我正在寻找建议,以优化下面的代码。目前,在具有2个工作节点的emr集群上完成此代码大约需要25分钟。
wall_start_time = time.time()
## Experimenting with bringing all data onto the master node
uid_vector_df = document_vector_df.select('uid', 'Glove Document Vector')
print("#Rows = %d" % uid_vector_df.count())
print("#Columns = %d" % len(uid_vector_df.columns))
print("#Num Partitions = %d" % uid_vector_df.rdd.getNumPartitions())
print("Schema")
uid_vector_df.printSchema()
projected_data_on_master = uid_vector_df.collect()
total_wall_time = time.time() - wall_start_time
print("Wall Time taken (in secs) to collect data on master node = {}".format(total_wall_time))
这是print语句的输出
# Rows = 2000
# Columns = 2
# Num Partitions = 1
Schema
root
|-- uid: string (nullable = true)
|-- Glove Document Vector: array (nullable = true)
| |-- element: float (containsNull = true)
此代码花费的时间过长。我试过在本地pyspark安装程序上运行它,也试过在emr上运行它。以下是数据的性质。
每行由一个字符串(uid)和一个浮点向量组成。每个向量的大小是300。这意味着每行大约是64(字节)+300*8(字节),假设float是8字节,这意味着对于2000行,数据的大小小于5mb
我真的很惊讶为什么这么少量的数据要花这么长时间才能在主节点中收集到。请让我知道如何优化这段简单的代码。
谢谢
p、 我知道collect()操作是不受欢迎的,但是有一个合理的用例,我们希望在主节点上收集数据,使用faiss执行批聚类。因此,我不是在寻找如何完全避免collect()操作的建议
暂无答案!
目前还没有任何答案,快来回答吧!