【问题标题】:Keras multiply layer output with scalar带有标量的 Keras 多层输出
【发布时间】:2017-01-14 03:15:59
【问题描述】:

我有一个层输出,我想乘以一个标量。我可以使用 lambda 层来做到这一点,即

sc_mult = Lambda(lambda x: x * 2)(layer)

效果很好。但是,如果我想为每个示例使用不同的标量,我会尝试将它们作为第二个输入提供,形状为 (Examples, 1)

input_scalar = Input(shape = (1L,))

因此我的 lambda 层变为

sc_mult = Lambda(lambda x: x * input_scalar)(layer)

但现在这会在训练时引发错误。注意,32 是批量大小,128 是层输入(和输出)的维度 - 层输入乘以标量是 (batch_size x 32(前一层的过滤器) x 128(空间暗淡) x 128 (空间暗淡))。

GpuElemwise. Input dimension mis-match. Input 5 (indices start at 0) has shape[2] == 32, but the output's size on that axis is 128.

我假设我没有通过输入层输入正确的形状,但不知道为什么。

【问题讨论】:

    标签: lambda theano keras


    【解决方案1】:

    不确定回答一个老问题是否有用,但也许其他人也遇到过同样的问题。

    问题确实是标量的形状与输入(或 x)的形状。您应该使用np.reshape 重塑您的标量,使其具有与您相乘的矩阵一样多的维度,例如:

    from keras import *
    from keras.layers import *
    import numpy as np
    
    # inputs
    X = np.ones((32,32,128,128))
    s = np.arange(32).reshape(-1,1,1,1) # 1 different scalar per batch example, reshaped
    print(X.shape, s.shape)
    
    # model
    input_X = Input(shape=(32,128,128))
    input_scalar = Input(shape = (1,1,1))
    sc_mult = Lambda(lambda x: x * input_scalar)(input_X)
    model = Model(inputs=[input_X, input_scalar], outputs=sc_mult)
    
    out = model.predict([X,s])
    out
    

    现在out[0,:,:,:] 全部为零,out[1,:,:,:] 全部为 1,out[31,:,:,:] 全部为 31s,等等。

    【讨论】:

      【解决方案2】:

      另一种方法是将 lambda 层内容放入这样的函数中:

      def mul_sca(x):
          return tf.multiply(x[0],x[1])
      

      然后调用它:

      sc_mult = Lambda(mul_sca)([input_scalar,layer])
      

      这帮助我在使用多个 GPU 时避免了一些错误。

      【讨论】:

        【解决方案3】:

        乘以它

        x = Dense(10, activation="sigmoid")(inputs)
        x_multiplied = x * 5.0
        

        => 可能的输出范围 = 0..5

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 2021-05-07
          • 2019-04-08
          • 1970-01-01
          • 1970-01-01
          • 2017-07-23
          • 2019-09-17
          • 2019-03-28
          相关资源
          最近更新 更多