【问题标题】:Keras model: TypeError: can't pickle _thread.lock objectsKeras 模型:TypeError:无法腌制 _thread.lock 对象
【发布时间】:2019-05-11 06:00:25
【问题描述】:

我无法在 PySpark 中使用经过训练的 Keras 模型。使用以下版本的库:

tensorflow==1.1.0
h5py==2.7.0
keras==2.0.4

另外,我使用 Spark 2.4.0。

from pyspark.sql import SparkSession
import pyspark.sql.functions as func
from keras.models import load_model

spark = SparkSession \
    .builder \
    .appName("Test") \
    .master("local[2]") \
    .getOrCreate()

my_model = load_model("my_model.h5")
spark.sparkContext.addFile("my_model.h5")
my_model_bcast = spark.sparkContext.broadcast(my_model)

# ...

get_prediction_udf = func.udf(get_prediction, IntegerType())
ds = ds\
    .withColumn("predicted_value", get_prediction_udf(my_model_bcast,
                                                      func.col("col1"),
                                                      func.col("col2"))))

函数get_prediction如下(简化代码):

def get_prediction(my_model_bcast, col1, col2):
    cur_state = np.array([col1,col2])
    state = cur_state.reshape(1,2)
    ynew = my_model_bcast.predict(state)
    return np.argmax(ynew[0])

下面的错误是由my_model_bcast = spark.sparkContext.broadcast(my_model)这一行触发的:

  File "/usr/local/spark-2.4.0-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/broadcast.py", line 110, in dump
    pickle.dump(value, f, 2)
TypeError: can't pickle _thread.lock objects

我正在阅读类似的主题以找到解决方案。据我了解,keras 不支持申请pickle。但在这种情况下,如何使用经过训练的模型在 PySpark 中进行预测?

【问题讨论】:

    标签: python apache-spark keras pyspark keras-2


    【解决方案1】:

    似乎无法序列化 keras 模型,所以也许只是分发文件并作为 spark 文件?所以在你的函数内部(你期望模型作为输入)你可以从那个路径读取文件并在里面创建模型?

    path = SparkFiles.get("mode_file.h5")
    model =  load_model(path)
    

    【讨论】:

    • 谢谢。我完全按照你的建议做了。我收到了错误Caused by: net.razorvine.pickle.objects.ClassDictConstructor.construct
    猜你喜欢
    • 2019-10-20
    • 2018-04-14
    • 2017-12-18
    • 2017-10-23
    • 1970-01-01
    • 2021-05-30
    • 2021-05-16
    • 2017-12-04
    • 2019-01-21
    相关资源
    最近更新 更多