【发布时间】:2021-06-25 04:47:01
【问题描述】:
我一直在尝试使用队列和线程在 keras 模型上异步运行预测,同时编译输入然后即时检索输出。
为此,我尝试在从队列中提取输入的生成器上使用 model.predict(),然后在计算结果后使用自定义 on_predict_batch_end 回调将结果推送到另一个队列。
我的方法运行,但不幸的是,我似乎无法从回调中检索预测。我只能得到一个不让我访问实际输出的非急切张量。 tf.config.run_functions_eagerly(True) 没有帮助。
以下是我正在尝试做的事情的简要总结:
from queue import Queue
from threading import Thread
# i/o queues
input_queue = Queue(1)
output_queue = Queue(1)
# inputs generator:
def feed():
while True:
yield input_queue.get()
# Custom Keras callback that is supposed to pull predictions on the fly
# and add them to output queue:
class CustomCallback(tf.keras.callbacks.Callback):
def on_predict_batch_end(self, batch, logs=None):
# Retrieve "output" put in output queue:
output_queue.put(self.model.layers[-1].output)
# Prediction function to use in thread call:
def prediction_function():
model.predict(feed(),verbose=1, callbacks=[CustomCallback()])
# Start prediction thread:
pt = Thread(target=prediction_function, daemon=True)
pt.start()
# Add input to the queue and retrieve when output is ready:
for i in input_data:
iq.put(i)
output_data += [oq.get()]
这运行得很好,但我在 output_data 列表中得到的只是一个非急切的张量。因为它不急于.numpy() 不起作用,.eval() 会破坏目的。有没有办法在不为队列中的每个新输入再次调用 model.predict() 的情况下即时访问预测,这非常慢?
【问题讨论】:
标签: tensorflow keras tensorflow2.0 tf.keras