Spark scala -根据另一列中定义的列名获取列值

kqqjbcuj  于 8个月前  发布在  Scala
关注(0)|答案(5)|浏览(101)

我有一个类似于以下的框架:

+-----------+-----------+---------------+------+---------------------+                                        
|best_col     |A          |B              |  C   |<many more columns>  |
+-----------+-----------+---------------+------+---------------------+
|     A#B     |    14     |        26     |  32  |       ...           |
|     C       |    13     |        17     |  96  |       ...           |
|     B#C     |    23     |        19     |  42  |       ...           |
+-----------+-----------+---------------+------+---------------------+

我想以这样的DataFrame结束:

+-----------+-----------+---------------+------+---------------------+----------+                                        
|best_col     |A          |B              |  C   |<many more columns>  | result   |
+-----------+-----------+---------------+------+---------------------+----------+
|     A#B     |    14     |        26     |  32  |       ...           |   14#26  |
|     C       |    13     |        17     |  96  |       ...           |   96     |
|     B#C     |    23     |        19     |  42  |       ...           |   19#42  |
+-----------+-----------+---------------+------+---------------------+----------+

本质上,我想添加一个列result,它将从best_col列中指定的列中选择值。best_col仅包含DataFrame中存在的列名。我不想检查像col(best_col) === A等。我试着做col(col("best_col").toString()),但这不起作用。有什么简单的方法吗?

laximzn5

laximzn51#

你不能完全动态地做到这一点,但你可以简单地使用一个case,当链和列出所有你可以支持的字段(*)时-显然它们必须都具有相同的类型。您的“我不想检查类似”是唯一的方法,AttributeReferences(列名)必须在查询期间固定。

  • 如果您可以构建when(col(bes_col)==“columnname”)系列调用,那么您可以使用模式来构造具有正确类型的字段,并接近完全动态。
qxgroojn

qxgroojn2#

您可以创建一个udf并使用withColumn将整行传递给它。并根据列A的值返回结果。
大概是这样的:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row

def concatFunc(row: Row) = ...
def combineUdf = udf((row: Row) => concatFunc(row))

然后使用这个:

df.withColumn("result", combineUdf(struct(columns.map(col): _*)))

// or try this as well if work
df.withColumn("result", combineUdf(*))
xmq68pz9

xmq68pz93#

下面是一个以编程方式生成适当的when序列的解决方案:

// cols is the list of considered columns
val selection = cols
    .foldLeft(when(col("best_col").isNull, lit(null)))
             ((cur, column) => cur.when(col("best_col") === column, col(column)))

df.withColumn("result", selection).show()
gwbalxhn

gwbalxhn4#

  • 将所需列转换为map< key, value >,其中key是列名,value是列值。
  • best_col列使用split函数拆分其值。
  • 使用transform高阶函数,遍历每个键,然后从map<string, string>数据类型中查找并获取值。
  • 使用concat_ws连接array< value >

下面是示例代码。

scala> df.printSchema
root
 |-- best_col: string (nullable = true)
 |-- A: integer (nullable = false)
 |-- B: integer (nullable = false)
 |-- C: integer (nullable = false)
val columnsExpr = df.columns
  .map(c => col(c).as(c)) ++ Seq(
  df.columns
    .filter(_ != "best_col")
    .map(c => map(lit(c), col(c)))
    .reduce(map_concat(_, _))
    .as("map_data")
)

// Exiting paste mode, now interpreting.

columnsExpr: Array[org.apache.spark.sql.Column] = Array(best_col AS best_col, A AS A, B AS B, C AS C, map_concat(map_concat(map(A, A), map(B, B)), map(C, C)) AS map_data)
val resultExpr =
  concat_ws(
       "#",
       transform(
          split($"best_col", "#"),
          t => col("map_data")(t)
       )
    )

// Exiting paste mode, now interpreting.

