【问题标题】:Predict/infer asynchronously with a Keras or TF model?使用 Keras 或 TF 模型进行异步预测/推断?
【发布时间】:2021-06-25 04:47:01
【问题描述】:

我一直在尝试使用队列和线程在 keras 模型上异步运行预测,同时编译输入然后即时检索输出。

为此,我尝试在从队列中提取输入的生成器上使用 model.predict(),然后在计算结果后使用自定义 on_predict_batch_end 回调将结果推送到另一个队列。

我的方法运行,但不幸的是,我似乎无法从回调中检索预测。我只能得到一个不让我访问实际输出的非急切张量。 tf.config.run_functions_eagerly(True) 没有帮助。

以下是我正在尝试做的事情的简要总结:

from queue import Queue
from threading import Thread


# i/o queues
input_queue = Queue(1)
output_queue = Queue(1)

# inputs generator:
def feed():
    while True:
        yield input_queue.get()

# Custom Keras callback that is supposed to pull predictions on the fly
# and add them to output queue:
class CustomCallback(tf.keras.callbacks.Callback):
    def on_predict_batch_end(self, batch, logs=None):
        # Retrieve "output" put in output queue:
        output_queue.put(self.model.layers[-1].output)

# Prediction function to use in thread call:
def prediction_function():
    model.predict(feed(),verbose=1, callbacks=[CustomCallback()])

# Start prediction thread:
pt = Thread(target=prediction_function, daemon=True)
pt.start()

# Add input to the queue and retrieve when output is ready:
for i in input_data:
    iq.put(i)
    output_data += [oq.get()]

这运行得很好,但我在 output_data 列表中得到的只是一个非急切的张量。因为它不急于.numpy() 不起作用,.eval() 会破坏目的。有没有办法在不为队列中的每个新输入再次调用 model.predict() 的情况下即时访问预测,这非常慢?

【问题讨论】:

    标签: tensorflow keras tensorflow2.0 tf.keras


    【解决方案1】:

    解决了我的问题。如果您查看 keras 的 predict 方法,它会通过日志将批处理输出传递给 on_predict_batch_end() 。在回调中你需要做的就是将它转发到队列中:

    class CustomCallback(tf.keras.callbacks.Callback):
        def on_predict_batch_end(self, batch, logs=None):
            # Retrieve "output" put in output queue:
            output_queue.put(logs['outputs'])
    

    【讨论】:

      猜你喜欢
      • 2021-02-20
      • 2018-07-24
      • 2020-07-14
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-05-18
      • 2021-03-24
      • 2021-07-22
      相关资源
      最近更新 更多