【问题标题】:How to summarize pytorch model如何总结pytorch模型
【发布时间】:2022-11-17 10:03:44
【问题描述】:

你好,我正在为 cartpole 上的强化学习构建一个 DQN 模型,并想像 keras model.summary() 函数一样打印我的模型摘要

这是我的模型类。

class DQN():
    ''' Deep Q Neural Network class. '''
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=0.05):
            super(DQN, self).__init__()
            self.criterion = torch.nn.MSELoss()
            self.model = torch.nn.Sequential(
                            torch.nn.Linear(state_dim, hidden_dim),
                            torch.nn.ReLU(),
                            torch.nn.Linear(hidden_dim, hidden_dim*2),
                            torch.nn.ReLU(),
                            torch.nn.Linear(hidden_dim*2, action_dim)
                    )
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr)



    def update(self, state, y):
        """Update the weights of the network given a training sample. """
        y_pred = self.model(torch.Tensor(state))
        loss = self.criterion(y_pred, Variable(torch.Tensor(y)))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


    def predict(self, state):
        """ Compute Q values for all actions using the DQL. """
        with torch.no_grad():
            return self.model(torch.Tensor(state))

这是传递了参数的模型实例。

# Number of states = 4
n_state = env.observation_space.shape[0]
# Number of actions = 2
n_action = env.action_space.n
# Number of episodes
episodes = 150
# Number of hidden nodes in the DQN
n_hidden = 50
# Learning rate
lr = 0.001


simple_dqn = DQN(n_state, n_action, n_hidden, lr)


我尝试使用 torchinfo summary 但我得到一个 AttributeError: “DQN”对象没有属性“named_pa​​rameters”

from torchinfo import summary
simple_dqn = DQN(n_state, n_action, n_hidden, lr)
summary(simple_dqn, input_size=(4, 2, 50))

任何帮助表示赞赏。

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    你的DQN应该是nn.Module的子类

    class DQN(nn.Module):
        def __init__(self, state_dim, action_dim, hidden_dim=64, lr=0.05):
            ...
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-07-17
      • 2021-04-10
      • 2020-06-14
      • 1970-01-01
      • 1970-01-01
      • 2021-06-23
      • 2018-08-18
      相关资源
      最近更新 更多