【问题标题】:Gradient Descent Runtime Error梯度下降运行时错误
【发布时间】:2017-08-07 13:35:53
【问题描述】:

我已经为 python 做了一个梯度下降的简单实现,它对大多数参数都很好,但对于某些学习率和迭代次数的参数,它会给我一个运行时错误。

RuntimeWarning:double_scalars 中遇到溢出

RuntimeWarning: double_scalars 中遇到无效值

现在我假设因为出现溢出错误,b 和 m 值变得太大而无法存储在内存中,这个假设是否正确?

以及如何防止程序崩溃,因为主程序中的异常处理似乎不起作用,你能想出一种不进行异常处理的方法来从逻辑上防止错误吗?

def compute_error(points,b,m):
    error = 0
    for i in range(len(points)):
        y = ponts[i][1]
        x = points[i][0]
        error +=  (y - (m*x + b))**2
    return error/len(points)

def gradient_runner(points,LR,num_iter,startB=0,startM=0):
    b = startB
    m = startM
    for i in range(num_iter):
        b,m = step_gradient(points,b,m,LR)
    return [b,m]

def step_gradient(points,b,m,LR):
    b_gradient = 0
    m_gradient = 0
    N = float(len(points))
    for i in range(len(points)):
        x = points[i][0]
        y = points[i][1]
        b_gradient+= (-2/N)*(y - ((m*x)+b))
        m_gradient+= (-2/N)*x*(y - ((m*x)+b))
##    print "Value for b_gradient",b_gradient
##    print "Value for b is ",b
##    print "Value for learning rate is ",LR
    new_b = b - (LR * b_gradient)
    new_m = m - (LR * m_gradient)
    return [new_b,new_m]    

import numpy as np
a = np.array([[1,1],[4,2],[6,3],[8,4],[11,5],[12,6],[13,7],[16,8]])

b,m=gradient_runner(a,0.0001,1000) # These parameters work
# b,m=gradient_runner(a,0.1,10000) #Program Crashes
yguesses = [m * i + b for i in a[:,0]]


import matplotlib.pyplot as plt

guezz= yguesses

plt.scatter(a[:,0], a[:,1] ,color="green")
plt.plot(a[:,0],guezz,color="red")

plt.show()

【问题讨论】:

    标签: python machine-learning runtime-error gradient


    【解决方案1】:

    问题在于“学习率”LR(仅通过更改 LR 来测试这一点——你会发现如果你走得足够低,算法就会收敛)。如果LR 的值太高,您每次都会迈出太大的一步(假设您在每一步都“跳过”正确的值)。有一些方法可以计算最大步长应该是多少。谷歌一下(例如“梯度下降步长”)。

    但是,正如您所注意到的,如果出现溢出,则结果很可能是错误的。在 Python 中你可以catch warnings,你可以用它来告诉用户结果是错误的。

    【讨论】:

    • 如果学习率现在变得太低而无法避免溢出,我们有什么选择。在我的例子中,将学习率设置为 0.0001 会使错误消失,但我的梯度下降即使在 2000 次迭代中也不会收敛。
    • @AbhyudayaSrinet 假设您的算法是正确的,那么您的问题可能不适合简单的梯度下降。您可以尝试在谷歌上搜索“改进条件梯度下降”之类的内容。
    • 谢谢。原来我的实现是不正确的。我正在做 oneVsAll 分类并使用标签作为我的输出值,而我应该一直使用类标签向量(例如 [0,1,0,0])。这个answer 帮我弄明白了。
    猜你喜欢
    • 2011-09-27
    • 1970-01-01
    • 2020-08-07
    • 2021-02-20
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-06-13
    相关资源
    最近更新 更多