【问题标题】:Tensorflow: Slice a 3D tensor with list of indices along the second axisTensorflow:使用沿第二轴的索引列表对 3D 张量进行切片
【发布时间】:2017-05-12 09:15:48
【问题描述】:

我有一个形状为:[batch_size, sentence_length, word_dim] 的占位符张量和一个带有shape=[batch_size, num_indices] 的索引列表。索引位于第二个轴上,是句子中单词的索引。 Batch_size & sentence_length 仅在运行时已知。

如何提取形状为[batch_size, len(indices), word_dim] 的张量?

我正在阅读有关tensorflow.gather 的信息,但似乎只收集沿第一轴的切片。我对么?

编辑:我设法让它与常量一起工作

def tile_repeat(n, repTime):
    '''
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    '''
    print n, repTime
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, int(repTime)])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y

def gather_along_second_axis(x, idx):
    ''' 
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    '''
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]),  # flatten input
                idx_flattened)
    y = tf.reshape(y, tf.shape(x))
    return y

x = tf.constant([
            [[1,2,3],[3,5,6]],
            [[7,8,9],[10,11,12]],
            [[13,14,15],[16,17,18]]
    ])
idx=tf.constant([[0,1],[1,0],[1,1]])

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print y.eval()
    print tf.Tensor.get_shape(y)

输出是:

[[[ 1  2  3]
  [ 3  5  6]]
 [[10 11 12]
  [ 7  8  9]]
 [[16 17 18]
  [16 17 18]]]

形状:(3, 2, 3)

但是,当输入是占位符时,它不起作用返回错误:

idx = tf.tile(idx, [1, int(repTime)])  
TypeError: int() argument must be a string or a number, not 'Tensor'

Python 2.7,张量流 0.12

提前谢谢你。

【问题讨论】:

  • 我会展平占位符的前两个维度并计算展平维度中的索引。 tf.gather 完成后就可以了。
  • 嗨@AllenLavoie:谢谢。我也从这里得到了这个想法[github.com/tensorflow/tensorflow/issues/206].但是,我无法使其适用于占位符输入。你能快速看一下我编辑的问题吗?
  • 您可以使用tf.shape 将维度作为整数张量获取。即使静态形状信息不可用,这也有效。
  • 谢谢。我得到另一个错误: idx = tf.tile(idx, [1, int(repTime)]) TypeError: int() argument must be a string or a number, not 'Tensor' 而且我不确定我们是否可以得到tf.tile 可以处理运行时维度吗?
  • 是的,tile 适用于张量维度。如果repTime 已经是一个整数张量,您可能只需要删除int 演员表(这是试图使其成为一个常规的Python 整数,它不能这样做)。如有必要,您可以使用tf.cast 进行投射。

标签: python tensorflow


【解决方案1】:

感谢@AllenLavoie 的 cmets,我最终可以想出解决方案:

def tile_repeat(n, repTime):
    '''
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    '''
    print n, repTime
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, repTime])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y

def gather_along_second_axis(x, idx):
    ''' 
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    '''
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]),  # flatten input
                idx_flattened)
    y = tf.reshape(y, tf.shape(x))
    return y

x = tf.constant([
            [[1,2,3],[3,5,6]],
            [[7,8,9],[10,11,12]],
            [[13,14,15],[16,17,18]]
    ])
idx=tf.constant([[0,1],[1,0],[1,1]])

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print y.eval()
    print tf.Tensor.get_shape(y)

【讨论】:

    【解决方案2】:

    @Hoa Vu 的回答非常有帮助。该代码适用于示例xidx,即sentence_length == len(indices),但在sentence_length != len(indices) 时会出错。

    我稍微更改了代码,现在它可以在sentence_length >= len(indices) 时工作。

    我在 Python 3.x 上使用新的 xidx 进行了测试。

    def tile_repeat(n, repTime):
        '''
        create something like 111..122..2333..33 ..... n..nn 
        one particular number appears repTime consecutively.
        This is for flattening the indices.
        '''
        idx = tf.range(n)
        idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
        idx = tf.tile(idx, [1, repTime])  # Create multiple columns, each column has one number repeats repTime 
        y = tf.reshape(idx, [-1])
        return y
    
    
    def gather_along_second_axis(x, idx):
        ''' 
        x has shape: [batch_size, sentence_length, word_dim]
        idx has shape: [batch_size, num_indices]
        Basically, in each batch, get words from sentence having index specified in idx
        However, since tensorflow does not fully support indexing,
        gather only work for the first axis. We have to reshape the input data, gather then reshape again
        '''
        reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
        idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(idx)[1]) * tf.shape(x)[1] + reshapedIdx
        y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]),  # flatten input
                    idx_flattened)
        y = tf.reshape(y, [tf.shape(x)[0],tf.shape(idx)[1],tf.shape(x)[2]])
        return y
    
    x = tf.constant([
                [[1,2,3],[1,2,3],[3,5,6],[3,5,6]],
                [[7,8,9],[7,8,9],[10,11,12],[10,11,12]],
                [[13,14,15],[13,14,15],[16,17,18],[16,17,18]]
        ])
    idx=tf.constant([[0,1],[1,2],[0,3]])
    
    y = gather_along_second_axis(x, idx)
    with tf.Session(''):
        print(y.eval())
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-06-02
      • 2020-02-16
      • 1970-01-01
      • 1970-01-01
      • 2018-11-07
      • 1970-01-01
      • 2019-10-11
      • 1970-01-01
      相关资源
      最近更新 更多