【问题标题】:Soft actor critic with discrete action space具有离散动作空间的软演员评论家
【发布时间】:2019-10-07 03:08:06
【问题描述】:

我正在尝试为离散动作空间实现软演员评论算法,但损失函数有问题。

这是来自 SAC 的链接,其中包含连续操作空间: https://spinningup.openai.com/en/latest/algorithms/sac.html

我不知道自己做错了什么。

问题是网络在 cartpole 环境中没有学到任何东西。

github上的完整代码:https://github.com/tk2232/sac_discrete/blob/master/sac_discrete.py

这是我如何计算离散动作的损失的想法。

价值网络

class ValueNet:
    def __init__(self, sess, state_size, hidden_dim, name):
        self.sess = sess

        with tf.variable_scope(name):
            self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='value_states')
            self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='value_targets')
            x = Dense(units=hidden_dim, activation='relu')(self.states)
            x = Dense(units=hidden_dim, activation='relu')(x)
            self.values = Dense(units=1, activation=None)(x)

            optimizer = tf.train.AdamOptimizer(0.001)

            loss = 0.5 * tf.reduce_mean((self.values - tf.stop_gradient(self.targets)) ** 2)
            self.train_op = optimizer.minimize(loss, var_list=_params(name))

    def get_value(self, s):
        return self.sess.run(self.values, feed_dict={self.states: s})

    def update(self, s, targets):
        self.sess.run(self.train_op, feed_dict={self.states: s, self.targets: targets})

在 Q_Network 中,我通过收集的操作收集值

示例

q_out = [[0.5533, 0.4444], [0.2222, 0.6666]]
collected_actions = [0, 1]
gather = [[0.5533], [0.6666]]

收集功能

def gather_tensor(params, idx):
    idx = tf.stack([tf.range(tf.shape(idx)[0]), idx[:, 0]], axis=-1)
    params = tf.gather_nd(params, idx)
    return params

Q 网络

class SoftQNetwork:
    def __init__(self, sess, state_size, action_size, hidden_dim, name):
        self.sess = sess

        with tf.variable_scope(name):
            self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='q_states')
            self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='q_targets')
            self.actions = tf.placeholder(dtype=tf.int32, shape=[None, 1], name='q_actions')

            x = Dense(units=hidden_dim, activation='relu')(self.states)
            x = Dense(units=hidden_dim, activation='relu')(x)
            x = Dense(units=action_size, activation=None)(x)
            self.q = tf.reshape(gather_tensor(x, self.actions), shape=(-1, 1))

            optimizer = tf.train.AdamOptimizer(0.001)

            loss = 0.5 * tf.reduce_mean((self.q - tf.stop_gradient(self.targets)) ** 2)
            self.train_op = optimizer.minimize(loss, var_list=_params(name))

    def update(self, s, a, target):
        self.sess.run(self.train_op, feed_dict={self.states: s, self.actions: a, self.targets: target})

    def get_q(self, s, a):
        return self.sess.run(self.q, feed_dict={self.states: s, self.actions: a})

政策网

class PolicyNet:
    def __init__(self, sess, state_size, action_size, hidden_dim):
        self.sess = sess

        with tf.variable_scope('policy_net'):
            self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='policy_states')
            self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='policy_targets')
            self.actions = tf.placeholder(dtype=tf.int32, shape=[None, 1], name='policy_actions')

            x = Dense(units=hidden_dim, activation='relu')(self.states)
            x = Dense(units=hidden_dim, activation='relu')(x)
            self.logits = Dense(units=action_size, activation=None)(x)
            dist = Categorical(logits=self.logits)

            optimizer = tf.train.AdamOptimizer(0.001)

            # Get action
            self.new_action = dist.sample()
            self.new_log_prob = dist.log_prob(self.new_action)

            # Calc loss
            log_prob = dist.log_prob(tf.squeeze(self.actions))
            loss = tf.reduce_mean(tf.squeeze(self.targets) - 0.2 * log_prob)
            self.train_op = optimizer.minimize(loss, var_list=_params('policy_net'))

    def get_action(self, s):
        action = self.sess.run(self.new_action, feed_dict={self.states: s[np.newaxis, :]})
        return action[0]

    def get_next_action(self, s):
        next_action, next_log_prob = self.sess.run([self.new_action, self.new_log_prob], feed_dict={self.states: s})
        return next_action.reshape((-1, 1)), next_log_prob.reshape((-1, 1))

    def update(self, s, a, target):
        self.sess.run(self.train_op, feed_dict={self.states: s, self.actions: a, self.targets: target})

