如何用keras模型实现实时推理的spark?

mbyulnm0  于 2021-05-19  发布在  Spark
关注(0)|答案(0)|浏览(223)

下面是我的一段代码:

@pandas_udf(StringType())
def online_predict(values: pd.Series) -> pd.Series:
    pred = Model.from_config(bc_config.value)
    pred.set_weights(bc_weights.value)
    ds = tf.data.Dataset.from_tensor_slices(values)
    ds = ds.map(preprocessing).batch(batch_size)
    res = pred.predict(ds)
    res = tf.norm(res, axis=1)
    # res = tf.greater(res, 5.0)
    res = tf.strings.as_string(res).numpy()
    return pd.Series(res)

spark = SparkSession.builder.appName(
    'spark_tf').master("local[*]").getOrCreate()
weights = np.load('./ext/weights.npy', allow_pickle=True)
config = np.load('./ext/config.npy', allow_pickle=True).item()
bc_weights = spark.sparkContext.broadcast(weights)
bc_config = spark.sparkContext.broadcast(config)

stream = spark.readStream.format('kafka') \
    .option('kafka.bootstrap.servers', 'localhost:9092') \
    .option('subscribe', 'dlpred') \
    .load()

stream = stream.select(online_predict(col('value')).alias('value'))

x = stream.writeStream \
    .format('kafka') \
    .option("kafka.bootstrap.servers", 'localhost:9092') \
    .option('topic', 'dltest') \
    .option('checkpointLocation', './kafka_checkpoint') \
    .start()

x.awaitTermination()

所以基本上我的工作流程是:
广播模型的权重和配置。
从kafka初始化pyspark结构化流媒体管道,然后在其上应用pandas udf。
通过Pypark kafka接收器将消息发送回kafka。
这是个好习惯吗?我在pandas udf中初始化我的模型,因为我认为spark cluster处理pandas udf,所以在pandas udf之外初始化模型是没有意义的,即使使用广播的权重和配置,因为spark cluster不会将模型缩放到它的工作对象。
据我所知,每当有新行出现时,pyspark都会对每一行应用udf,所以模型初始化会重复,不是吗?当消息重复出现时,我还会得到“triggeredtf.function retracting”警告。总的来说,我对结构化流媒体和spark没有什么经验,所以我不知道它是否得到了正确的实现。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题