【问题标题】:Pytorch to Keras code equivalencePytorch 到 Keras 代码等价
【发布时间】:2018-04-02 16:01:39
【问题描述】:

给定 PyTorch 中的以下代码,Keras 等价物是什么?

class Network(nn.Module):

    def __init__(self, state_size, action_size):
        super(Network, self).__init__()

        # Inputs = 5, Outputs = 3, Hidden = 30
        self.fc1 = nn.Linear(5, 30)
        self.fc2 = nn.Linear(30, 3)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        outputs = self.fc2(x)
        return outputs

是这个吗?

model = Sequential()
model.add(Dense(units=30, input_dim=5, activation='relu'))
model.add(Dense(units=30, activation='relu'))
model.add(Dense(units=3, activation='linear'))

还是这个?

model = Sequential()
model.add(Dense(units=30, input_dim=5, activation='linear'))
model.add(Dense(units=30, activation='relu'))
model.add(Dense(units=3, activation='linear'))

是吗?

model = Sequential()
model.add(Dense(units=30, input_dim=5, activation='relu'))
model.add(Dense(units=30, activation='linear'))
model.add(Dense(units=3, activation='linear'))

谢谢

【问题讨论】:

    标签: keras pytorch


    【解决方案1】:

    据我所知,它们看起来都不正确。正确的 Keras 等效代码是:

    model = Sequential()
    model.add(Dense(30, input_shape=(5,), activation='relu')) 
    model.add(Dense(3)) 
    

    model.add(Dense(30, input_shape=(5,), activation='relu'))

    模型将采用形状 (*, 5) 的输入数组和形状 (*, 30) 的输出数组。除了input_shape,您也可以使用input_diminput_dim=5 等价于input_shape=(5,)

    model.add(密集(3))

    在第一层之后,不再需要指定输入的大小。此外,如果您不指定任何激活,则不会应用任何激活(相当于线性激活)。


    另一种选择是:

    model = Sequential()
    model.add(Dense(30, input_dim=5)) 
    model.add(Activation('relu'))
    model.add(Dense(3)) 
    

    希望这是有道理的!

    【讨论】:

    • 2个密集层,一个激活relu,另一个是linear
    • 还有一个“隐藏的输入层”在这种情况下也没有出现。如果算上这个,它们是 3 层。
    【解决方案2】:

    看起来像

    model = Sequential()
    model.add(InputLayer(input_shape=input_shape(5,)) 
    model.add(Dense(30, activation='relu')
    model.add(Dense(3))
    

    如果您尝试将 Pytorch 模型转换为 Keras 模型,您也可以尝试使用Pytorch2Keras 转换器。

    它支持 Conv2d、Linear、Activations、Element-wise 操作等基础层。所以,我已经转换了 ResNet50,错误为 1e-6。

    【讨论】:

      【解决方案3】:
        model = Sequential()
        model.add(Dense(30, input_dim=5, activation='relu'))
        model.add(Dense(3, activation=None))
      

      【讨论】:

      • 请不要只发布代码作为答案,还要解释您的代码的作用以及它如何解决问题的问题。带有解释的答案通常更有帮助,质量更高,更有可能吸引投票。
      猜你喜欢
      • 2018-07-13
      • 1970-01-01
      • 1970-01-01
      • 2018-02-23
      • 2022-11-28
      • 2020-07-31
      • 2019-09-02
      • 2020-03-26
      • 2022-01-19
      相关资源
      最近更新 更多