【问题标题】:Numpy arrays from tuples of arrays for matrix based neural networks用于基于矩阵的神经网络的数组元组中的 Numpy 数组
【发布时间】:2018-06-09 00:01:55
【问题描述】:

为了在神经网络中实现学习,我使用了随机梯度下降,其中小批量通过以下列表理解表示:

mini_batches = [training_data[j:j+mini_batch_size] for j in range(0,len(training_data),mini_batch_size)]

在列表理解中,sn-ps mini_batch_sizetraining_data 源于随机梯度下降法的输入,该方法是神经网络类的一部分。

training_data 由数组元组(x,y) 组成,其中第一个数组x 包含输入数据,第二个数组y 包含输出数据(分类)。数组x 的形状为(784,1),数组y 的形状为(10,1)。长度len(training_data) 输出50,000。

我想构造两个 Numpy 数组,其中第一个数组是输入数据的矩阵,第二个数组是输出数据的矩阵。这两个阵列将用于学习算法的基于矩阵的方法中。但是,我不确定数组结构的简单实现。非常感谢您在这个方向上的帮助。

mini_batch_size 是一个整数,它指定training_data 的分区方式。切片后的训练数据,即training_data[0],格式如下(元组的前半部分包含从0到1的784个浮点数,后半部分包含10个浮点数):

(array([[ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.01171875],
    [ 0.0703125 ],
    [ 0.0703125 ],
    [ 0.0703125 ],
    [ 0.4921875 ],
    [ 0.53125   ],
    [ 0.68359375],
    [ 0.1015625 ],
    [ 0.6484375 ],
    [ 0.99609375],
    [ 0.96484375],
    [ 0.49609375],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.1171875 ],
    [ 0.140625  ],
    [ 0.3671875 ],
    [ 0.6015625 ],
    [ 0.6640625 ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.87890625],
    [ 0.671875  ],
    [ 0.98828125],
    [ 0.9453125 ],
    [ 0.76171875],
    [ 0.25      ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.19140625],
    [ 0.9296875 ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98046875],
    [ 0.36328125],
    [ 0.3203125 ],
    [ 0.3203125 ],
    [ 0.21875   ],
    [ 0.15234375],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.0703125 ],
    [ 0.85546875],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.7734375 ],
    [ 0.7109375 ],
    [ 0.96484375],
    [ 0.94140625],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.3125    ],
    [ 0.609375  ],
    [ 0.41796875],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.80078125],
    [ 0.04296875],
    [ 0.        ],
    [ 0.16796875],
    [ 0.6015625 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.0546875 ],
    [ 0.00390625],
    [ 0.6015625 ],
    [ 0.98828125],
    [ 0.3515625 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.54296875],
    [ 0.98828125],
    [ 0.7421875 ],
    [ 0.0078125 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.04296875],
    [ 0.7421875 ],
    [ 0.98828125],
    [ 0.2734375 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.13671875],
    [ 0.94140625],
    [ 0.87890625],
    [ 0.625     ],
    [ 0.421875  ],
    [ 0.00390625],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.31640625],
    [ 0.9375    ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.46484375],
    [ 0.09765625],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.17578125],
    [ 0.7265625 ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.5859375 ],
    [ 0.10546875],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.0625    ],
    [ 0.36328125],
    [ 0.984375  ],
    [ 0.98828125],
    [ 0.73046875],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.97265625],
    [ 0.98828125],
    [ 0.97265625],
    [ 0.25      ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.1796875 ],
    [ 0.5078125 ],
    [ 0.71484375],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.80859375],
    [ 0.0078125 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.15234375],
    [ 0.578125  ],
    [ 0.89453125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.9765625 ],
    [ 0.7109375 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.09375   ],
    [ 0.4453125 ],
    [ 0.86328125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.78515625],
    [ 0.3046875 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.08984375],
    [ 0.2578125 ],
    [ 0.83203125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.7734375 ],
    [ 0.31640625],
    [ 0.0078125 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.0703125 ],
    [ 0.66796875],
    [ 0.85546875],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.76171875],
    [ 0.3125    ],
    [ 0.03515625],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.21484375],
    [ 0.671875  ],
    [ 0.8828125 ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.953125  ],
    [ 0.51953125],
    [ 0.04296875],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.53125   ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.828125  ],
    [ 0.52734375],
    [ 0.515625  ],
    [ 0.0625    ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ]], dtype=float32), array([[ 0.],
    [ 0.],
    [ 0.],
    [ 0.],
    [ 0.],
    [ 1.],
    [ 0.],
    [ 0.],
    [ 0.],
    [ 0.]]))

数据来自 MNIST 训练集。请注意,training_data = list(training_data) 是在最初加载数据时执行的——稍后会创建网络类。如果不使用list(foo)操作,则数据显示如下:<zip at 0x1d744027948>

【问题讨论】:

标签: python arrays python-3.x numpy neural-network


【解决方案1】:

这是一个解决方案,为输入数据和目标数据提供一个数组:

input_data_array = np.asarray([input_data.ravel() for input_data, target_data in mini_batch]).T
target_data_array = np.asarray([target_data.ravel() for input_data, target_data in mini_batch]).T

代码使用来自 Numpy 的 ravelasarray 的组合。

参考training_datamnist_loader 模块:

import mnist_loader
training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
training_data = list(training_data)

至于动机,矩阵形式的反向传播更快。基本的反向传播方程如下所示(来源:Michael Nielsen 的在线文本:Neural networks and deep learning)。内部带点的圆圈是 Hadamard 或 Schur 产品的符号:

【讨论】:

    猜你喜欢
    • 2017-10-30
    • 2018-06-12
    • 2019-05-15
    • 1970-01-01
    • 2018-05-24
    • 1970-01-01
    • 2013-09-24
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多