【问题标题】:Usage of argmax from tf.nn.max_pool_with_argmax tensorflowtf.nn.max_pool_with_argmax 张量流中 argmax 的使用
【发布时间】:2018-06-21 07:24:10
【问题描述】:

我正在尝试使用 tf.nn.max_pool_with_argmax() 的 argmax 结果来索引另一个张量。为简单起见,假设我正在尝试实现以下内容:

output, argmax = tf.nn.max_pool_with_argmax(input, ksize, strides, padding)
tf.assert_equal(input[argmax],output)

现在我的问题是如何实现必要的索引操作input[argmax] 以达到预期的结果?我猜这涉及tf.gather_nd() 和相关调用的一些用法,但我无法弄清楚。如有必要,我们可以假设输入具有[BatchSize, Height, Width, Channel] 维度。

感谢您的帮助!

垫子

【问题讨论】:

  • 您是否检查过 input[argmax] 本身不起作用? Tensorflow 支持相对先进的跨步切片(如 numpy),因此这可能会起作用(尽管您可能需要将其应用于填充输入)。
  • 是的,我查过了。那行不通……

标签: python tensorflow indexing convolution max-pooling


【解决方案1】:

我找到了一个使用 tf.gather_nd 的解决方案,虽然它看起来不那么优雅,但它确实有效。我使用了unravel_argmax发布的函数here

def unravel_argmax(argmax, shape):
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.stack(output_list)

def max_pool(input, ksize, strides,padding):
    output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
    shape = input.get_shape()
    arg_max = tf.cast(arg_max,tf.int32)
    unraveld = unravel_argmax(arg_max,shape)
    indices = tf.transpose(unraveld,(1,2,3,4,0))
    channels = shape[-1]
    bs = tf.shape(iv.m)[0]
    t1 = tf.range(channels,dtype=arg_max.dtype)[None, None, None, :, None]
    t2 = tf.tile(t1,multiples=(bs,) + tuple(indices.get_shape()[1:-2]) + (1,1))
    t3 = tf.concat((indices,t2),axis=-1)
    t4 = tf.range(tf.cast(bs, dtype=arg_max.dtype))
    t5 = tf.tile(t4[:,None,None,None,None],(1,) + tuple(indices.get_shape()[1:-2].as_list()) + (channels,1))
    t6 = tf.concat((t5, t3), -1)    
    return tf.gather_nd(input,t6) 

如果有人有更优雅的解决方案,我仍然很想知道。

垫子

【讨论】:

    【解决方案2】:

    这个小 sn-p 有效:

    def get_results(data,other_tensor):
        pooled_data, indices = tf.nn.max_pool_with_argmax(data,ksize=[1,ksize,ksize,1],strides=[1,stride,stride,1],padding='VALID',include_batch_in_index=True)
        b,w,h,c = other_tensor.get_shape.as_list()
        other_tensor_pooled = tf.gather(tf.reshape(other_tensor,shape= [b*w*h*c,]),indices)
        return other_tensor_pooled
    

    上面的indices可以用来索引张量。此函数实际上返回扁平索引,并将其与 batch_size > 1 一起使用,您需要将 include_batch_in_index 传递为 True 以获得正确的结果。我在这里假设othertensor 你的批量大小与data. 相同

    【讨论】:

      【解决方案3】:

      我是这样做的:

      def max_pool(input, ksize, strides,padding):
          output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
          shape=tf.shape(output)
          output1=tf.reshape(tf.gather(tf.reshape(input,[-1]),arg_max),shape)
      
          err=tf.reduce_sum(tf.square(tf.subtract(output,output1)))
          return output1, err
      

      【讨论】:

      • 您必须包含参数include_batch_in_index=True 才能在多个图像上正常工作
      猜你喜欢
      • 2021-12-22
      • 2018-11-07
      • 1970-01-01
      • 2018-06-09
      • 2020-04-24
      • 1970-01-01
      • 1970-01-01
      • 2017-12-01
      • 1970-01-01
      相关资源
      最近更新 更多