【问题标题】:How to wrap a tensorflow RNNCell in keras?如何在keras中包装张量流RNNCell?
【发布时间】:2019-05-14 01:33:34
【问题描述】:

我想在 keras 层中实现自定义 LSTM 单元。实际上这个实现存在于 tensorflow 中,所以我想知道是否可以将其包装为 keras 层并在模型中调用它。

我发现官方documentation 太简单了,看不到如何构建自定义 RNN 层。 herehere 也有类似的问题,但似乎没有得到解决。

提前感谢您的帮助!

【问题讨论】:

    标签: python tensorflow keras rnn keras-layer


    【解决方案1】:

    现在 tensorflow 的文档可能在问题发布后有所改进。

    您可能需要查看this guidethis SO answer 以供参考。

    【讨论】:

      【解决方案2】:

      根据我的理解,您应该能够在类层的 init() 中初始化单元格,然后在调用方法中使用您的输入引用它。

      例如:

      class MySimpleLayer(Layer):
        def __init__(self, lstm_size):
          super(MySimpleLayer, self).__init__()
          self.lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
      
        def call(self, batch, state):
          return self.lstm(batch, state)
      
      layer = MySimpleLayer(lstm_size)
      logits = layer(batch, state)
      

      这个实现是最基本的,所以你可能需要研究 build() 和 compute_output_shape() 方法来处理更复杂的用例。

      【讨论】:

      • 对不起,call 这样的定义与Layer 不匹配;我得到 TypeError: call() missing 1 required positional argument: 'states'
      • Call() 肯定与 Layer 类一起使用,正如 here 所指定的那样。对我来说,这看起来像是一个实施错误。尝试将 [batch, state] 作为单个列表输入传递给 call()。
      猜你喜欢
      • 1970-01-01
      • 2020-02-18
      • 1970-01-01
      • 1970-01-01
      • 2021-08-24
      • 1970-01-01
      • 1970-01-01
      • 2019-10-07
      • 1970-01-01
      相关资源
      最近更新 更多