pyspark 在Databricks上使用MLflow记录Spark模型时出错- mlflow.spark.log_model()

lpwwtiir  于 5个月前  发布在  Spark
关注(0)|答案(1)|浏览(91)

我尝试使用下面的代码片段记录Spark模型。模型度量和参数保存在ML流运行中,但模型本身不会保存在artefacts下。但是,在同一环境中使用model.sklearn.log_model()记录Scikit-learn模型时,模型会成功保存。
环境:Databricks 10.4 LTS ML集群

train, test = train_test_random_split(conf, data)

experiment_name = "/mlflow_experiments/debug_spark_model"
mlflow.set_experiment(experiment_name)

evaluator = BinaryClassificationEvaluator()

rf = RandomForestClassifier()

param_grid = (
    ParamGridBuilder()
    .addGrid(rf.numTrees,[15)
    .addGrid(rf.maxDepth, [6])
    .addGrid(
        rf.minInstancesPerNode,
       [7],
    )
    .build()
)

cv = CrossValidator(
    estimator=rf,
    estimatorParamMaps=param_grid,
    evaluator=BinaryClassificationEvaluator(metricName="areaUnderROC"),
    numFolds=10,
)
cv_model = cv.fit(train)

# best model
model = cv_model.bestModel

model_params_best = {
    "numTrees": cv_model.getEstimatorParamMaps()[np.argmax(cv_model.avgMetrics)][
        cv_model.bestModel.numTrees
    ],
    "maxDepth": cv_model.getEstimatorParamMaps()[np.argmax(cv_model.avgMetrics)][
        cv_model.bestModel.maxDepth
    ],
    "minInstancesPerNode": cv_model.getEstimatorParamMaps()[
        np.argmax(cv_model.avgMetrics)
    ][cv_model.bestModel.minInstancesPerNode],
}

model_metrics_best, artifacts_best, predicted_df_best = train_model(
    model, train, test, evaluator
)
with mlflow.start_run(run_name="debug_run_1"):
    run_id = mlflow.active_run().info.run_id
    mlflow.log_params(model_params_best)
    mlflow.log_metrics(model_metrics_best)

    #debug 1
    artifact_path = "best_model"
    mlflow.spark.log_model(spark_model = model, artifact_path = artifact_path) 
    source = get_artifact_uri(run_id=run_id, artifact_path=artifact_path)

字符串
它给出了下面的错误。
Copyright © 2018 - 2019 www.qqq.com All Rights Reserved.粤ICP备15047777号-1技术支持:中企动力
x1c 0d1x的数据
我感谢任何调试方向或解决方案有关此错误。

brjng4g3

brjng4g31#

找到解决此错误或大多数与mlflowdbfs相关的错误的方法。
禁用mlflowdbfs在Databricks ML MySQL集群中的工作上述错误。另一个选项将使用正常的Databricks MySQL集群。

import os
os.environ["DISABLE_MLFLOWDBFS"] = "true"

字符串

相关问题