【发布时间】:2018-06-09 00:01:55
【问题描述】:
为了在神经网络中实现学习,我使用了随机梯度下降,其中小批量通过以下列表理解表示:
mini_batches = [training_data[j:j+mini_batch_size] for j in range(0,len(training_data),mini_batch_size)]
在列表理解中,sn-ps mini_batch_size 和 training_data 源于随机梯度下降法的输入,该方法是神经网络类的一部分。
training_data 由数组元组(x,y) 组成,其中第一个数组x 包含输入数据,第二个数组y 包含输出数据(分类)。数组x 的形状为(784,1),数组y 的形状为(10,1)。长度len(training_data) 输出50,000。
我想构造两个 Numpy 数组,其中第一个数组是输入数据的矩阵,第二个数组是输出数据的矩阵。这两个阵列将用于学习算法的基于矩阵的方法中。但是,我不确定数组结构的简单实现。非常感谢您在这个方向上的帮助。
mini_batch_size 是一个整数,它指定training_data 的分区方式。切片后的训练数据,即training_data[0],格式如下(元组的前半部分包含从0到1的784个浮点数,后半部分包含10个浮点数):
(array([[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.01171875],
[ 0.0703125 ],
[ 0.0703125 ],
[ 0.0703125 ],
[ 0.4921875 ],
[ 0.53125 ],
[ 0.68359375],
[ 0.1015625 ],
[ 0.6484375 ],
[ 0.99609375],
[ 0.96484375],
[ 0.49609375],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.1171875 ],
[ 0.140625 ],
[ 0.3671875 ],
[ 0.6015625 ],
[ 0.6640625 ],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.87890625],
[ 0.671875 ],
[ 0.98828125],
[ 0.9453125 ],
[ 0.76171875],
[ 0.25 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.19140625],
[ 0.9296875 ],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98046875],
[ 0.36328125],
[ 0.3203125 ],
[ 0.3203125 ],
[ 0.21875 ],
[ 0.15234375],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.0703125 ],
[ 0.85546875],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.7734375 ],
[ 0.7109375 ],
[ 0.96484375],
[ 0.94140625],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.3125 ],
[ 0.609375 ],
[ 0.41796875],
[ 0.98828125],
[ 0.98828125],
[ 0.80078125],
[ 0.04296875],
[ 0. ],
[ 0.16796875],
[ 0.6015625 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.0546875 ],
[ 0.00390625],
[ 0.6015625 ],
[ 0.98828125],
[ 0.3515625 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.54296875],
[ 0.98828125],
[ 0.7421875 ],
[ 0.0078125 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.04296875],
[ 0.7421875 ],
[ 0.98828125],
[ 0.2734375 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.13671875],
[ 0.94140625],
[ 0.87890625],
[ 0.625 ],
[ 0.421875 ],
[ 0.00390625],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.31640625],
[ 0.9375 ],
[ 0.98828125],
[ 0.98828125],
[ 0.46484375],
[ 0.09765625],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.17578125],
[ 0.7265625 ],
[ 0.98828125],
[ 0.98828125],
[ 0.5859375 ],
[ 0.10546875],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.0625 ],
[ 0.36328125],
[ 0.984375 ],
[ 0.98828125],
[ 0.73046875],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.97265625],
[ 0.98828125],
[ 0.97265625],
[ 0.25 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.1796875 ],
[ 0.5078125 ],
[ 0.71484375],
[ 0.98828125],
[ 0.98828125],
[ 0.80859375],
[ 0.0078125 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.15234375],
[ 0.578125 ],
[ 0.89453125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.9765625 ],
[ 0.7109375 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.09375 ],
[ 0.4453125 ],
[ 0.86328125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.78515625],
[ 0.3046875 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.08984375],
[ 0.2578125 ],
[ 0.83203125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.7734375 ],
[ 0.31640625],
[ 0.0078125 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.0703125 ],
[ 0.66796875],
[ 0.85546875],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.76171875],
[ 0.3125 ],
[ 0.03515625],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.21484375],
[ 0.671875 ],
[ 0.8828125 ],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.953125 ],
[ 0.51953125],
[ 0.04296875],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.53125 ],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.828125 ],
[ 0.52734375],
[ 0.515625 ],
[ 0.0625 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ]], dtype=float32), array([[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 1.],
[ 0.],
[ 0.],
[ 0.],
[ 0.]]))
数据来自 MNIST 训练集。请注意,training_data = list(training_data) 是在最初加载数据时执行的——稍后会创建网络类。如果不使用list(foo)操作,则数据显示如下:<zip at 0x1d744027948>。
【问题讨论】:
-
看看这是否有帮助 - stackoverflow.com/questions/40084931/…
标签: python arrays python-3.x numpy neural-network