【发布时间】:2018-03-13 02:05:24
【问题描述】:
我正在尝试有效地复制 numpy 的 ndarray.choose() 方法。
这是我正在寻找的一个 numpy 示例:
b = np.arange(15).reshape(3, 5)
c = np.array([1,0,4])
c.choose(b.T) # trying to replicate in tensorflow
-> array([ 1, 5, 14])
我能做的最好的事情是生成一个 batch_size 方阵(如果批量很大,则它很大)并取它的对角线:
tf_b = tf.constant(b)
tf_c = tf.constant(c)
sess.run(tf.diag_part(tf.gather(tf.transpose(tf_b), tf_c)))
-> array([ 1, 5, 14])
有没有办法做到这一点,在第一维中只是线性的(而不是平方)?
【问题讨论】:
-
您的
numpy代码等同于b[np.arange(3),c]。choose有一条说明不鼓励使用单个数组(如您的b.T)为choices。 -
在
numpy,这个索引的一维版本是b.flat[np.arange(b.shape[0])*b.shape[1]+c]
标签: python numpy tensorflow