如何使用pyspark通过比较列值来添加遭遇id

izj3ouym  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(302)

我的csv有两列dept,date。我想添加一个名为id的新列。条件:
默认id为1。
如果当前部门==下一个部门和日期差异大于2天,则id将增加。
如果当前部门==下一个部门和日期差小于2天,则id将与上一个id相同。
如果当前部门!=下一个部门,然后从id=1开始。我使用pandas with row iterator实现了这一点,以及如何使用pyspark(500000行)实现这一点。
预期结果:

+------+------------+----+
| dept |    date    | id |
+------+------------+----+
|    1 | 04/10/2018 |  1 |
|    1 | 27/11/2018 |  2 |
|    1 | 27/11/2018 |  2 |
|    1 | 27/11/2018 |  2 |
|    1 | 27/12/2018 |  3 |
|    1 | 27/01/2019 |  4 |
|    1 | 27/02/2019 |  5 |
|    1 | 27/03/2019 |  6 |
|    1 | 27/04/2019 |  7 |
|    1 | 27/05/2019 |  8 |
|    2 | 28/12/2018 |  1 |
|    2 | 28/12/2018 |  1 |
|    2 | 28/12/2018 |  1 |
|    2 | 09/01/2019 |  2 |
|    2 | 09/01/2019 |  2 |
|    2 | 15/02/2019 |  3 |
|    2 | 15/02/2019 |  3 |
|    2 | 15/02/2019 |  3 |
|    2 | 28/02/2019 |  4 |
|    2 | 28/02/2019 |  4 |
|    2 | 02/04/2019 |  5 |
|    2 | 08/04/2019 |  6 |
|    2 | 08/04/2019 |  6 |
|    2 | 08/04/2019 |  6 |
|    2 | 09/04/2019 |  6 |
|    2 | 10/04/2019 |  6 |
|    2 | 10/04/2019 |  6 |
|    2 | 29/04/2019 |  7 |
|    2 | 06/02/2019 |  8 |
|    2 | 06/02/2019 |  8 |
|    2 | 06/02/2019 |  8 |
|    2 | 06/02/2019 |  8 |
|    2 | 06/02/2019 |  8 |
|    2 | 20/09/2018 |  9 |
|    2 | 20/09/2018 |  9 |
|    2 | 05/10/2018 | 10 |
|    2 | 05/10/2018 | 10 |
|    2 | 22/03/2019 | 11 |
|    2 | 22/03/2019 | 11 |
|    2 | 17/05/2019 | 12 |
|    3 | 20/09/2018 |  1 |
|    3 | 20/09/2018 |  1 |
|    3 | 20/09/2018 |  1 |
|    3 | 12/10/2018 |  2 |
|    3 | 12/10/2018 |  2 |
|    3 | 09/11/2018 |  3 |
|    3 | 20/12/2018 |  4 |
|    3 | 22/03/2019 |  5 |
|    3 | 22/03/2019 |  5 |
|    3 | 09/04/2019 |  6 |
+------+------------+----+
smdncfj3

smdncfj31#

我不得不更新一下你的数据,因为在你的例子中,你的日期不是很有序。
对于你的用例来说,使用一个累积和的滞后应该可以达到目的,
以下是pyspark的工作版本:

df = spark.createDataFrame(
    [
    (1, "04/10/2018", 1),
    (1, "27/11/2018", 2),
    (1, "27/11/2018", 2),
    (1, "27/11/2018", 2),
    (1, "27/12/2018", 3),
    (1, "27/01/2019", 4),
    (1, "27/02/2019", 5),
    (1, "27/03/2019", 6),
    (1, "27/04/2019", 7),
    (1, "27/05/2019", 8),
    (2, "28/12/2018", 1),
    (2, "28/12/2018", 1),
    (2, "28/12/2018", 1),
    (2, "09/01/2019", 2),
    (2, "09/01/2019", 2),
    (2, "15/02/2019", 3),
    (2, "15/02/2019", 3),
    (2, "15/02/2019", 3),
    (2, "28/02/2019", 4),
    (2, "28/02/2019", 4),
    (2, "02/04/2019", 5),
    (2, "08/04/2019", 6),
    (2, "08/04/2019", 6),
    (2, "08/04/2019", 6),
    (2, "09/04/2019", 6),
    (2, "10/04/2019", 6),
    (2, "10/04/2019", 6),
    (2, "29/04/2019", 7),
    (2, "06/05/2019", 8),
    (2, "06/05/2019", 8),
    (2, "06/05/2019", 8),
    (2, "06/05/2019", 8),
    (2, "06/05/2019", 8),
    (2, "20/09/2019", 9),
    (2, "20/09/2019", 9),
    (2, "05/10/2019", 10),
    (2, "05/10/2019", 10),
    (2, "22/03/2020", 11),
    (2, "22/03/2020", 11),
    (2, "17/05/2020", 12),
    (3, "20/09/2018", 1),
    (3, "20/09/2018", 1),
    (3, "20/09/2018", 1),
    (3, "12/10/2018", 2),
    (3, "12/10/2018", 2),
    (3, "09/11/2018", 3),
    (3, "20/12/2018", 4),
    (3, "22/03/2019", 5),
    (3, "22/03/2019", 5),
    (3, "09/04/2019", 6),
    ],
    ['dept', 'date', 'target']
)
from pyspark.sql.functions import col, to_timestamp, when, coalesce, lit, datediff, lag, sum
from pyspark.sql import Window

window = Window.partitionBy("dept").orderBy("date_parsed")
window_cusum = (
    Window
    .partitionBy('dept')
    .orderBy('date_parsed')
    .rangeBetween(Window.unboundedPreceding, 0)
)

final_df = (
    df
    .withColumn('date_parsed', to_timestamp(col('date'), 'dd/MM/yyyy'))
    .withColumn('diff',
        when(
            datediff(col("date_parsed"), lag(col("date_parsed")).over(window)) <= 2,
            True
        ).otherwise(False)
    )
    .withColumn('cusum_of_false',
        sum(
            when(~ col("diff"), lit(1)
            ).otherwise(lit(0))
        ).over(window_cusum)
    )
    .withColumn("check_working", col("target") == col("cusum_of_false"))
)

final_df.orderBy("dept", "date_parsed", "cusum_of_false").show()

row_count = final_df.count()
check_working = final_df.agg(sum(when(col("check_working"), lit(1)).otherwise(lit(0)))).collect()[0][0]

assert row_count == check_working

相关问题