【发布时间】:2017-01-26 03:51:24
【问题描述】:
我刚开始学习cntk。但是,我有一个基本问题阻碍了我的进步。我有以下测试通过:
import numpy as np
from cntk import input_variable, plus
def test_simple(self):
x_input = np.asarray([[1, 2, 2]], dtype=np.int64)
assert (1, 3) == x_input.shape
y_input = np.asarray([[5, 3, 3]], dtype=np.int64)
assert (1, 3) == y_input.shape
x = input_variable(x_input.shape[1])
assert (3, ) == x.shape
y = input_variable(y_input.shape[1])
assert (3, ) == y.shape
x_plus_y = plus(x, y)
assert (3, ) == x_plus_y.shape
res = x_plus_y.eval({x: x_input, y: y_input})
assert 6 == res[0, 0, 0]
assert 5 == res[0, 0, 1]
assert 5 == res[0, 0, 2]
我知道输出的形状是 (1, 1, 3),因为第一和第二轴分别是批处理轴和默认动态轴。
但是,为什么我需要将输入变量的形状设置为 (3,) 而不是 (1, 3)。使用 (1, 3) 失败。
为什么图中输入节点的形状与用作该节点输入的numpy数据不一致?
谢谢你, 稻谷
【问题讨论】:
标签: cntk