scala—查找具有一定公共值的所有用户对

q9yhzks0  于 2021-06-27  发布在  Hive
关注(0)|答案(2)|浏览(321)

我是spark的新手,我正试图找到一些关于我已经转换成两个独立Dataframe的数据列表的具体信息。
这两个Dataframe是:

Users:                item_Details:
user_id | item_id     item_id | item_name
-----------------     ----------------------
  1     | 123           123   |  phone
  2     | 223           223   |  game
  3     | 423           423   |  foo
  2     | 1223          1223  |  bar
  1     | 3213          3213  | foobar

我需要找到所有对用户有超过50个共同的项目和项目的数量排序。不能有重复项,这意味着只有一组userid1和userid2。
结果需要如下所示:

user_id1 | user_id2 | count_of_items | list_of_items
-------------------------------------------------------------
    1    |     2    |       51       |  phone,foo,bar,foobar
4si2a6ki

4si2a6ki1#

有一种方法:
装配 item pairs 通过自连接的每个不同用户对
生成 common itemsitem pairs 使用自定义项
按特定的公用项计数筛选结果数据集
如下图所示:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row

val users = Seq(
  (1, 123), (1, 223), (1, 423),
  (2, 123), (2, 423), (2, 1223), (2, 3213),
  (3, 223), (3, 423), (3, 1223), (3, 3213),
  (4, 123), (4, 1223), (4, 3213)
).toDF("user_id", "item_id")

val item_details = Seq(
  (123, "phone"), (223, "game"), (423, "foo"), (1223, "bar"), (3213, "foobar")
)toDF("item_id", "item_name")

val commonItems = udf( (itemPairs: Seq[Row]) =>
  itemPairs.collect{ case Row(a: Int, b: Int) if a == b => a }
)

val commonLimit = 2  // Replace this with any specific common item count

val user_common_items =
  users.as("u1").join(users.as("u2"), $"u1.user_id" < $"u2.user_id").
  groupBy($"u1.user_id", $"u2.user_id").agg(
    collect_set(
      struct($"u1.item_id".as("ui1"), $"u2.item_id".as("ui2"))
    ).as("item_pairs")).
  withColumn("common_items", commonItems($"item_pairs")).
  drop("item_pairs").
  where(size($"common_items") > commonLimit)

user_common_items.show(false)
// +-------+-------+-----------------+
// |user_id|user_id|common_items     |
// +-------+-------+-----------------+
// |2      |3      |[423, 3213, 1223]|
// |2      |4      |[3213, 123, 1223]|
// +-------+-------+-----------------+

如果需要通用项目名称而不是项目ID,则可以加入 item_details 在上述步骤中对项目名称进行汇总;或者,可以分解现有的 common item ids 加入 item_details 还有一个 collect_list 按用户对聚合:

user_common_items.
  withColumn("item_id", explode($"common_items")).
  join(item_details, Seq("item_id")).
  groupBy($"u1.user_id", $"u2.user_id").agg(collect_list($"item_name").as("common_items")).
  withColumn("item_count", size($"common_items")).
  show
// +-------+-------+--------------------+----------+
// |user_id|user_id|        common_items|item_count|
// +-------+-------+--------------------+----------+
// |      2|      3|  [foo, foobar, bar]|         3|
// |      2|      4|[foobar, phone, bar]|         3|
// +-------+-------+--------------------+----------+
ukxgm1gy

ukxgm1gy2#

另一个解决方案,不使用自定义项。因为我们需要公共项,所以可以在joinexprs本身中给出匹配。看看这个

val users = Seq(
  (1, 123), (1, 223), (1, 423),
  (2, 123), (2, 423), (2, 1223), (2, 3213),
  (3, 223), (3, 423), (3, 1223), (3, 3213),
  (4, 123), (4, 1223), (4, 3213)
).toDF("user_id", "item_id")

val items = Seq(
  (123, "phone"), (223, "game"), (423, "foo"), (1223, "bar"), (3213, "foobar")
)toDF("item_id", "item_name")

val common_items =
  users.as("t1").join(users.as("t2"),$"t1.user_id" < $"t2.user_id" and $"t1.item_id" === $"t2.item_id" )
      .join(items.as("it"),$"t1.item_id"===$"it.item_id","inner")
      .groupBy($"t1.user_id",$"t2.user_id")
      .agg(collect_set('item_name).as("items"))
      .filter(size('items)>2) // change here for count
      .withColumn("size",size('items))

common_items.show(false)

结果

+-------+-------+--------------------+----+
|user_id|user_id|items               |size|
+-------+-------+--------------------+----+
|2      |3      |[bar, foo, foobar]  |3   |
|2      |4      |[bar, foobar, phone]|3   |
+-------+-------+--------------------+----+

相关问题