【问题标题】:Does model.reset_states for LSTM affect any other non-LSTM layers in the model?LSTM 的 model.reset_states 会影响模型中的任何其他非 LSTM 层吗?
【发布时间】:2020-04-26 12:14:41
【问题描述】:
我在tf.keras 中使用LSTM 的有状态模式,当我处理我的序列数据时,我需要手动执行reset_states,如here 所述。似乎通常人们会使用model.reset_states(),但在我的情况下,我的 LSTM 层嵌入在一个更复杂的网络中,其中包括各种其他层,如 Dense、Conv 等。我的问题是,如果我只是在嵌入了 LSTM(并且只有一个 LSTM)的主模型上调用 model.reset_states(),我是否应该担心重置会影响模型中的其他层,例如 Dense 或 Conv 层?寻找 LSTM 层并将 reset_states 调用隔离到该层会更好吗?
【问题讨论】:
标签:
python
tensorflow
keras
lstm
tf.keras
【解决方案1】:
TLDR:像LSTM/GRU 这样的层有权重和状态,而像Conv/Dense/Embedding 这样的层只有权重。 reset_state() 只影响有状态的层。
reset_states() 所做的是,对于 LSTM,它会重置层中的 c_t 和 h_t 输出。这些是您通常通过设置LSTM(n, return_state=True) 获得的值。
Embedding、Dense、Conv 层中没有这样的状态。所以model.reset_states() 不会影响那些前馈层。只是像 LSTM 和 GRU 这样的顺序层。
如果您愿意,可以查看source code 并验证该函数是否看起来每个层都具有reset_state 属性(前馈层没有)。
【解决方案2】:
任何具有可设置 stateful 属性的层都受reset_states() 的约束;该方法遍历每一层,检查它是否有stateful=True - 如果有,调用它的reset_states() 方法;见source。
在 Keras 中,包括ConvLSTM2D 在内的所有循环层都有一个可设置的stateful 属性——我不知道还有其他属性。然而,tensorflow.keras 有很多自定义层实现,它们可能会;您可以使用下面的代码来确定:
def print_statefuls(model):
for layer in model.layers:
if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
print(layer.name, "is stateful")