【问题标题】:Tensorflow 2 custom LSTM Cell initial statesTensorflow 2 自定义 LSTM Cell 初始状态
【发布时间】:2021-07-11 19:54:38
【问题描述】:

我尝试实现一个自定义 LSTM 单元。首先,在进行自定义之前,我尝试重现原始 LSTM 单元。但是,我遇到了一个问题,初始状态是单个张量而不是元组。

class LSTMCell(keras.layers.Layer):
   def __init__(self, units, activation='tanh',
           recurrent_activation='hard_sigmoid',
           use_bias=True,
           kernel_initializer='glorot_uniform',
           recurrent_initializer='orthogonal',
           bias_initializer='zeros',**kwargs):
    self.units = units
    self.state_size = units
    self.kernel_initializer = kernel_initializer
    self.use_bias = use_bias
    self. recurrent_initializer = recurrent_initializer
    self.bias_initializer = bias_initializer
    super(LSTMCell, self).__init__(**kwargs)
def build(self, input_shape):
    input_dim = input_shape[-1]
    self.kernel = self.add_weight(
        shape=(input_dim, self.units * 4),
        name='kernel',
        initializer=self.kernel_initializer)
    self.recurrent_kernel = self.add_weight(
        shape=(self.units, self.units * 4),
        name='recurrent_kernel',
        initializer=self.recurrent_initializer)
    self.bias = self.add_weight(
      shape=(self.units * 4,),
      name='bias',
      initializer=self.bias_initializer)
    
    
       
def _compute_carry_and_output_fused(self, z, c_tm1):
    z0, z1, z2, z3 = z
    i = K.sigmoid(z0)
    f = K.sigmoid(z1)
    c = f * c_tm1 + i * K.tanh(z2)
    o = K.sigmoid(z3)
    return c, o
    
def call(self, inputs, states, training=None):
    
    
   
    h_tm1 = states[0]  # previous memory state
    c_tm1 = states[1]  # previous carry state
    
    z = K.dot(inputs, self.kernel)
    z += K.dot(h_tm1, self.recurrent_kernel)
    z = K.bias_add(z, self.bias)
    z = tf.split(z, num_or_size_splits=4, axis=1)
    c, o = self._compute_carry_and_output_fused(z, c_tm1)

    h = o * K.sigmoid(c)
    self.h = h
    self.c = c
    
    return h, [h,c]

cell= LSTMCell(32)  
layer = RNN(cell) 
a = np.random.rand(44,10,40)
out = layer(a)

我收到错误消息: c_tm1 = 状态[1]

IndexError: 元组索引超出范围

我猜问题是调用了普通 rnn 单元的 get_init_state 函数,它返回了一个张量。我试图通过初始化班级中的状态来解决这个问题:

self.initial = True
.....

if self.initial:
 h_tm1 = tf.zeros(shape=[inputs.shape[0], self.state_size], name='h')
 c_tm1 = tf.zeros(shape=[inputs.shape[0], self.state_size], name='c')
 self.initial = False

'''

但它也没有工作。如何让我的 LSTMCell 调用正确的 get_init_state 函数或以任何方式工作?

【问题讨论】:

    标签: lstm tensorflow2.0 recurrent-neural-network keras-layer tf.keras


    【解决方案1】:

    你应该用 self.state_size 给出正确的形状:

    self.state_size = (self.units, self.units)
    

    【讨论】:

      【解决方案2】:

      我刚刚使用 Tensorflow 2.7.0 包测试了您的代码,您的代码似乎运行良好。

      也许您定义的函数(def __ init __、def build、def _compute_carry_and_output_fused 和 def 调用) 在 class LSTMCell 中没有正确对齐。

      而且,也许您的 self 实例对象(units、state_size、...、bias_initializer)和 super init 在 init 定义的函数内也没有对齐。 p>

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2019-05-05
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2016-11-09
        • 2016-12-13
        相关资源
        最近更新 更多