【问题标题】:How to get accurate predictions from a neural network如何从神经网络中获得准确的预测
【发布时间】:2018-05-05 03:31:22
【问题描述】:

我在下面为 3 输入逻辑与门的真值表创建了神经网络,但 [1,1,0] 的预期输出不正确。输出应为 0。但它预测为 0.9,这意味着大约为 1。所以输出不正确。所以我需要知道的是如何使输出预测更准确。请指导我。

import numpy as np

class NeuralNetwork():
    def __init__(self):

        self.X = np.array([[0, 0, 0],
                          [0, 0, 1],
                          [0, 1, 0],
                          [0, 1, 1],
                          [1, 0, 0],
                          [1, 0, 1],
                          [1, 1, 1]])

        self.y = np.array([[0],
                           [0],
                           [0],
                           [0],
                           [0],
                           [0],
                           [1]])

        np.random.seed(1)

        # randomly initialize our weights with mean 0
        self.syn0 = 2 * np.random.random((3, 4)) - 1
        self.syn1 = 2 * np.random.random((4, 1)) - 1

    def nonlin(self,x, deriv=False):
        if (deriv == True):
            return x * (1 - x)

        return 1 / (1 + np.exp(-x))

    def train(self,steps):
        for j in xrange(steps):

            # Feed forward through layers 0, 1, and 2
            l0 = self.X
            l1 = self.nonlin(np.dot(l0, self.syn0))
            l2 = self.nonlin(np.dot(l1, self.syn1))

            # how much did we miss the target value?
            l2_error = self.y - l2

            if (j % 10000) == 0:
                print "Error:" + str(np.mean(np.abs(l2_error)))

            # in what direction is the target value?
            # were we really sure? if so, don't change too much.
            l2_delta = l2_error * self.nonlin(l2, deriv=True)

            # how much did each l1 value contribute to the l2 error (according to the weights)?
            l1_error = l2_delta.dot(self.syn1.T)

            # in what direction is the target l1?
            # were we really sure? if so, don't change too much.
            l1_delta = l1_error * self.nonlin(l1, deriv=True)

            self.syn1 += l1.T.dot(l2_delta)
            self.syn0 += l0.T.dot(l1_delta)

        print("Output after training:")
        print(l2)

    def predict(self,newInput):
        # Multiply the input with weights and find its sigmoid activation for all layers
        layer0 = newInput
        print("predict -> layer 0 : "+str(layer0))
        layer1 = self.nonlin(np.dot(layer0, self.syn0))
        print("predict -> layer 1 : "+str(layer1))
        layer2 = self.nonlin(np.dot(layer1, self.syn1))
        print("predicted output is : "+str(layer2))




if __name__ == '__main__':
    ann=NeuralNetwork()
    ann.train(100000)
    ann.predict([1,1,0])

输出:

Error:0.48402933124
Error:0.00603525276229
Error:0.00407346660344
Error:0.00325224335386
Error:0.00277628698655
Error:0.00245737222701
Error:0.00222508289674
Error:0.00204641406194
Error:0.00190360175536
Error:0.00178613765229
Output after training:
[[  1.36893057e-04]
 [  5.80758383e-05]
 [  1.19857670e-03]
 [  1.85443483e-03]
 [  2.13949603e-03]
 [  2.19360982e-03]
 [  9.95769492e-01]]
predict -> layer 0 : [1, 1, 0]
predict -> layer 1 : [ 0.00998162  0.91479567  0.00690524  0.05241988]
predicted output is : [ 0.99515547]

【问题讨论】:

  • 您的输入是否使用训练后获得的权重正确预测?
  • @NanduKalidindi 这是我需要清除的一点,根据我的理解,权重会自动生成以更准确地猜测输出。所以你在这里问什么我不清楚。如果我错了,请纠正我。
  • 是的,您使用现有输入进行训练以计算可以预测值的权重,而不仅仅是您的输入。验证权重是否正确的一种方法是在所有 8 给定输入上运行 ann.predict() 方法,并将计算值与相应的输出进行交叉检查。
  • @NanduKalidindi 我检查了所有 8 个输入,它不能预测正确的输出,我还不清楚的是,根据我的理解,神经网络中发生的事情是将权重调整为得到我们训练输入集的预期输出。这就是你在上面评论中的意思。那么我该怎么做才能准确地预测网络呢?请指导我。提前谢谢。

标签: python-2.7 numpy neural-network


【解决方案1】:

您在与门中遗漏的每个输入似乎都会发生这种情况。例如,尝试用[1, 1, 0] 替换[0, 1, 1] 输入,然后尝试预测[0, 1, 1],它预测的最终值接近1。我尝试包含biaseslearning rate,但似乎没有任何效果。

就像 Prune 提到的那样,这可能是因为反向传播网络无法使用不完整的模型。

要充分训练您的网络并获得最佳权重,请提供所有可能的输入,即与门的 8 个输入。然后你总能得到正确的预测,因为你已经用这些输入训练了网络,在这种情况下,这可能对预测没有意义。可能是对小数据集的预测效果不佳。

这只是我的猜测,因为我用于预测的几乎所有网络都拥有相当大的数据集。

【讨论】:

    【解决方案2】:

    实际上,它确实产生了正确的输出——模型是模棱两可的。您的输入数据适合A*B;第三个输入的值 从不 影响给定的输出,因此您的模型无法知道它在案例 110 中是否重要。就纯信息论而言,您没有输入强制得到你想要的结果。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2018-06-08
      • 1970-01-01
      • 1970-01-01
      • 2016-07-28
      • 2017-08-31
      • 2017-12-02
      • 1970-01-01
      相关资源
      最近更新 更多