如何在spark中获得数组列的所有组合?

ctzwtxfj  于 2021-07-14  发布在  Spark
关注(0)|答案(3)|浏览(363)

假设我有一个数组列 group_ids ```
+-------+----------+
|user_id|group_ids |
+-------+----------+
|1 |[5, 8] |
|3 |[1, 2, 3] |
|2 |[1, 4] |
+-------+----------+

架构:

root
|-- user_id: integer (nullable = false)
|-- group_ids: array (nullable = false)
| |-- element: integer (containsNull = false)

我想得到所有成对的组合:

+-------+------------------------+
|user_id|group_ids |
+-------+------------------------+
|1 |5, 8 |
|3 |[[1, 2], [1, 3], [2, 3]]|
|2 |1, 4 |
+-------+------------------------+

到目前为止,我用自定义项创建了最简单的解决方案:

spark.udf.register("permutate", udf((xs: Seq[Int]) => xs.combinations(2).toSeq))

dataset.withColumn("group_ids", expr("permutate(group_ids)"))

我要找的是通过spark内置函数实现的东西。有没有一种方法可以在没有自定义项的情况下实现相同的代码?
ru9i0ody

ru9i0ody1#

一些高阶函数可以做到这一点。需要Spark>=2.4。

val df2 = df.withColumn(
    "group_ids", 
    expr("""
        filter(
            transform(
                flatten(
                    transform(
                        group_ids, 
                        x -> arrays_zip(
                            array_repeat(x, size(group_ids)), 
                            group_ids
                        )
                    )
                ), 
                x -> array(x['0'], x['group_ids'])
            ), 
            x -> x[0] < x[1]
        )
    """)
)

df2.show(false)
+-------+------------------------+
|user_id|group_ids               |
+-------+------------------------+
|1      |[[5, 8]]                |
|3      |[[1, 2], [1, 3], [2, 3]]|
|2      |[[1, 4]]                |
+-------+------------------------+
jyztefdp

jyztefdp2#

基于 explode 以及 joins 解决方案

val exploded = df.select(col("user_id"), explode(col("group_ids")).as("e"))

// to have combinations
val joined1 = exploded.as("t1")
                      .join(exploded.as("t2"), Seq("user_id"), "outer")
                      .select(col("user_id"), col("t1.e").as("e1"), col("t2.e").as("e2"))

// to filter out redundant combinations
val joined2 = joined1.as("t1")
                     .join(joined1.as("t2"), $"t1.user_id" === $"t2.user_id" && $"t1.e1" === $"t2.e2" && $"t1.e2"=== $"t2.e1")
                     .where("t1.e1 < t2.e1")
                     .select("t1.*")

// group into array
val result = joined2.groupBy("user_id")
                    .agg(collect_set(struct("e1", "e2")).as("group_ids"))
ikfrs5lh

ikfrs5lh3#

可以得到列的最大大小 group_ids . 然后,在范围内使用组合 (1 - maxSize)when 表达式从原始数组创建子数组组合,并最终从结果数组中筛选空元素:

val maxSize = df.select(max(size($"group_ids"))).first.getAs[Int](0)

val newCol = (1 to maxSize).combinations(2)
  .map(c =>
    when(
      size($"group_ids") >= c(1),
      array(element_at($"group_ids", c(0)), element_at($"group_ids", c(1)))
    )
  ).toSeq

df.withColumn("group_ids", array(newCol: _*))
  .withColumn("group_ids", expr("filter(group_ids, x -> x is not null)"))
  .show(false)

//+-------+------------------------+
//|user_id|group_ids               |
//+-------+------------------------+
//|1      |[[5, 8]]                |
//|3      |[[1, 2], [1, 3], [2, 3]]|
//|2      |[[1, 4]]                |
//+-------+------------------------+

相关问题