【问题标题】:Is Keras thread safe?Keras 线程安全吗?
【发布时间】:2017-04-12 12:24:22
【问题描述】:
我正在使用 Python 和 Keras(目前使用 Theano 后端,但我对切换没有疑虑)。我有一个神经网络,我可以并行加载和处理多个信息源。目前,我一直在一个单独的进程中运行每一个,它从文件中加载自己的网络副本。这似乎是对 RAM 的浪费,所以我认为拥有一个多线程进程和一个由所有线程使用的网络实例会更有效。但是,我想知道 Keras 对于任一后端是否都是线程安全的。如果我在不同的线程中同时在两个不同的输入上运行.predict(x),我会遇到竞争条件或其他问题吗?
谢谢
【问题讨论】:
标签:
python
multithreading
keras
【解决方案1】:
是的,Keras 是线程安全的,如果你稍微注意一下的话。
事实上,在强化学习中有一种称为Asynchronous Advantage Actor Critics (A3C) 的算法,其中每个代理都依赖同一个神经网络来告诉他们在给定状态下应该做什么。换句话说,每个线程同时调用model.predict,就像您的问题一样。使用 Keras 的示例实现是 here。
但是,如果您查看代码,则应特别注意这一行:
model._make_predict_function() # have to initialize before threading
Keras 文档中从未提及这一点,但它必须让它同时工作。简而言之,_make_predict_function 是编译predict 函数的函数。在多线程设置中,必须手动调用该函数提前编译predict,否则predict函数要等到第一次运行后才会编译,多线程同时调用会出现问题。可以看详细解释here。
到目前为止,我还没有遇到过 Keras 中的多线程问题。
【解决方案2】:
引用那种fcholet:
_make_predict_function 是一个私有 API。我们不建议调用它。
在这里,用户应该首先调用 predict。
请注意,不能保证 Keras 模型是线程安全的。
考虑在每个线程中拥有模型的独立副本
用于 CPU 推理。