【问题标题】:How to replicate numpy.choose() in tensorflow?如何在张量流中复制 numpy.choose()?
【发布时间】:2018-03-13 02:05:24
【问题描述】:

我正在尝试有效地复制 numpy 的 ndarray.choose() 方法。

这是我正在寻找的一个 numpy 示例:

b = np.arange(15).reshape(3, 5)
c = np.array([1,0,4])
c.choose(b.T)  # trying to replicate in tensorflow
  -> array([ 1,  5, 14]) 

我能做的最好的事情是生成一个 batch_size 方阵(如果批量很大,则它很大)并取它的对角线:

tf_b = tf.constant(b)
tf_c = tf.constant(c)
sess.run(tf.diag_part(tf.gather(tf.transpose(tf_b), tf_c)))
-> array([ 1,  5, 14])

有没有办法做到这一点,在第一维中只是线性的(而不是平方)?

【问题讨论】:

  • 您的numpy 代码等同于b[np.arange(3),c]choose 有一条说明不鼓励使用单个数组(如您的 b.T)为 choices
  • numpy,这个索引的一维版本是b.flat[np.arange(b.shape[0])*b.shape[1]+c]

标签: python numpy tensorflow


【解决方案1】:

是的,有一种更简单的方法可以做到这一点。将您的 b 数组展平为一维数组,因此它是 [0, 1, 2, ..., 13, 14]。采用一系列索引,这些索引在您正在采用的“选择”数量范围内(在您的情况下为 3)。那将是[0, 1, 2]。将此范围乘以原始形状的第二个维度,即每个选项的选项数(在您的情况下为 5 个)。这给了你[0, 5, 10]。然后将您的索引添加到此以获得[1, 5, 14]。现在你可以调用 tf.gather()。

这是我从here 获取的一些代码,它们对 RNN 输出执行类似的操作。你的会略有不同,但想法是一样的。

index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, out_size])
relevant = tf.gather(flat, index)
return relevant

总体而言,操作非常简单。您使用范围操作来获取每行开头的索引,然后添加您在每行中的位置的索引。我认为在 1D 中进行操作是最简单的,这就是我们将其展平的原因。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-04-23
    • 1970-01-01
    • 1970-01-01
    • 2017-07-31
    • 2016-08-10
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多