pyspark从其数组对象值获取相关记录

t30tvxxf  于 2021-05-17  发布在  Spark
关注(0)|答案(1)|浏览(309)

我有一个spark dataframe,它有一个id列,与其他列一起,它有一个数组列,其中包含相关记录的id作为它的值。
示例Dataframe为

ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456]
345 | alen | [789]
456 | sam  | [789,999]
789 | marc | [111]
555 | dan  | [333]

在上面,我需要将所有相关的子id附加到父id的数组列

ID | NAME | RELATED_IDLIST
 --------------------------
 123 | mike | [345,456,789,999,111]
 345 | alen | [789,111]
 456 | sam  | [789,999,111]
 789 | marc | [111]
 555 | dan  | [333]

我需要你的帮助。谢谢

wtzytmuj

wtzytmuj1#

处理此任务的一种方法是执行selfleftjoin,更新相关的\u idlist,进行多次迭代,直到满足某些条件(仅当整个层次结构的最大深度很小时,此方法才有效)。对于spark 2.3,我们可以将arraytype列转换为逗号分隔的stringtype列,使用sql内置函数find \u in \u set和一个新列 PROCESSED_IDLIST 要设置连接条件,请参见下面的主逻辑:
功能:

from pyspark.sql import functions as F
import pandas as pd

# define a function which takes a dataframe as input, does a self left-join and then return another

# dataframe with exactly the same schema as the input dataframe. do the same repeatly until some conditions satisfy

def recursive_join(d, max_iter=10):
  # function to find direct child-IDs and merge into RELATED_IDLIST
  def find_child_idlist(_df):
    return _df.alias('d1').join(
        _df.alias('d2'), 
        F.expr("find_in_set(d2.ID,d1.RELATED_IDLIST)>0 AND find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1"),
        "left"
      ).groupby("d1.ID", "d1.NAME").agg(
        F.expr("""
          /* combine d1.RELATED_IDLIST with all matched entries from collect_set(d2.RELATED_IDLIST)
           * and remove trailing comma from when all d2.RELATED_IDLIST are NULL */
          trim(TRAILING ',' FROM
              concat_ws(",", first(d1.RELATED_IDLIST), concat_ws(",", collect_list(d2.RELATED_IDLIST)))
          ) as RELATED_IDLIST"""),
        F.expr("first(d1.RELATED_IDLIST) as PROCESSED_IDLIST")
    )
  # below the main code logic
  d = find_child_idlist(d).persist()
  if (d.filter("RELATED_IDLIST!=PROCESSED_IDLIST").count() > 0) & (max_iter > 1):
    d = recursive_join(d, max_iter-1)
  return d

# define pandas_udf to remove duplicate from an ArrayType column

get_uniq = F.pandas_udf(lambda s: pd.Series([ list(set(x)) for x in s ]), "array<int>")

哪里:
在函数中 find_child_idlist() ,左连接必须满足以下两个条件:
d2.id在d1.idlist中: find_in_set(d2.ID,d1.RELATED_IDLIST)>0 d2.id不在d1.u idlist中: find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1 当没有行时退出递归的\u连接 RELATED_IDLIST!=PROCESSED_IDLIST 或者 max_iter > 1 处理:
设置Dataframe:

df = spark.createDataFrame([
  (123, "mike", [345,456]), (345, "alen", [789]), (456, "sam", [789,999]),
  (789, "marc", [111]), (555, "dan", [333])
],["ID", "NAME", "RELATED_IDLIST"])

添加新列 PROCESSED_IDLIST 保存 RELATED_IDLIST 在上一个join中,执行 recursive_join() ```
df1 = df.withColumn('RELATED_IDLIST', F.concat_ws(',','RELATED_IDLIST'))
.withColumn('PROCESSED_IDLIST', F.col('ID'))

df_new = recursive_join(df1, 5)
df_new.show(10,0)
+---+----+-----------------------+-----------------------+
|ID |NAME|RELATED_IDLIST |PROCESSED_IDLIST |
+---+----+-----------------------+-----------------------+
|555|dan |333 |333 |
|789|marc|111 |111 |
|345|alen|789,111 |789,111 |
|123|mike|345,456,789,789,999,111|345,456,789,789,999,111|
|456|sam |789,999,111 |789,999,111 |
+---+----+-----------------------+-----------------------+

分裂 `RELATED_IDLIST` 放入整数数组,然后使用udf函数删除重复的数组元素:

df_new.withColumn("RELATED_IDLIST", get_uniq(F.split('RELATED_IDLIST', ',').cast('array'))).show(10,0)
+---+----+-------------------------+-----------------------+
|ID |NAME|RELATED_IDLIST |PROCESSED_IDLIST |
+---+----+-------------------------+-----------------------+
|555|dan |[333] |333 |
|789|marc|[111] |111 |
|345|alen|[789, 111] |789,111 |
|123|mike|[999, 456, 111, 789, 345]|345,456,789,789,999,111|
|456|sam |[111, 789, 999] |789,999,111 |
+---+----+-------------------------+-----------------------+

相关问题