如何在执行map之前获取特定分区的所有输入?

xlpyo6sf  于 2021-05-29  发布在  Spark
关注(0)|答案(1)|浏览(246)

我有以下代码:

partitions = 5
stitching_result = sc.\
    parallelize(probes_combination, partitions).\
    map(lambda l: stitch_images_pair(project, l)).\
    collect()

如何获得回调函数,该函数将获取特定分区的所有输入,并在之前执行 map(lambda l: stitch_images_pair(project, l)) 部分?
在我的例子中,它应该运行5次-每个分区一次。
提前谢谢。

fjaof16o

fjaof16o1#

可以使用mappartitions:

def func(iterator):
    print("hello world")
    return iterator

stitching_result = sc.\
    parallelize(a, partitions).\
    mapPartitions(func, True).\
    map(lambda l: stitch_images_pair(project, l)).\
    collect()

打印五次 hello world 执行前 stitch_images_pair .
在内部使用迭代器时 func 迭代器应该具体化为一个列表,并返回一个新的迭代器。以下代码打印每个分区中的行数:

def func(iterator):
    data = list(iterator)
    print(len(data))
    return iter(data)

相关问题