【问题标题】:Reinforcement Learning with Keras model使用 Keras 模型进行强化学习
【发布时间】:2019-03-31 18:36:00
【问题描述】:

我试图在 Keras 中实现一个 q-learning 算法。根据文章,我发现了这些代码行。

for state, action, reward, next_state, done in sample_batch:
        target = reward
        if not done:
            #formula
          target = reward + self.gamma * np.amax(self.brain.predict(next_state)[0])
        target_f = self.brain.predict(state)
        #shape (1,2)
        target_f[0][action] = target
        print(target_f.shape)
        self.brain.fit(state, target_f, epochs=1, verbose=0)
    if self.exploration_rate > self.exploration_min:
        self.exploration_rate *= self.exploration_decay

变量sample_batch 是包含来自收集的数据的样本state, action, reward, next_state, done 的数组。 我还发现了下面的q-learning公式

为什么等式(代码)中没有- 符号?我发现np.amax 返回数组的最大值或沿轴的最大值。当我打电话给self.brain.predict(next_state) 时,我得到[[-0.06427538 -0.34116858]]。那么它在这个方程中起到了预测的作用呢?随着我们前进,target_f 是当前状态的预测输出,然后我们还通过这一步将奖励附加到它上面。然后,我们在当前的state(X) 和target_f(Y) 上训练模型。我有几个问题。 self.brain.predict(next_state) 的作用是什么,为什么没有减号?为什么我们在一个模型上预测两次?前self.brain.predict(state) and self.brain.predict(next_state)[0]

【问题讨论】:

    标签: python keras deep-learning reinforcement-learning q-learning


    【解决方案1】:

    为什么等式(代码)中没有 - 号?

    这是因为损失计算是在 fit 函数内部完成的。

    reward + self.gamma * np.amax(self.brain.predict(next_state)[0])
    

    这与损失函数中的 target 组件相同。

    在 keras 的 fit 方法中,损失将如下计算。 对于单个训练数据点(神经网络的标准符号),

    x = input state
    
    y = predicted value
    
    y_i = target value
    
    loss(x) = y_i - y
    

    在这一步目标 - 预测发生在内部。

    为什么我们在一个模型上预测两次?

    好问题!!!

     target = reward + self.gamma * np.amax(self.brain.predict(next_state)[0])
    

    在这一步中,我们预测下一个状态的值,以计算目标如果我们采取特定行动 a(表示为 Q(s,a))的状态 s 的值

     target_f = self.brain.predict(state)
    

    在这一步中,我们正在计算在状态 s 中我们可以采取的每个动作的所有 Q 值

    target = 1.00    // target is a single value for action a
    target_f = (0.25,0.25,0.25,0.25)   //target_f is a list of values for all actions
    

    然后执行以下步骤。

    target_f[0][action] = target
    

    我们只更改所选操作的值。 (如果我们采取行动 3)

    target_f = (0.25,0.25,1.00,0.25)  // only action 3 value will change
    

    现在 target_f 将是 实际目标值,我们正试图以正确的形状进行预测。

    【讨论】:

      猜你喜欢
      • 2022-11-10
      • 2023-03-27
      • 1970-01-01
      • 2019-11-21
      • 1970-01-01
      • 2019-06-03
      • 2018-02-25
      • 2020-10-07
      • 2020-02-11
      相关资源
      最近更新 更多