我试图通过擦除一些特性(存储在中)来修改“features”向量列 feature_idx_to_wipe
). 伪代码如下,问题是udf不取 Set
. 我想知道如何解决这个问题,或者是否有更好的方法。
//data
val feature_idx_to_wipe = Set(1, 2)
val dfA = spark.createDataFrame(Seq(
(0, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
(1, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (4, 1.0)))),
(2, Vectors.sparse(6, Seq((0, 1.0), (2, 1.0), (4, 1.0))))
)).toDF("id", "features")
dfA.show(false)
+---+-------------------------+
|id |features |
+---+-------------------------+
|0 |(6,[0,1,2],[1.0,1.0,1.0])|
|1 |(6,[2,3,4],[1.0,1.0,1.0])|
|2 |(6,[0,2,4],[1.0,1.0,1.0])|
+---+-------------------------+
//udf
def wipe(v: NewSparseVector, idx2clean:Set[Int]) : NewSparseVector = {
val lb:ListBuffer[(Int, Double)]=ListBuffer()
v.foreachActive {
case (i, v) =>
if(!idx2clean.contains(i)){
lb += ((i, v))
}
}
NewVectors.sparse(v.size, lb.toSeq).toSparse
}
val udf_wipe = udf((x: NewSparseVector, idx2clean:Set[Int]) => wipe(x, idx2clean))
//apply udf
dfA.withColumn("features_wiped", udf_wipe(col("features"), feature_idx_to_wipe))
// error:
// scala> dfA.withColumn("nf", udf_wipe(col("features"), tc))
// <console>:98: error: type mismatch;
// found : scala.collection.immutable.Set[Int]
// required: org.apache.spark.sql.Column
// dfA.withColumn("nf", udf_wipe(col("features"), tc))
//target (a new column of vector added, with features at index 1,2 are removed)
dfA.select("id","features_wiped").show(false)
+---+-------------------------+
|id |features_wiped |
+---+-------------------------+
|0 |(6,[0],[1.0]) |
|1 |(6,[3,4],[1.0,1.0]) |
|2 |(6,[0,4],[1.0,1.0]) |
+---+-------------------------+
2条答案
按热度按时间zqdjd7g91#
可以通过如下方式将函数wipe转换为curried函数:
要为相应函数创建自定义项,请执行以下操作:
最后,将udf应用于Dataframe:
s2j5cfk02#
另一种选择-
试验数据
备选方案-1使用照明,如下所示-
备选方案2使用广播变量sparkcontext.broadcast,如下所示-