【问题标题】:Using multiprocessing.pool load keras model cause predict "ValueError Tensor Tensor"使用 multiprocessing.pool 加载 keras 模型导致预测“ValueError Tensor Tensor”
【发布时间】:2019-05-26 06:46:52
【问题描述】:

我为每个项目保存了 1000 多个模型。现在我需要将所有这些模型加载到内存(数据帧)中进行预测。如果我只是使用“for”循环来加载这些模型,每次加载都会比之前的模型加载慢 3 秒。所以我尝试使用multiprocessing.pool(ThreadPool)。

但是,奇怪的是,使用 ThreadPool 会导致预测“ValueError: Tensor Tensor”。如果使用正常加载,则预测很好。

我试过线程也有错误消息

#following code will lead to ValueError
from multiprocessing.pool import ThreadPool as Pool
def load_model(stock):
    model_pred.at[0, stock] = keras.models.load_model (
        'C:/Users/chenp/Documents/rqpro/models/{}_model.h5'.format (stock))


pool = Pool(processes=16)
for stock in trade_stocks['stock']:
    pool.map (load_model, (stock,))

#Prediction
for stock in trade_stocks['stock']:
    model = model_pred.loc[0, stock]
    prediction = model.predict(pred_data)

#Get following msg:
ValueError: Tensor Tensor("dense_9/Softmax:0", shape=(?, 2), dtype=float32) is not an element of this graph.

#Normal code but too low efficient
for stock in trade_stocks['stock']:
    model_pred.at[0, stock] = keras.models.load_model(
           'C:/Users/chenp/Documents/rqpro/models/{}_model.h5'.format(stock))





#Get following msg:
ValueError: Tensor Tensor("dense_9/Softmax:0", shape=(?, 2), dtype=float32) is not an element of this graph.

【问题讨论】:

    标签: python-3.x tf.keras


    【解决方案1】:

    这是因为 Keras 不是线程安全的。为了解决这个问题,请在预测之前使用_make_predict_function()。详细解答请check

    【讨论】:

      猜你喜欢
      • 2017-06-20
      • 2020-03-17
      • 1970-01-01
      • 2018-04-25
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-06-03
      • 2021-09-10
      相关资源
      最近更新 更多