resultExpr: org.apache.spark.sql.Column = concat_ws(#, transform(split(best_col, #, -1), lambdafunction(map_data[x_20], x_20)))
scala> df
        .select(columnsExpr:_*)
        .show(false)

+--------+---+---+---+---------------------------+
|best_col|A  |B  |C  |map_data                   |
+--------+---+---+---+---------------------------+
|A#B     |14 |26 |32 |{A -> 14, B -> 26, C -> 32}|
|C       |13 |17 |96 |{A -> 13, B -> 17, C -> 96}|
|B#C     |23 |19 |42 |{A -> 23, B -> 19, C -> 42}|
+--------+---+---+---+---------------------------+
scala> df
        .select(columnsExpr:_*)
        .withColumn("result", resultExpr)
        .show(false)

+--------+---+---+---+---------------------------+------+
|best_col|A  |B  |C  |map_data                   |result|
+--------+---+---+---+---------------------------+------+
|A#B     |14 |26 |32 |{A -> 14, B -> 26, C -> 32}|14#26 |
|C       |13 |17 |96 |{A -> 13, B -> 17, C -> 96}|96    |
|B#C     |23 |19 |42 |{A -> 23, B -> 19, C -> 42}|19#42 |
+--------+---+---+---+---------------------------+------+

下面的方法将所有列转换为map<string, string>类型。

df
  .withColumn(
    "map_data",
    from_json(
        to_json(struct($"*")), 
        MapType(StringType, StringType)
    ) // Converting all columns to map<string, string>
  )
  .withColumn(
    "result",
    concat_ws(
      "#",
      transform(
        split($"best_col", "#"), // splitting best_col based on "#"
        t => col("map_data")(t)  // getting value from map data type & concating with #
      )
    )
  )
  .show(false)

使用udf

scala> val concat_cols = udf( 
    (map: Map[String, String], data: String) => 
        data.split("#").map(c => map.getOrElse(c, null)).reduceOption(_ +"#"+ _)
    )

concat_cols: org.apache.spark.sql.expressions.UserDefinedFunction = SparkUserDefinedFunction($Lambda$4451/448249168@34adac5b,StringType,List(Some(class[value[0]: map<string,string>]), Some(class[value[0]: string])),Some(class[value[0]: string]),None,true,true)

scala> df
    .withColumn("map_data", 
        from_json(
            to_json(
                struct($"*")
            ), 
            MapType(StringType, StringType)
        )
    )
    .withColumn("result", concat_cols($"map_data", $"best_col"))
    .show(false)
+--------+---+---+---+--------------------------------------------+------+
|best_col|A  |B  |C  |map_data                                    |result|
+--------+---+---+---+--------------------------------------------+------+
|A#B     |14 |26 |32 |{best_col -> A#B, A -> 14, B -> 26, C -> 32}|14#26 |
|C       |13 |17 |96 |{best_col -> C, A -> 13, B -> 17, C -> 96}  |96    |
|B#C     |23 |19 |42 |{best_col -> B#C, A -> 23, B -> 19, C -> 42}|19#42 |
+--------+---+---+---+--------------------------------------------+------+
t98cgbkg

t98cgbkg5#

列“best_col”可以拆分,并且对于每个其他列可以添加技术列,如果列在“best_col”中不存在,则为空;以及使用“concat_ws”函数连接的列之后:

val df = Seq(
  ("A#B", 14, 26, 32),
  ("C", 13, 17, 96),
  ("B#C", 23, 19, 42)).toDF("best_col", "A", "B", "C")

val withBestColumns = df.withColumn("best_columns", split($"best_col", "#"))

val columnNames = df.columns.tail
val withAdditionalFields = columnNames
  .foldLeft(withBestColumns) { case (df, columnName) => df.withColumn(columnName + "_required",
    when(array_contains($"best_columns", columnName),col(columnName).cast(StringType)).otherwise(null))
  }

val technicalColumnNames = columnNames.map(name=>name + "_required")

val result = withAdditionalFields
  .withColumn("result",  concat_ws("#", technicalColumnNames.map(col): _*))
  .drop(technicalColumnNames: _*)
  .drop("best_columns")

结果是:

+--------+---+---+---+------+
|best_col|A  |B  |C  |result|
+--------+---+---+---+------+
|A#B     |14 |26 |32 |14#26 |
|C       |13 |17 |96 |96    |
|B#C     |23 |19 |42 |19#42 |
+--------+---+---+---+------+

相关问题