【问题标题】:Tensorflow slicing based on variable基于变量的Tensorflow切片
【发布时间】:2016-03-04 07:47:02
【问题描述】:

我发现索引在 tensorflow (#206) 中仍然是一个悬而未决的问题,所以我想知道目前我可以使用什么作为解决方法。我想根据每个训练示例更改的变量对矩阵的行/列进行索引/切片。

到目前为止我已经尝试过:

  1. 基于占位符的切片(不起作用)

以下(工作)代码片段基于固定数字。

import tensorflow as tf
import numpy as np

x = tf.placeholder("float")
y = tf.slice(x,[0],[1])

#initialize
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

#run
result = sess.run(y, feed_dict={x:[1,2,3,4,5]})
print(result)

但是,我似乎不能简单地将这些固定数字之一替换为 tf.placeholder。以下代码给了我错误“TypeError: List of Tensors when single Tensor expected.”

import tensorflow as tf
import numpy as np

x = tf.placeholder("float")
i = tf.placeholder("int32")
y = tf.slice(x,[i],[1])

#initialize
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

#run
result = sess.run(y, feed_dict={x:[1,2,3,4,5],i:0})
print(result)

这听起来像 [i] 周围的括号太多了,但删除它们也无济于事。如何使用占位符/变量作为索引?

  1. 基于 python 变量的切片(不能正确反向传播/更新)

我也尝试过使用普通的 python 变量作为索引。这不会导致错误,但网络在训练时不会学到任何东西。我想是因为更改的变量没有正确注册,所以图表格式错误并且更新不起作用?

  1. 通过 one-hot 向量 + 乘法进行切片(有效,但速度慢)

我发现的一种解决方法是使用单热向量。在 numpy 中创建一个单热向量,使用占位符传递它,然后通过矩阵乘法进行切片。这可行,但速度很慢。

任何想法如何根据变量有效地切片/索引?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    基于占位符的切片应该可以正常工作。由于形状和类型的一些微妙问题,您似乎遇到了类型错误。您有以下情况:

    x = tf.placeholder("float")
    i = tf.placeholder("int32")
    y = tf.slice(x,[i],[1])
    

    ...你应该有:

    x = tf.placeholder("float")
    i = tf.placeholder("int32")
    y = tf.slice(x,i,[1])
    

    ...然后您应该在对sess.run() 的调用中将i 作为[0] 提供。

    为了更清楚一点,我建议重写代码如下:

    import tensorflow as tf
    import numpy as np
    
    x = tf.placeholder(tf.float32, shape=[None])  # 1-D tensor
    i = tf.placeholder(tf.int32, shape=[1])
    y = tf.slice(x, i, [1])
    
    #initialize
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    
    #run
    result = sess.run(y, feed_dict={x: [1, 2, 3, 4, 5], i: [0]})
    print(result)
    

    tf.placeholder 操作的附加 shape 参数有助于确保您提供的值具有适当的形状,并且如果形状不正确,TensorFlow 也会引发错误。

    【讨论】:

    • 它给出了以下错误:ValueError: Shape () must have rank 1
    • 啊,是的,自从我发布这篇文章以来,TensorFlow 对标量和长度为 1 的向量之间的区别变得更加严格。更新了答案以修复它。
    【解决方案2】:

    如果你有一个额外的维度,这行得通。

    import tensorflow as tf
    import numpy as np
    
    def reorder0(e, i, length):
        '''
        e: a two dimensional tensor
        i: a one dimensional int32 tensor, of shape (e.shape[0])
        returns: a tensor of the same shape as e, where the jth entry is entry i[j] from e
        '''
        return tf.concat(
            [ tf.expand_dims( e[i[j],:], axis=0)  for j in range(length) ],
            axis=0
        )
    
    e = tf.placeholder(tf.float32, shape=(2,3,5), name='e' )  # sentences, words, embedding
    i = tf.placeholder(tf.int32, shape=(2,3), name='i' ) # for each word, index of parent
    p = tf.concat(
        [ tf.expand_dims(reorder0(e[k,:,:], i[k,:], 3), axis=0)  for k in range(2) ],
        axis=0,
        name='p'
    )
    
    #initialize
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    
    #run
    result = sess.run(p, feed_dict={
        e: [ 
            ( (1.0,1.1,1.2,1.3,1.4),(2.0,2.1,2.2,2.3,2.4),(3.0,3.1,3.2,3.3,3.4) ), 
            ( (21.0,21.1,21.2,21.3,21.4),(22.0,22.1,22.2,22.3,22.4),(23.0,23.1,23.2,23.3,23.4) ),  
        ], 
        i: [ (1,1,1), (2,0,2)]
    })
    print(result)
    

    【讨论】:

      【解决方案3】:

      如果在构建模型时不知道大小,请使用 TensorArray。

      e = tf.placeholder(tf.float32, shape=(3,5) )  # words, embedding
      i = tf.placeholder(tf.int32, shape=(3) ) # for each word, index of parent
      #p = reorder0(e, i, 3)
      a = tf.TensorArray(
          tf.float32, 
          size=e.get_shape()[0],
          dynamic_size=True,
          infer_shape= True,
          element_shape=e.get_shape()[1],
          clear_after_read = False
      )
      
      
      #initialize
      init = tf.initialize_all_variables()
      sess = tf.Session()
      sess.run(init)
      
      #run
      result = sess.run(
          a.unstack(e).gather(i), 
          feed_dict={
              e: ( (1.0,1.1,1.2,1.3,1.4),(2.0,2.1,2.2,2.3,2.4),(3.0,3.1,3.2,3.3,3.4) ),
                  #( (21.0,21.1,21.2,21.3,21.4),(22.0,22.1,22.2,22.3,22.4),(23.0,23.1,23.2,23.3,23.4) ),  
              i: (2,0,2)
          }
      )
      print(result)
      

      【讨论】:

        猜你喜欢
        • 2018-04-09
        • 2018-05-07
        • 1970-01-01
        • 2018-08-05
        • 2019-10-21
        • 1970-01-01
        • 2020-03-14
        • 2019-11-05
        • 1970-01-01
        相关资源
        最近更新 更多