【发布时间】:2018-06-21 07:24:10
【问题描述】:
我正在尝试使用 tf.nn.max_pool_with_argmax() 的 argmax 结果来索引另一个张量。为简单起见,假设我正在尝试实现以下内容:
output, argmax = tf.nn.max_pool_with_argmax(input, ksize, strides, padding)
tf.assert_equal(input[argmax],output)
现在我的问题是如何实现必要的索引操作input[argmax] 以达到预期的结果?我猜这涉及tf.gather_nd() 和相关调用的一些用法,但我无法弄清楚。如有必要,我们可以假设输入具有[BatchSize, Height, Width, Channel] 维度。
感谢您的帮助!
垫子
【问题讨论】:
-
您是否检查过 input[argmax] 本身不起作用? Tensorflow 支持相对先进的跨步切片(如 numpy),因此这可能会起作用(尽管您可能需要将其应用于填充输入)。
-
是的,我查过了。那行不通……
标签: python tensorflow indexing convolution max-pooling