【发布时间】:2019-10-09 11:18:44
【问题描述】:
我正在构建一个 LSTM,用于报告,并想总结有关它的内容。但是,我已经看到了在 Keras 中构建 LSTM 的两种不同方法,它们会为参数数量产生两个不同的值。
我想了解为什么参数会以这种方式不同。
This 问题正确地显示了这段代码的原因
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.layers import Embedding
from keras.layers import LSTM
model = Sequential()
model.add(LSTM(256, input_dim=4096, input_length=16))
model.summary()
产生 4457472 个参数。
据我所知,下面两个 LSTM 应该是一样的
m2 = Sequential()
m2.add(LSTM(1, input_dim=5, input_length=1))
m2.summary()
m3 = Sequential()
m3.add(LSTM((1),batch_input_shape=(None,5,1)))
m3.summary()
但是,m2 会产生28 参数,而m3 会产生12 参数。 为什么?
如何为 1 单位 LSTM 和 5 维输入计算 12?
包括警告消息。希望对您有所帮助。
输出
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (None, 256) 4457472
=================================================================
Total params: 4,457,472
Trainable params: 4,457,472
Non-trainable params: 0
_________________________________________________________________
Warning (from warnings module):
File "difparam.py", line 11
m2.add(LSTM(1, input_dim=5, input_length=1))
UserWarning: The `input_dim` and `input_length` arguments in recurrent layers are deprecated. Use `input_shape` instead.
Warning (from warnings module):
File "difparam.py", line 11
m2.add(LSTM(1, input_dim=5, input_length=1))
UserWarning: Update your `LSTM` call to the Keras 2 API: `LSTM(1, input_shape=(1, 5))`
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_2 (LSTM) (None, 1) 28
=================================================================
Total params: 28
Trainable params: 28
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_3 (LSTM) (None, 1) 12
=================================================================
Total params: 12
Trainable params: 12
Non-trainable params: 0
_________________________________________________________________
m2 是基于 Stack Overflow 问题中的信息构建的,m3 是基于 YouTube 的 this video 构建的。
【问题讨论】:
-
m2似乎说 4 x ((1 x 5) + (1^2)+ 1) = 4 x (1 + 5 + 1) = 4 x 7 = 28。@987654331 @没有意义。
标签: python-3.x keras lstm