【发布时间】:2019-08-13 22:15:23
【问题描述】:
我很难理解为什么以下两个代码示例会产生不同的结果:
代码 1:
for h in range(n_H):
for w in range(n_W):
# Find indices
vert_start = h * stride # Starting row-index for current slice
vert_end = vert_start + f # Final row-index (+1) for current slice
horiz_start = w * stride # Starting column-index for current slice
horiz_end = horiz_start + f # Final column-index (+1) for current slice
for c in range(n_C):
Aux = (W[:, :, :, c] * Z[:, h, w, c, np.newaxis, np.newaxis, np.newaxis])
A[:, vert_start:vert_end, horiz_start:horiz_end, :] += Aux
代码 2:
for h in range(n_H):
for w in range(n_W):
# Find indices
vert_start = h * stride # Starting row-index for current slice
vert_end = vert_start + f # Final row-index (+1) for current slice
horiz_start = w * stride # Starting column-index for current slice
horiz_end = horiz_start + f # Final column-index (+1) for current slice
Aux = np.zeros((m, f, f, n_CP))
for c in range(n_C):
Aux += (W[:, :, :, c] * Z[:, h, w, c, np.newaxis, np.newaxis, np.newaxis])
A[:, vert_start:vert_end, horiz_start:horiz_end, :] += Aux
在这两种情况下
- n_H、n_W、n_C、n_HP、n_WP、n_CP、m、stride 和 f 是标量
- W 是一个形状数组 (f, f, n_CP, n_C)
- Z 是一个形状为 (m, n_H, n_W, n_C) 的数组
- A 是一个形状为 (m, n_HP, n_WP, n_CP) 的数组
我注意到,当“索引范围”(vert_start:vert_end 和 horiz_start:horiz_end)是标量时,这两种方法会产生相同的结果,即 f=1。但是,我无法弄清楚为什么它也不适用于范围。
您可以在下面找到一个示例,该示例的代码示例会产生不同的输出:
np.random.seed(1)
m = 2
f = 2
stride = 1
n_C = 3
n_CP = 1
n_H = 2
n_W = 2
n_HP = 3
n_WP = 3
W = np.random.randn(f, f, n_CP, n_C)
Z = np.random.rand(m, n_H, n_W, n_C)
A = np.zeros((m, n_HP, n_WP, n_CP))
A2 = np.zeros((m, n_HP, n_WP, n_CP))
for h in range(n_H):
for w in range(n_W):
# Find indices
vert_start = h * stride # Starting row-index for current slice
vert_end = vert_start + f # Final row-index (+1) for current slice
horiz_start = w * stride # Starting column-index for current slice
horiz_end = horiz_start + f # Final column-index (+1) for current slice
for c in range(n_C):
Aux = (W[:, :, :, c] * Z[:, h, w, c, np.newaxis, np.newaxis, np.newaxis])
A[:, vert_start:vert_end, horiz_start:horiz_end, :] += Aux
Aux = np.zeros((m, f, f, n_CP))
for c in range(n_C):
Aux += (W[:, :, :, c] * Z[:, h, w, c, np.newaxis, np.newaxis, np.newaxis])
A2[:, vert_start:vert_end, horiz_start:horiz_end, :] += Aux
print(A == A2)
【问题讨论】:
-
我建议 unsar 函数 print() 以了解代码中发生的情况或逐步执行它。能举例说明错误发生时inputs()的值吗?
-
我已经编辑了我的问题并添加了一个参数值示例,这两种方法提供了不同的结果。
标签: python indexing nested-loops cumulative-sum