【问题标题】:Mixing feed forward layers and recurrent layers in Tensorflow?在 Tensorflow 中混合前馈层和循环层?
【发布时间】:2016-07-25 15:27:53
【问题描述】:

有没有人能够在 Tensorflow 中混合前馈层和循环层?

例如: 输入->conv->GRU->线性->输出

我可以想象一个人可以定义他自己的带有前馈层的单元格,并且没有可以使用 MultiRNNCell 函数堆叠的状态,例如:

cell = tf.nn.rnn_cell.MultiRNNCell([conv_cell,GRU_cell,linear_cell])

这会让生活变得更轻松......

【问题讨论】:

    标签: python tensorflow recurrent-neural-network gated-recurrent-unit


    【解决方案1】:

    你不能只做以下事情:

    rnnouts, _ = rnn(grucell, inputs)
    linearout = [tf.matmul(rnnout, weights) + bias for rnnout in rnnouts]
    

    等等

    【讨论】:

      【解决方案2】:

      This tutorial 给出了一个如何将卷积层与循环层一起使用的示例。例如,最后一个卷积层是这样的:

      ...
      l_conv4_a = conv_pre(l_pool3, 16, (5, 5), scope="l_conv4_a")
      l_pool4 = pool(l_conv3_a, scope="l_pool4")
      l_flatten = flatten(l_pool4, scope="flatten")
      

      并且已经定义了 RNN 单元:

      _, shape_state = tf.nn.dynamic_rnn(cell=shape_cell,
          inputs=tf.expand_dims(batch_norm(x_shape_pl), 2), dtype=tf.float32, scope="shape_rnn")
      

      您可以连接两个输出并将其用作下一层的输入:

      features = tf.concat(concat_dim=1, values=[x_margin_pl, shape_state, x_texture_pl, l_flatten], name="features")
      

      或者你可以只使用 CNN 层的输出作为 RNN 单元的输入:

      _, shape_state = tf.nn.dynamic_rnn(cell=shape_cell,
          inputs=l_flatten, dtype=tf.float32, scope="shape_rnn")
      

      【讨论】:

        【解决方案3】:

        这是我目前所拥有的;欢迎改进:

        class LayerCell(rnn_cell_impl.RNNCell):
        
            def __init__(self, tf_layer, **kwargs):
                ''' :param tf_layer: a tensorflow layer, e.g. tf.layers.Conv2D or 
                    tf.keras.layers.Conv2D. NOT tf.layers.conv2d !
                    Can pass all other layer params as well, just need to give the 
                    parameter name: paramname=param'''
                self.layer_fn = tf_layer(**kwargs)
        
            def __call__(self, inputs, state, scope=None):
                ''' Every `RNNCell` must implement `call` with
                  the signature `(output, next_state) = call(input, state)`.  The optional
                  third input argument, `scope`, is allowed for backwards compatibility
                  purposes; but should be left off for new subclasses.'''
                return (self.layer_fn(inputs), state)
        
            def __str__(self):
                    return "Cell wrapper of " + str(self.layer_fn)
        
            def __getattr__(self, attr):
                '''credits to https://stackoverflow.com/questions/1382871/dynamically-attaching-a-method-to-an-existing-python-object-generated-with-swig/1383646#1383646'''
                return getattr(self.layer_fn, attr)
        
            @property
            def state_size(self):
                """size(s) of state(s) used by this cell.
        
                It can be represented by an Integer, a TensorShape or a tuple of Integers
                or TensorShapes.
                """
                return  (0,) 
        
            @property
            def output_size(self):
                """Integer or TensorShape: size of outputs produced by this cell."""
                # use with caution; could be uninitialized
                return self.layer_fn.output_shape
        

        (当然,不要与循环层一起使用,因为状态保持会被破坏。)

        似乎适用于:tf.layers.Conv2D、tf.keras.layers.Conv2D、tf.keras.layers.Activation、tf.layers.BatchNormalization

        不适用于:tf.keras.layers.BatchNormalization。 至少在 tf.while 循环中使用它时对我来说失败了;抱怨组合来自不同框架的变量,类似于here。也许 keras 使用tf.Variable() instead of tf.get_variable() ...?


        用法:

        cell0 = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[40, 40, 3], output_channels=16, kernel_shape=[5, 5])
        cell1 = LayerCell(tf.keras.layers.Conv2D, filters=8, kernel_size=[5, 5], strides=(1, 1), padding='same')
        cell2 = LayerCell(tf.layers.BatchNormalization, axis=-1)
        
        inputs =  np.random.rand(10, 40, 40, 3).astype(np.float32)
        multicell = tf.contrib.rnn.MultiRNNCell([cell0, cell1, cell2])
        state = multicell.zero_state(batch_size=10, dtype=tf.float32)
        
        output = multicell(inputs, state)
        

        【讨论】:

          猜你喜欢
          • 2016-11-29
          • 2017-01-02
          • 1970-01-01
          • 2016-10-24
          • 2017-01-06
          • 2019-09-29
          • 2020-10-18
          • 2017-12-02
          • 2017-07-31
          相关资源
          最近更新 更多