apache spark在pyspark中操作数组时发生typeerror

v9tzhpje  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(400)

我试图计算“用户特征”和“电影特征”之间的点积(元素积的总和):

+------+-------+--------------------+--------------------+
|userId|movieId|       user_features|      movie_features|
+------+-------+--------------------+--------------------+
|    18|      1|[0.0, 0.5, 0.0, 0...|[1, 0, 0, 0, 0, 1...|
|    18|      2|[0.1, 0.0, 0.0, 0...|[1, 0, 0, 0, 0, 0...|
|    18|      3|[0.2, 0.0, 0.3, 0...|[0, 0, 0, 0, 0, 1...|
|    18|      4|[0.0, 0.1, 0.0, 0...|[0, 0, 0, 0, 0, 1...|
+------+-------+--------------------+--------------------+

数据类型:

df.printSchema()
_____________________________________________
root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- user_features: array (nullable = false)
 |    |-- element: double (containsNull = true)
 |-- movie_features: array (nullable = false)
 |    |-- element: float (containsNull = true)

None

我用这个

class Solution:
    """
    Data reading, pre-processing...
    """
    @udf("array<double>")
    def miltiply(self, x, y):
        if x and y:
            return [float(a * b) for a, b in zip(x, y)]

    def get_dot_product(self):

        df = self.user_DF.crossJoin(self.movies_DF)
        output = df.withColumn("zipxy", self.miltiply("user_features", "movie_features")) \
                   .withColumn('sumxy', sum([F.col('zipxy').getItem(i) for i in range(20)]))

出现以下错误: TypeError: Invalid argument, not a string or column: <__main__.Solution instance at 0x000000000A777EC8> of type <type 'instance'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function. 我错过了什么?我正在做这件事 udf 因为我用的是spark 1.6,所以我不能用 aggregate 或者 zip_with 功能。

epggiuax

epggiuax1#

如果你能用 numpy 然后

df = spark.createDataFrame([(18, 1, [1, 0, 1], [1, 1, 1])]).toDF('userId','movieId','user_features','movie_features')

import numpy as np
df.rdd.map(lambda x: (x[0], x[1], x[2], x[3], float(np.dot(np.array(x[2]), np.array(x[3]))))).toDF(df.columns + ['dot']).show()

+------+-------+-------------+--------------+---+
|userId|movieId|user_features|movie_features|dot|
+------+-------+-------------+--------------+---+
|    18|      1|    [1, 0, 1]|     [1, 1, 1]|2.0|
+------+-------+-------------+--------------+---+

相关问题