【发布时间】:2021-03-03 22:17:10
【问题描述】:
考虑两个张量:Ta = [[1,2,3],[1,2,4]], Tb = [[True, False, True],[False, True, False]] 在张量流中。 Tb 表示允许值的 Ta 的位置。我需要将 Ta 的所有允许值带到左侧,例如 Ta_ordered = [[1,3,2],[2,1,4]]。
【问题讨论】:
标签: tensorflow2.0 tensorflow2.x
考虑两个张量:Ta = [[1,2,3],[1,2,4]], Tb = [[True, False, True],[False, True, False]] 在张量流中。 Tb 表示允许值的 Ta 的位置。我需要将 Ta 的所有允许值带到左侧,例如 Ta_ordered = [[1,3,2],[2,1,4]]。
【问题讨论】:
标签: tensorflow2.0 tensorflow2.x
试试这个方法。它基于以必要的顺序对元素进行排序,然后使用tf.gather():
import tensorflow as tf
Ta = [[1,2,3],[1,2,4]]
Tb = [[True, False, True],[False, True, False]]
Ta, Tb = (tf.convert_to_tensor(t) for t in (Ta, Tb))
X, Y = Ta.shape
inds = tf.range(X * Y)
inds = tf.reshape(inds, (X, Y))
adj = tf.cast(Tb, tf.int32) * (X + 1)
inds -= adj # guarantees minimums for marked elements
inds = tf.argsort(inds)
output = tf.gather(Ta, inds, batch_dims=1)
【讨论】: