【问题标题】:How to get a proper prediction from an neural net trained on MNIST from kaggle?如何从 kaggle 的 MNIST 训练的神经网络中获得正确的预测?
【发布时间】:2019-06-27 08:38:16
【问题描述】:

我已经在来自 kaggle 的 MNIST 数据集上训练了一个神经网络。我无法让神经网络预测它正在接收的数字。

我不知道该如何解决这个问题。

'''蟒蛇

    import pandas as pd
    from tensorflow import keras
    import matplotlib.pyplot as plt
    import numpy as np


    mnist=pd.read_csv(r"C:\Users\Chandrasang\python projects\digit-recognizer\train.csv").values
    xtest=pd.read_csv(r"C:\Users\Chandrasang\python projects\digit-recognizer\test.csv").values

    ytrain=mnist[:,0]
    xtrain=mnist[:,1:]

    x_train=keras.utils.normalize(xtrain,axis=1)
    x_test=keras.utils.normalize(xtest,axis=1)

    x=0
    xtrain2=[]
    while True:
        d=x_train[x]
        d.shape=(28,28)
        xtrain2.append(d)
        x+=1
        if x==42000:
            break

    y=0
    xtest2=[]
    while True:
        b=x_test[y]
        b.shape=(28,28)
        xtest2.append(b)
        y+=1
        if y==28000:
            break

    train=np.array(xtrain2,dtype=np.float32)
    test=np.array(xtest2,dtype=np.float32)

    model=keras.models.Sequential()
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(256,activation=keras.activations.relu))
    model.add(keras.layers.Dense(256,activation=keras.activations.relu))
    model.add(keras.layers.Dense(10,activation=keras.activations.softmax))

    model.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 metrics=['accuracy'])
    model.fit(train,ytrain,epochs=10)

    ans=model.predict(x_test)
    print(ans[3])

'''

我希望输出是一个整数,而不是它给我以下数组:

[2.7538205e-02 1.0337318e-11 2.9973364e-03 5.7095995e-06 1.6916725e-07 6.9060135e-08 1.3406207e-09 1.1861910e-06 1.4758119e-06 9.6945578e-01]

【问题讨论】:

  • 试试ans = model.predict_classes(x_test)
  • 或者,您可以使用numpy.ndarray.argmax...ans.argmax(1)

标签: python pandas numpy keras


【解决方案1】:

你的输出是正常的,它是一个概率向量。您有 10 个类别(数字从 0 到 9),并且您的网络计算您的图像在每个类别中的概率。查看您的结果,您的网络将您的输入分类为 9,概率约为 0.96。

如果您只想查看预测的课程,如 Chris A. 所说,请使用 predict_classes

【讨论】:

    猜你喜欢
    • 2017-06-19
    • 1970-01-01
    • 2017-08-14
    • 2018-05-05
    • 1970-01-01
    • 2018-11-26
    • 2019-11-16
    • 2018-10-30
    • 2020-12-18
    相关资源
    最近更新 更多