【发布时间】:2021-07-11 19:54:38
【问题描述】:
我尝试实现一个自定义 LSTM 单元。首先,在进行自定义之前,我尝试重现原始 LSTM 单元。但是,我遇到了一个问题,初始状态是单个张量而不是元组。
class LSTMCell(keras.layers.Layer):
def __init__(self, units, activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',**kwargs):
self.units = units
self.state_size = units
self.kernel_initializer = kernel_initializer
self.use_bias = use_bias
self. recurrent_initializer = recurrent_initializer
self.bias_initializer = bias_initializer
super(LSTMCell, self).__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer)
self.bias = self.add_weight(
shape=(self.units * 4,),
name='bias',
initializer=self.bias_initializer)
def _compute_carry_and_output_fused(self, z, c_tm1):
z0, z1, z2, z3 = z
i = K.sigmoid(z0)
f = K.sigmoid(z1)
c = f * c_tm1 + i * K.tanh(z2)
o = K.sigmoid(z3)
return c, o
def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
z = K.dot(inputs, self.kernel)
z += K.dot(h_tm1, self.recurrent_kernel)
z = K.bias_add(z, self.bias)
z = tf.split(z, num_or_size_splits=4, axis=1)
c, o = self._compute_carry_and_output_fused(z, c_tm1)
h = o * K.sigmoid(c)
self.h = h
self.c = c
return h, [h,c]
cell= LSTMCell(32)
layer = RNN(cell)
a = np.random.rand(44,10,40)
out = layer(a)
我收到错误消息: c_tm1 = 状态[1]
IndexError: 元组索引超出范围
我猜问题是调用了普通 rnn 单元的 get_init_state 函数,它返回了一个张量。我试图通过初始化班级中的状态来解决这个问题:
self.initial = True
.....
if self.initial:
h_tm1 = tf.zeros(shape=[inputs.shape[0], self.state_size], name='h')
c_tm1 = tf.zeros(shape=[inputs.shape[0], self.state_size], name='c')
self.initial = False
'''
但它也没有工作。如何让我的 LSTMCell 调用正确的 get_init_state 函数或以任何方式工作?
【问题讨论】:
标签: lstm tensorflow2.0 recurrent-neural-network keras-layer tf.keras