【问题标题】:pytorch MNIST neural network produces several non-zero outputspytorch MNIST 神经网络产生几个非零输出
【发布时间】:2021-03-25 00:00:23
【问题描述】:

我尝试做一个在 MNIST 数据集上运行的神经网络。我主要关注 pytorch.nn 教程。结果,我得到了一个可以学习的模型,但是这个过程或模型本身有问题。我在输出端接收到多个神经元,而不是一个活跃的神经元。

这是模型本身:

model = nn.Sequential(
    nn.Linear(784, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
    nn.ReLU(),
)

这是训练过程:

loss_func = nn.CrossEntropyLoss()
opt = optim.SGD(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    for xbt, ybt in train_dl:
        pred = model(xbt)
        loss = loss_func(pred, ybt)
        opt.zero_grad()
        loss.backward()
        opt.step()
        

    model.eval()
    # Validation
    if epoch % 10 == 0:
        with torch.no_grad():
            losses, nums = zip(
                *[(loss_func(model(xbv), ybv), len(xbv)) for xbv, ybv in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, val_loss)

这是每 10 个 epoch 的平均损失:

0 0.13384412774592638
10 0.0900113809091039
20 0.09795805384699234
30 0.10341344920364791
40 0.10804545368137551

这就是将模型应用于验证集的结果如下:

[[ 0.         0.         0.        ... 28.436266   0.         5.001435 ]
 [ 7.3331523 12.666427  31.898096  ...  0.         0.         0.       ]
 [ 0.        18.116354   8.049953  ...  4.330721   0.         0.       ]
 ...
 [ 8.504517   0.         6.302228  ...  0.         0.         0.       ]
 [ 1.7339934  0.         0.        ...  0.         2.1565871  0.       ]
 [45.750134   0.         6.2685804 ...  2.247082   0.         0.       ]]
 Shape: (9984, 10)

我尝试改变学习速度、模型层数、时期数,但似乎没有任何效果。

【问题讨论】:

    标签: python neural-network pytorch mnist


    【解决方案1】:

    你有 10 个神经元在最后一层使用 ReLU,是的,所有神经元都会激发/激活。在这种情况下,每个神经元都在线性激活的输出上应用ReLu 函数。即ReLu(w.x+b)。有 10 个这样的神经元,它们都会根据其输入给出特定的输出,是的,它们都会被激发/激活。从中推断输出的方法是采用与激活最大的神经元相对应的类(使用 np.argmax 或 torch.max)。

    【讨论】:

      【解决方案2】:

      这是绝对正常的: 你的输出应该是[bacth_size, 10] 的形状,因为在每次迭代中你都会给它提供一批batch_size 图像,并且输出层有10 神经元。

      对此的解释如下:

      • 输出张量的每一行都是对批处理中的一个输入图像的预测。
      • 对于一行,如果你只是想分类,你的猜测将是该行的 argmax。例如,如果output[0] = [ 0. , 0. , 0. , 2.1, 3.0, 0., 4., 28.436266 0., 5.001435 ],则意味着您的网络已经预测该图像属于第 n°7 类(28.436266 是最大值,在索引 7 处)。

      现在,您还可以使用某种概率来解释结果。为此,您需要将softmax 层应用于您的输出。然后,值output[i][j] 将被解释为图像i 属于j 类的概率。

      【讨论】:

        猜你喜欢
        • 2018-03-21
        • 1970-01-01
        • 2018-12-16
        • 2017-10-28
        • 1970-01-01
        • 2018-02-02
        • 2015-02-23
        • 2017-06-13
        相关资源
        最近更新 更多