如何有效地使用窗口函数,根据n个先前值来决定下n个行数

qyuhtwio  于 2021-05-18  发布在  Spark
关注(0)|答案(1)|浏览(277)

嗨,我有以下数据。

+----------+----+-------+-----------------------+
|      date|item|avg_val|conditions             |
+----------+----+-------+-----------------------+
|01-10-2020|   x|     10|                      0|
|02-10-2020|   x|     10|                      0|
|03-10-2020|   x|     15|                      1|
|04-10-2020|   x|     15|                      1|
|05-10-2020|   x|      5|                      0|
|06-10-2020|   x|     13|                      1|
|07-10-2020|   x|     10|                      1|
|08-10-2020|   x|     10|                      0|
|09-10-2020|   x|     15|                      1|
|01-10-2020|   y|     10|                      0|
|02-10-2020|   y|     18|                      0|
|03-10-2020|   y|      6|                      1|
|04-10-2020|   y|     10|                      0|
|05-10-2020|   y|     20|                      0|
+----------+----+-------+-----------------------+

我尝试创建一个新的列,名为flag level,它基于
如果标志值为0,则新列值将为0。
如果标志为1,则新列将为1,接下来的4n行数将为零,即不需要检查next n值。这个过程将应用于每一个项目,也就是说,分区按项目将工作。
我在这里用了n=4,
我用了下面的代码,但没有有效的开窗功能,有没有优化的方法。

DROP TEMPORARY TABLE t2;
CREATE TEMPORARY TABLE t2
SELECT *,
MAX(conditions) OVER (PARTITION BY item ORDER BY item,`date` ROWS 4 PRECEDING ) AS new_row
FROM record
ORDER BY item,`date`;

 DROP TEMPORARY TABLE t3;
CREATE TEMPORARY TABLE t3
SELECT *,ROW_NUMBER() OVER (PARTITION BY item,new_row ORDER BY item,`date`) AS e FROM t2;

SELECT *,CASE WHEN new_row=1 AND e%5>1 THEN 0 
WHEN new_row=1 AND e%5=1 THEN 1 ELSE 0 END AS flag FROM t3;

输出像

+----------+----+-------+-----------------------+-----+
|      date|item|avg_val|conditions             |flag |
+----------+----+-------+-----------------------+-----+
|01-10-2020|   x|     10|                      0|    0|
|02-10-2020|   x|     10|                      0|    0|
|03-10-2020|   x|     15|                      1|    1|
|04-10-2020|   x|     15|                      1|    0|
|05-10-2020|   x|      5|                      0|    0|
|06-10-2020|   x|     13|                      1|    0|
|07-10-2020|   x|     10|                      1|    0|
|08-10-2020|   x|     10|                      0|    0|
|09-10-2020|   x|     15|                      1|    1|
|01-10-2020|   y|     10|                      0|    0|
|02-10-2020|   y|     18|                      0|    0|
|03-10-2020|   y|      6|                      1|    1|
|04-10-2020|   y|     10|                      0|    0|
|05-10-2020|   y|     20|                      0|    0|
+----------+----+-------+-----------------------+-----+

但我无法得到输出,我尝试了更多。

qacovj5a

qacovj5a1#

正如注解(由@nbk和@akina)所建议的,您将需要某种迭代器来实现逻辑。使用sparksql和sparkversion2.4+,我们可以使用内置函数aggregate并设置一个结构数组和一个计数器作为累加器。下面是一个名为 record (假设值为 conditions 列是 0 或者 1 ):

val df = Seq(
    ("01-10-2020", "x", 10, 0), ("02-10-2020", "x", 10, 0), ("03-10-2020", "x", 15, 1),
    ("04-10-2020", "x", 15, 1), ("05-10-2020", "x", 5, 0), ("06-10-2020", "x", 13, 1),
    ("07-10-2020", "x", 10, 1), ("08-10-2020", "x", 10, 0), ("09-10-2020", "x", 15, 1),
    ("01-10-2020", "y", 10, 0), ("02-10-2020", "y", 18, 0), ("03-10-2020", "y", 6, 1),
    ("04-10-2020", "y", 10, 0), ("05-10-2020", "y", 20, 0)
).toDF("date", "item", "avg_val", "conditions")

df.createOrReplaceTempView("record")

sql语句:

