【问题标题】:Implementing momentum weight update for neural network实现神经网络的动量权重更新
【发布时间】:2018-05-11 06:36:30
【问题描述】:

我正在关注 mnielsen 的在线 book。我正在尝试按照他的代码here 定义here 来实现动量权重更新。总体思路是,对于动量权重更新,您不会直接更改具有负梯度的权重向量。您有一个参数velocity,您将其设置为零开始,然后将超参数mu 通常设置为0.9

# Momentum update
v = mu * v - learning_rate * dx # integrate velocity
x += v # integrate position

所以我在下面的代码 sn-p 中有权重 w 和权重变化为 nebla_w

def update_mini_batch(self, mini_batch, eta):
        """Update the network's weights and biases by applying
        gradient descent using backpropagation to a single mini batch.
        The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
        is the learning rate."""
        nabla_b = [np.zeros(b.shape) for b in self.biases]
        nabla_w = [np.zeros(w.shape) for w in self.weights]
        for x, y in mini_batch:
            delta_nabla_b, delta_nabla_w = self.backprop(x, y)
            nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
            nabla_w = [nw+dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
        self.weights = [w-(eta/len(mini_batch))*nw
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta/len(mini_batch))*nb
                       for b, nb in zip(self.biases, nabla_b)]

所以在最后两行中你将self.weight 更新为

self.weights = [w-(eta/len(mini_batch))*nw
                for w, nw in zip(self.weights, nabla_w)]

对于动量权重更新,我正在执行以下操作:

self.momentum_v = [ (momentum_mu * self.momentum_v) - ( ( float(eta) / float(len(mini_batch)) )* nw) 
                   for nw in nebla_w ]
self.weights = [ w + v 
                for w, v in zip (self.weights, self.momentum_v)]

但是,我收到以下错误:

 TypeError: can't multiply sequence by non-int of type 'float'

momentum_v 更新。我的eta 超参数已经是浮动的,尽管我再次用浮动函数包装了它。我也用浮动包裹了len(mini_batch)。我也尝试过nw.astype(float),但我仍然会收到错误消息。我不确定为什么。 nabla_w 是一个 numpy 浮点数组。

【问题讨论】:

  • TypeError: can't multiply sequence by non-int of type 'float' 当它们都是 numpy 数组时不会发生。某些东西是一个列表或一个元组或一些其他序列,而 不是 一个 numpy 数组。为每个被相乘的变量打印type(variable),你会看到一些东西不是一个numpy数组。您没有显示您对momentum_mumomentum_v 的定义,也许他们是违规者?例如从同一错误消息中查看我今天早些时候here 的回答。
  • 另外,错误并没有告诉你应该使用float,而是告诉你使用float而你不能,因此将一堆东西投射为float 并不能帮助解决问题。某处你将一个元组或列表乘以一个浮点数。
  • @AlexanderReynolds 我现在看到了!我以错误的方式初始化了momentum_v,这不是一个numpy数组!。太感谢了。你可以写在答案部分,然后我可以接受你的答案。

标签: python numpy machine-learning neural-network mnist


【解决方案1】:

正如 cmets 中所讨论的,这里的东西不是 numpy 数组。上面给出的错误

TypeError: can't multiply sequence by non-int of type 'float'

是 Python 针对序列类型(列表、元组等)发出的错误。错误消息意味着序列不能乘以非整数。它们可以乘以一个int,但这不会改变值——它只是重复序列,即

>>> [1, 0] * 3
[1, 0, 1, 0, 1, 0]

当然在这个框架中,乘以一个浮点数是没有意义的:

>>> [1, 0] * 3.14
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can't multiply sequence by non-int of type 'float'

您会看到与此处相同的错误消息。因此,您要相乘的变量之一确实不是 numpy 数组,而是通用序列类型之一。一个简单的 np.array() 围绕有问题的变量将修复它,或者当然你可以将定义更改为一个数组。

【讨论】:

    猜你喜欢
    • 2015-05-03
    • 1970-01-01
    • 2012-07-09
    • 2021-08-06
    • 1970-01-01
    • 2018-09-13
    • 2019-05-05
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多