我试图通过spark来适应xgboost车型。这是我的最小可复制实现。 task-config
包含要分析的字段和模型参数。
private var spark: SparkSession = _
private val taskConfig = JsonLoader.parseJsonConfig[TaskConfig]("/task-config/local_train.json")
private val Some(groupId: String) = this.taskConfig.taskInfo.groupColumn
private val Some(ranking: String) = this.taskConfig.taskInfo.rankingColumn
private val numPartition = 1000
private val spark = SparkSession
.builder()
.master("local[*]")
.appName("test")
.getOrCreate()
在我使用的产品中 SparkSession
. 作为赫比,我用 csv
我决定把它作为一个入口点,而不是进行sql查询。
val data = this.spark.read
.format("csv")
.option("header", "true")
.option("mode", "DROPMALFORMED")
.load(this.taskConfig.training.db + this.taskConfig.training.table)
当我打印我的模式时,阅读似乎工作得很好 data.schema.fields
我看到我的字段列表: Array(StructField(col1 col2 col3 ...)
当我尝试将数据传递到管道中时抛出异常
private def pipeline: Pipeline = {
val num_features = getFeaturesByType(this.taskConfig, "numeric")
val numFeatureAssembler = new VectorAssembler()
.setInputCols(num_features)
.setOutputCol("num_features")
val pipelineStages = Array(numFeatureAssembler)
new Pipeline().setStages(pipelineStages)
}
val pipelineModel = Some(pipeline.fit(data))
我在执行sql查询时没有注意到同样的问题,所以我想我应该更深入地研究数据读取器输出的正确格式。不过,如有任何建议,我们将不胜感激。
暂无答案!
目前还没有任何答案,快来回答吧!