将非列类型变量传递给udf

iibxawm4  于 2021-05-27  发布在  Spark
关注(0)|答案(2)|浏览(311)

我试图通过擦除一些特性(存储在中)来修改“features”向量列 feature_idx_to_wipe ). 伪代码如下,问题是udf不取 Set . 我想知道如何解决这个问题,或者是否有更好的方法。

//data
val feature_idx_to_wipe = Set(1, 2)
val dfA = spark.createDataFrame(Seq(
  (0, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
  (1, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (4, 1.0)))),
  (2, Vectors.sparse(6, Seq((0, 1.0), (2, 1.0), (4, 1.0))))
)).toDF("id", "features")
dfA.show(false)
+---+-------------------------+
|id |features                 |
+---+-------------------------+
|0  |(6,[0,1,2],[1.0,1.0,1.0])|
|1  |(6,[2,3,4],[1.0,1.0,1.0])|
|2  |(6,[0,2,4],[1.0,1.0,1.0])|
+---+-------------------------+

//udf 
def wipe(v: NewSparseVector, idx2clean:Set[Int]) : NewSparseVector = {
    val lb:ListBuffer[(Int, Double)]=ListBuffer()
    v.foreachActive {
        case (i, v) =>
         if(!idx2clean.contains(i)){
          lb += ((i, v))
        }
    }

    NewVectors.sparse(v.size, lb.toSeq).toSparse 
}
val udf_wipe = udf((x: NewSparseVector, idx2clean:Set[Int]) => wipe(x, idx2clean))

//apply udf
dfA.withColumn("features_wiped", udf_wipe(col("features"), feature_idx_to_wipe))
// error: 
// scala> dfA.withColumn("nf", udf_wipe(col("features"), tc))
// <console>:98: error: type mismatch;
//  found   : scala.collection.immutable.Set[Int]
//  required: org.apache.spark.sql.Column
//        dfA.withColumn("nf", udf_wipe(col("features"), tc))

//target (a new column of vector added, with features at index 1,2 are removed)
dfA.select("id","features_wiped").show(false)
+---+-------------------------+
|id |features_wiped           |
+---+-------------------------+
|0  |(6,[0],[1.0])            |
|1  |(6,[3,4],[1.0,1.0])      |
|2  |(6,[0,4],[1.0,1.0])      |
+---+-------------------------+
zqdjd7g9

zqdjd7g91#

可以通过如下方式将函数wipe转换为curried函数:

def wipe(v: NewSparseVector)(idx2clean:Set[Int]) : NewSparseVector

要为相应函数创建自定义项,请执行以下操作:

val udf_wipe = udf((x: NewSparseVector) => wipe(x)(feature_idx_to_wipe))

最后,将udf应用于Dataframe:

dfA.withColumn("features_wiped", udf_wipe(col("features")))
s2j5cfk0

s2j5cfk02#

另一种选择-

试验数据

//data
    val dfA = spark.createDataFrame(Seq(
      (0, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
      (1, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (4, 1.0)))),
      (2, Vectors.sparse(6, Seq((0, 1.0), (2, 1.0), (4, 1.0))))
    )).toDF("id", "features")
    dfA.show(false)

    /**
      * +---+-------------------------+
      * |id |features                 |
      * +---+-------------------------+
      * |0  |(6,[0,1,2],[1.0,1.0,1.0])|
      * |1  |(6,[2,3,4],[1.0,1.0,1.0])|
      * |2  |(6,[0,2,4],[1.0,1.0,1.0])|
      * +---+-------------------------+
      */

备选方案-1使用照明,如下所示-

// Alternative-1
    //udf
    val feature_idx_to_wipe = Array(1, 2)
    import org.apache.spark.ml.linalg.{SparseVector => NewSparseVector}
    def wipe(v: NewSparseVector, idx2clean:Seq[Int]) : NewSparseVector = {
      val lb:ListBuffer[(Int, Double)]=ListBuffer()
      v.foreachActive {
        case (i, v) =>
          if(!idx2clean.contains(i)){
            lb += ((i, v))
          }
      }

      Vectors.sparse(v.size, lb.toSeq).toSparse
    }
    val udf_wipe = udf((x: NewSparseVector, idx2clean:Seq[Int]) => wipe(x, idx2clean))

    //apply udf
    val newDF = dfA.withColumn("features_wiped", udf_wipe(col("features"), lit(feature_idx_to_wipe)))

    //target (a new column of vector added, with features at index 1,2 are removed)
    newDF.select("id","features_wiped").show(false)
    /**
      * +---+-------------------+
      * |id |features_wiped     |
      * +---+-------------------+
      * |0  |(6,[0],[1.0])      |
      * |1  |(6,[3,4],[1.0,1.0])|
      * |2  |(6,[0,4],[1.0,1.0])|
      * +---+-------------------+
      */

备选方案2使用广播变量sparkcontext.broadcast,如下所示-

//    Alternative2
    //data
    val feature_idx_to_wipe1 = Set(1, 2)
    val broabcastSet = spark.sparkContext.broadcast(feature_idx_to_wipe1)

    //udf
    import org.apache.spark.ml.linalg.{SparseVector => NewSparseVector}
    def wipe1(v: NewSparseVector) : NewSparseVector = {
      val idx2clean = broabcastSet.value
      val lb:ListBuffer[(Int, Double)]=ListBuffer()
      v.foreachActive {
        case (i, v) =>
          if(!idx2clean.contains(i)){
            lb += ((i, v))
          }
      }

      Vectors.sparse(v.size, lb.toSeq).toSparse
    }
    val udf_wipe1 = udf((x: NewSparseVector) => wipe1(x))

    //apply udf
    val newDF1 = dfA.withColumn("features_wiped", udf_wipe1(col("features")))

    //target (a new column of vector added, with features at index 1,2 are removed)
    newDF1.select("id","features_wiped").show(false)

    /**
      * +---+-------------------+
      * |id |features_wiped     |
      * +---+-------------------+
      * |0  |(6,[0],[1.0])      |
      * |1  |(6,[3,4],[1.0,1.0])|
      * |2  |(6,[0,4],[1.0,1.0])|
      * +---+-------------------+
      */

相关问题