更新函数

def soft_q_update(batch_size, frame_idx):
    gamma = 0.99
    alpha = 0.2

    state, action, reward, next_state, done = replay_buffer.sample(batch_size)
    action = action.reshape((-1, 1))
    reward = reward.reshape((-1, 1))
    done = done.reshape((-1, 1))

Q_target

v_ = value_net_target.get_value(next_state)
q_target = reward + (1 - done) * gamma * v_

V_target

next_action, next_log_prob = policy_net.get_next_action(state)
q1 = soft_q_net_1.get_q(state, next_action)
q2 = soft_q_net_2.get_q(state, next_action)
q = np.minimum(q1, q2)
v_target = q - alpha * next_log_prob

Policy_target

q1 = soft_q_net_1.get_q(state, action)
q2 = soft_q_net_2.get_q(state, action)
policy_target = np.minimum(q1, q2)

【问题讨论】:

    标签: python tensorflow machine-learning reinforcement-learning


    【解决方案1】:

    由于该算法对离散和连续策略都是通用的,因此关键思想是我们需要一个可重新参数化的离散分布。然后,扩展应该涉及对连续 SAC 的最小代码修改 --- 只需更改策略分发类。

    有一种这样的分布——GumbelSoftmax 分布。 PyTorch 没有此内置功能,因此我只是从具有正确 rsample() 的近亲扩展它并添加正确的 log prob 计算方法。由于能够计算重新参数化的动作及其对数概率,SAC 能够以最少的额外代码完美地处理离散动作,如下所示。

        def calc_log_prob_action(self, action_pd, reparam=False):
            '''Calculate log_probs and actions with option to reparametrize from paper eq. 11'''
            samples = action_pd.rsample() if reparam else action_pd.sample()
            if self.body.is_discrete:  # this is straightforward using GumbelSoftmax
                actions = samples
                log_probs = action_pd.log_prob(actions)
            else:
                mus = samples
                actions = self.scale_action(torch.tanh(mus))
                # paper Appendix C. Enforcing Action Bounds for continuous actions
                log_probs = (action_pd.log_prob(mus) - torch.log(1 - actions.pow(2) + 1e-6).sum(1))
            return log_probs, actions
    
    
    # ... for discrete action, GumbelSoftmax distribution
    
    class GumbelSoftmax(distributions.RelaxedOneHotCategorical):
        '''
        A differentiable Categorical distribution using reparametrization trick with Gumbel-Softmax
        Explanation http://amid.fish/assets/gumbel.html
        NOTE: use this in place PyTorch's RelaxedOneHotCategorical distribution since its log_prob is not working right (returns positive values)
        Papers:
        [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017)
        [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017)
        '''
    
        def sample(self, sample_shape=torch.Size()):
            '''Gumbel-softmax sampling. Note rsample is inherited from RelaxedOneHotCategorical'''
            u = torch.empty(self.logits.size(), device=self.logits.device, dtype=self.logits.dtype).uniform_(0, 1)
            noisy_logits = self.logits - torch.log(-torch.log(u))
            return torch.argmax(noisy_logits, dim=-1)
    
        def log_prob(self, value):
            '''value is one-hot or relaxed'''
            if value.shape != self.logits.shape:
                value = F.one_hot(value.long(), self.logits.shape[-1]).float()
                assert value.shape == self.logits.shape
            return - torch.sum(- value * F.log_softmax(self.logits, -1), -1)
    

    这是 LunarLander 的结果。 SAC 很快就学会了解决它。

    完整的实现代码在SLM Labhttps://github.com/kengz/SLM-Lab/blob/master/slm_lab/agent/algorithm/sac.py

    Roboschool(连续)和 LunarLander(离散)的 SAC 基准测试结果如下所示:https://github.com/kengz/SLM-Lab/pull/399

    【讨论】:

    • 您是否将您的实现性能与 Petros Christodoulou arxiv.org/abs/1910.07207 提出的性能进行了比较?在他的文章中,他用直接求和代替了动作采样,因为动作空间是离散且有限的。我猜它应该有助于减少方差并加快收敛速度​​。
    【解决方案2】:

    这个repo 可能会有所帮助。描述说,该 repo 包含 PyTorch 上离散动作空间的 SAC 实现。有file 使用 SAC 算法用于连续动作空间,file 使用 SAC 算法适用于离散动作空间。

    【讨论】:

    • 谢谢,看了下脚本,发现还有一篇关于SAC的论文。计算损失的架构和方式是不同的。我不太明白他是如何计算 repo 中的 policy_loss 的。在他在描述中提到的论文中是 policy_loss (alpha * log(pi) - Q) 但他使用 (pi * (log(pi) - q_min))
    • 据我了解,repo 的作者是由 reddit 上的this 线程指导的。
    • 还发现此 repo 带有离散 SAC(script 和适当的 PR)。据我了解,该实现也在 PyTorch 上,基于最新的 SAC paper
    • 看来sac_discrete又被删除了。我尝试使用 cartpole 和 CategoricalDistParams 运行该算法。不幸的是,它不能在没有错误的情况下运行
    • 不过,这个 repo 和相应的论文很棒。但我有一个大离散空间的想法,例如组合优化问题。这些问题通常具有非常大的动作空间,这是该解决方案无法处理的。我认为在那种情况下,我们别无选择使用 Gumbel softmax 解决方案。
    【解决方案3】:

    有一篇关于具有离散动作空间的 SAC 的论文。它说离散动作空间的 SAC 不需要像 Gumbel softmax 这样的重新参数化技巧。相反,SAC 需要一些修改。详情请参阅论文。

    Paper/ Author's implementation (without codes for atari)/ Reproduction (with codes for atari)

    希望对你有帮助。

    【讨论】:

    • 对不起,我忘记放论文的链接了。我编辑了上面的答案。
    【解决方案4】:

    Pytorch 1.8 有 RelaxedOneHotCategorical,这支持使用 gumbel softmax 重新参数化采样。

    import torch
    import torch.nn as nn
    from torch.distributions import RelaxedOneHotCategorical
    
        class Policy(nn.Module):
            def __init__(self, input_dims, hidden_dims, actions):
                super().__init__()
                self.mlp = nn.Sequential(nn.Linear(input_dims, hidden_dims), nn.SELU(inplace=True),
                                            nn.Linear(hidden_dims, hidden_dims), nn.SELU(inplace=True),
                                            nn.Linear(hidden_dims, out_dims))
    
            def forward(self, state):
                logits = torch.log_softmax(self.mlp(state), dim=-1)
                return RelaxedOneHotCategorical(logits=logits, temperature=torch.ones(1) * 1.0)
    
    >>> policy = Policy(4, 16, 2)
    >>> a_dist = policy(torch.randn(8, 4))
    >>> a_dist.rsample()
    tensor([[0.0353, 0.9647],
            [0.1348, 0.8652],
            [0.1110, 0.8890],
            [0.4956, 0.5044],
            [0.6941, 0.3059],
            [0.6126, 0.3874],
            [0.2932, 0.7068],
            [0.0498, 0.9502]], grad_fn=<ExpBackward>)
    

    【讨论】:

      猜你喜欢
      • 2013-01-15
      • 2019-01-15
      • 2019-09-12
      • 1970-01-01
      • 2020-03-09
      • 1970-01-01
      • 2021-08-03
      • 1970-01-01
      • 2013-06-05
      相关资源
      最近更新 更多