【问题标题】:How to teach keras neural network to solve sqrt如何教keras神经网络解决sqrt
【发布时间】:2019-10-17 16:10:54
【问题描述】:

我正在学习使用 python 和 keras 进行机器学习。我创建了一个神经网络,从 {1, 4, 9, 16, 25, 36, ..., 100} 范围内的偶数整数中预测平方根。我已经编写了代码来做到这一点,但结果远非如此(无论我向网络提供什么数字,它都预测它是 1.0)。

我尝试过更改层数、每层中的神经元数量、激活函数,但没有任何帮助。

这是我目前写的代码:

from numpy import loadtxt
from keras.models import Sequential
from keras.layers import Dense
from keras import optimizers

# laod dataset
# dataset = loadtxt('pima-indians-diabetes.csv', delimiter=',')
dataset = loadtxt('sqrt.csv', delimiter=',')

# split into input (X) and output (y) variables
X = dataset[:,0:1] * 1.0
y = dataset[:,1] * 1.0

# define the keras model
model = Sequential()
model.add(Dense(6, input_dim=1, activation='relu'))
model.add(Dense(1, activation='linear'))

# compile the keras model
opt = optimizers.adam(lr=0.01)
model.compile(loss='mean_squared_error', optimizer=opt, metrics=['accuracy'])

# fit the keras model on the dataset (CPU)
model.fit(X, y, epochs=150, batch_size=10, verbose=0)

# evaluate the keras model
_, accuracy = model.evaluate(X, y, verbose=0)
print('Accuracy: %.2f' % (accuracy*100))

# make class predictions with the model
predicitions = model.predict_classes(X)

# summarize the first 10 cases
for i in range(10):
    print('%s => %.2f (expected %.2f)' % (X[i].tolist(), predicitions[i], y[i]))

这是数据集:

1,1
4,2
9,3
16,4
25,5
36,6
49,7
64,8
81,9
100,10

当我运行这个网络时,我得到以下结果:

[1.0] => 0.00 (expected 1.00)
[4.0] => 0.00 (expected 2.00)
[9.0] => 1.00 (expected 3.00)
[16.0] => 1.00 (expected 4.00)
[25.0] => 1.00 (expected 5.00)
[36.0] => 1.00 (expected 6.00)
[49.0] => 1.00 (expected 7.00)
[64.0] => 1.00 (expected 8.00)
[81.0] => 1.00 (expected 9.00)
[100.0] => 1.00 (expected 10.00)

我做错了什么?

【问题讨论】:

    标签: python tensorflow machine-learning keras neural-network


    【解决方案1】:

    这是一个回归问题。所以你应该使用model.predict() 而不是model.predict_classes()

    数据集也不够大。但是,您可以使用以下代码获得一些明智的预测。

    from numpy import loadtxt
    from keras.models import Sequential
    from keras.layers import Dense
    from keras import optimizers
    
    # laod dataset
    # dataset = loadtxt('pima-indians-diabetes.csv', delimiter=',')
    dataset = loadtxt('sqrt.csv', delimiter=',')
    
    # split into input (X) and output (y) variables
    X = dataset[:,0:1] * 1.0
    y = dataset[:,1] * 1.0
    
    # define the keras model
    model = Sequential()
    model.add(Dense(6, input_dim=1, activation='relu'))
    model.add(Dense(10, activation='relu'))
    model.add(Dense(1))
    
    # compile the keras model
    opt = optimizers.adam(lr=0.001)
    model.compile(loss='mean_squared_error', optimizer=opt)
    
    # fit the keras model on the dataset (CPU)
    model.fit(X, y, epochs=1500, batch_size=10, verbose=0)
    
    # evaluate the keras model
    _, accuracy = model.evaluate(X, y, verbose=0)
    print('Accuracy: %.2f' % (accuracy*100))
    
    # make class predictions with the model
    predicitions = model.predict(X)
    
    # summarize the first 10 cases
    for i in range(10):
        print('%s => %.2f (expected %.2f)' % (X[i].tolist(), predicitions[i], y[i]))
    

    输出:

    [1.0] => 1.00 (expected 1.00)
    [4.0] => 2.00 (expected 2.00)
    [9.0] => 3.32 (expected 3.00)
    [16.0] => 3.89 (expected 4.00)
    [25.0] => 4.61 (expected 5.00)
    [36.0] => 5.49 (expected 6.00)
    [49.0] => 6.52 (expected 7.00)
    [64.0] => 7.72 (expected 8.00)
    [81.0] => 9.07 (expected 9.00)
    [100.0] => 10.58 (expected 10.00)
    

    编辑:

    正如@desertnaut 在 cmets 中指出的那样,指标 accuracy 在回归任务中没有任何意义。因此,通常使用自定义 R_squared 值(AKA 确定系数)作为指标。 R_squared 值表示回归模型的拟合优度。下面是计算R_squared的代码。

    def r_squared(y_true, y_pred):
        from keras import backend as K
        SS_res =  K.sum(K.square(y_true - y_pred)) 
        SS_tot = K.sum(K.square(y_true - K.mean(y_true))) 
        return ( 1 - SS_res/(SS_tot + K.epsilon()) )
    

    现在,你可以编译模型了;

    model.compile(loss='mean_squared_error', optimizer=opt, metrics=[r_squared])
    

    【讨论】:

    猜你喜欢
    • 2015-04-24
    • 2010-12-09
    • 2021-01-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-03-17
    • 2020-02-25
    • 2018-10-10
    相关资源
    最近更新 更多