【发布时间】:2019-06-28 11:16:41
【问题描述】:
我在 numpy 中有这个例子:
import numpy as np
import tensorflow as tf
a = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11 , 12],
[13, 14, 15]])
res = np.zeros((5, 2), dtype=object)
for idx in range(0, len(a)-2, 2):
a0 = a[idx]
a1 = a[idx + 1]
a2 = a[idx + 2]
c = a0 + a1 + a2
res[idx:idx + 2] = ([idx, c])
res
array([[0, array([12, 15, 18])],
[0, array([12, 15, 18])],
[2, array([30, 33, 36])],
[2, array([30, 33, 36])],
[0, 0]], dtype=object)
我想在 tensorflow 中做:
a_tf = tf.convert_to_tensor(a)
res_tf = tf.zeros((5, 2), dtype=object)
for idx in range(0, a.shape[0]-2, 2):
a0 = tf.gather_nd(a, [idx])
a1 = tf.gather_nd(a, [idx + 1])
a2 = tf.gather_nd(a, [idx + 2])
c = a0 + a1 + a2
res = tf.gather_nd([idx, c], [idx:idx +2])
直到与c 的计算一致。
最后一行 (res) 它给了我:
res = tf.gather_nd([idx, c], [idx:idx +2])
^
SyntaxError: invalid syntax
我不确定如何收到结果。
更新
基本上,问题在于[idx, c] 是列表类型并试图做:tf.convert_to_tensor([idx, c],给出:
InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [] != values[1].shape = [3] [Op:Pack] name: packed/
【问题讨论】:
标签: python-3.x tensorflow tensorflow2.0