【发布时间】:2019-06-17 16:02:23
【问题描述】:
遵循tf.case 文档中的这个示例:
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
default=f3, exclusive=True)
我也想这样做,但允许使用 feed_dict 作为输入,如下图所示:
x = tf.placeholder(tf.float32, shape=[None])
y = tf.placeholder(tf.float32, shape=[None])
z = tf.placeholder(tf.float32, shape=[None])
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
default=f3, exclusive=True)
print(sess.run(r, feed_dict={x: [0, 1, 2, 3], y: [1, 1, 1, 1], z: [2, 2, 2, 2]}))
# result should be [17, -1, -1, 23]
所以,基本上我想输入三个长度相等的int-array 并接收一个包含 17、23 或 -1 的 int-values 数组。不幸的是,上面的代码给出了错误:
ValueError: Shape 必须为 0 级,但对于输入形状为 [?]、[?] 的“case/cond/Switch”(操作:“Switch”)为 1 级。
我明白,tf.case 需要布尔标量张量输入值,但有什么方法可以实现我想要的吗?我也试过tf.cond,但没有成功。
【问题讨论】:
标签: python tensorflow