spark.sql("""
  SELECT t1.item, m.*
  FROM (
    SELECT item,
      sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
    FROM record
    GROUP BY item
  ) as t1 LATERAL VIEW OUTER inline(
    aggregate(
      /* expr: set up array `dta` from the 2nd element to the last 
       *       notice that indices for slice function is 1-based, dta[i] is 0-based
       */
      slice(dta,2,size(dta)),
      /* start: set up and initialize `acc` to a struct containing two fields:
       * - dta: an array of structs with a single element dta[0]
       * - counter: number of rows after flag=1, can be from `0` to `N+1`
       */
      (array(dta[0]) as dta, dta[0].conditions as counter),
      /* merge: iterate through the `expr` using x and update two fields of `acc`
       * - dta: append values from x to acc.dta array using concat + array functions
       *        update flag using `IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)`
       * - counter: increment by 1 if acc.counter is between 1 and 4
       *            , otherwise set value to x.conditions
       */
      (acc, x) -> named_struct(
          'dta', concat(acc.dta, array(named_struct(
              'date', x.date,
              'avg_val', x.avg_val,
              'conditions', x.conditions,
              'flag', IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)
            ))),
          'counter', IF(acc.counter > 0 and acc.counter < 5, acc.counter+1, x.conditions)
        ),
      /* finish: retrieve acc.dta only and discard acc.counter */
      acc -> acc.dta
    )
  ) m
""").show(50)

结果:

+----+----------+-------+----------+----+
|item|      date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
|   x|01-10-2020|     10|         0|   0|
|   x|02-10-2020|     10|         0|   0|
|   x|03-10-2020|     15|         1|   1|
|   x|04-10-2020|     15|         1|   0|
|   x|05-10-2020|      5|         0|   0|
|   x|06-10-2020|     13|         1|   0|
|   x|07-10-2020|     10|         1|   0|
|   x|08-10-2020|     10|         0|   0|
|   x|09-10-2020|     15|         1|   1|
|   y|01-10-2020|     10|         0|   0|
|   y|02-10-2020|     18|         0|   0|
|   y|03-10-2020|      6|         1|   1|
|   y|04-10-2020|     10|         0|   0|
|   y|05-10-2020|     20|         0|   0|
+----+----------+-------+----------+----+

哪里:
使用 groupby 将同一项的行收集到一个名为dta column的结构数组中,该数组包含4个字段:date、avgïval、conditions和flag,并按日期排序
使用 aggregate 函数遍历上述结构数组,根据计数器和条件更新标志字段(详细信息请参见上面的sql代码注解)
使用 Lateral VIEW 和内联函数,从聚合函数中分解结构的结果数组
笔记:
(1) 所建议的sql是针对n=4的,其中 acc.counter IN (0,5) 以及 acc.counter < 5 在sql中。对于任何n,将上述值调整为: acc.counter IN (0,N+1) 以及 acc.counter < N+1 ,下面显示的是 N=2 使用相同的样本数据:

+----+----------+-------+----------+----+
|item|      date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
|   x|01-10-2020|     10|         0|   0|
|   x|02-10-2020|     10|         0|   0|
|   x|03-10-2020|     15|         1|   1|
|   x|04-10-2020|     15|         1|   0|
|   x|05-10-2020|      5|         0|   0|
|   x|06-10-2020|     13|         1|   1|
|   x|07-10-2020|     10|         1|   0|
|   x|08-10-2020|     10|         0|   0|
|   x|09-10-2020|     15|         1|   1|
|   y|01-10-2020|     10|         0|   0|
|   y|02-10-2020|     18|         0|   0|
|   y|03-10-2020|      6|         1|   1|
|   y|04-10-2020|     10|         0|   0|
|   y|05-10-2020|     20|         0|   0|
+----+----------+-------+----------+----+

(2) 我们使用 dta[0] 初始化 acc 其中包括其字段的值和数据类型。理想情况下,我们应该确保这些字段的数据类型正确,以便正确执行所有计算。例如在计算 acc.counter ,如果 conditions 是stringtype, acc.counter+1 将返回具有doubletype值的stringtype

spark.sql("select '2'+1").show()
+---------------------------------------+
|(CAST(2 AS DOUBLE) + CAST(1 AS DOUBLE))|
+---------------------------------------+
|                                    3.0|
+---------------------------------------+

当使用 acc.counter IN (0,5) 或者 acc.counter < 5 . 根据op的反馈,这产生了不正确的结果,没有任何警告/错误信息。
一种解决方法是在设置聚合函数的第二个参数时使用cast指定确切的字段类型,以便在任何类型不匹配时报告错误,如下所示:

CAST((array(dta[0]), dta[0].conditions) as struct<dta:array<struct<date:string,avg_val:string,conditions:int,flag:int>>,counter:int>),

另一个解决方案是在创建时强制类型 dta 列,在此示例中,请参见 int(conditions) as conditions 在以下代码中:

SELECT item,
  sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
FROM record
GROUP BY item

我们还可以在计算中强制数据类型,例如,请参见 int(acc.counter+1) 下图:

IF(acc.counter > 0 and acc.counter < 5, int(acc.counter+1), x.conditions)

相关问题