【问题标题】:Imputing axes by index in Tensorflow在 Tensorflow 中按索引输入轴
【发布时间】:2020-09-01 17:49:30
【问题描述】:

我有一个输入的 3D 张量 [batch, n_classes - k, 5] 和一个索引的 2D 张量 [batch, n_classes - k]k可以在[0, n_classes)中,以n_classes=3, k=1为例:

X = tf.constant([
    [[0.36636186, 0.45606998, 0.785176  , 0.19967379, 0. ],
     [0.2799339 , 0.9548653 , 0.7378969 , 0.5543541 , 1. ]],

    [[0.07455064, 0.9868869 , 0.77224475, 0.19871569, 0. ],
     [0.19579114, 0.0693613 , 0.100778  , 0.01822183, 1. ]],

    [[0.684233  , 0.4401525 , 0.12203824, 0.4951769 , 0. ],
     [0.47417384, 0.09783416, 0.49161586, 0.47347176, 0. ]]
])

idcs = tf.constant([
    [0, 2],
    [0, 1],
    [1, 2]
])

idcs 中的元素是类值(索引)。我试图通过将(0, n_classes) 范围内的缺失索引设置为零向量来将X 沿轴1 估算,即

tf.constant([
    [[0.36636186, 0.45606998, 0.785176  , 0.19967379, 0. ],
     [0.        , 0.        , 0.        , 0.        , 0. ],  # missing 1 in `idcs`
     [0.2799339 , 0.9548653 , 0.7378969 , 0.5543541 , 1. ]],

    [[0.07455064, 0.9868869 , 0.77224475, 0.19871569, 0. ],
     [0.19579114, 0.0693613 , 0.100778  , 0.01822183, 1. ],
     [0.        , 0.        , 0.        , 0.        , 0. ]], # missing 2 in `idcs`

    [[0.        , 0.        , 0.        , 0.        , 0. ],  # missing 0 in `idcs`
     [0.684233  , 0.4401525 , 0.12203824, 0.4951769 , 0. ],
     [0.47417384, 0.09783416, 0.49161586, 0.47347176, 0. ]]
])

我不太清楚如何在 tensorflow 中表达这一点。我考虑创建一个零张量 [batch, n_classes, 5] 并将当前索引分配给沿轴 1 的 X,但在张量中不允许分配。有没有简单的方法在 tensorflow 中实现这一点?

例如,如果我要在 Numpy 中表达这一点,我可能会尝试类似:

X = np.array([
    [[0.36636186, 0.45606998, 0.785176  , 0.19967379, 0. ],
     [0.2799339 , 0.9548653 , 0.7378969 , 0.5543541 , 1. ]],

    [[0.07455064, 0.9868869 , 0.77224475, 0.19871569, 0. ],
     [0.19579114, 0.0693613 , 0.100778  , 0.01822183, 1. ]],

    [[0.684233  , 0.4401525 , 0.12203824, 0.4951769 , 0. ],
     [0.47417384, 0.09783416, 0.49161586, 0.47347176, 0. ]]
])

idcs = np.array([
    [0, 2],
    [0, 1],
    [1, 2]
])

n_classes = 3
batch_size = 3

# selectors
x = np.repeat(np.arange(idcs.shape[0]), 2)  # [0, 0, 1, 1, 2, 2]
y = idcs.ravel()  # [0, 2, 0, 1, 1, 2]

z = np.zeros((batch_size, n_classes, 5))
z[x, y] = np.reshape(X, [x.shape[0], 5])
z

# array([[[0.36636186, 0.45606998, 0.785176  , 0.19967379, 0.        ],
#         [0.        , 0.        , 0.        , 0.        , 0.        ],
#         [0.2799339 , 0.9548653 , 0.7378969 , 0.5543541 , 1.        ]],
# 
#        [[0.07455064, 0.9868869 , 0.77224475, 0.19871569, 0.        ],
#         [0.19579114, 0.0693613 , 0.100778  , 0.01822183, 1.        ],
#         [0.        , 0.        , 0.        , 0.        , 0.        ]],
# 
#        [[0.        , 0.        , 0.        , 0.        , 0.        ],
#         [0.684233  , 0.4401525 , 0.12203824, 0.4951769 , 0.        ],
#         [0.47417384, 0.09783416, 0.49161586, 0.47347176, 0.        ]]])

【问题讨论】:

  • 我认为这是您正在寻找的功能 tensorflow.org/api_docs/python/tf/scatter_nd 虽然对我来说它的使用是不直观的
  • 感谢 @jakub 在让索引按照我想要的方式与 scatter_nd 对齐时遇到了一些麻烦。弄清楚如何使用稀疏张量来完成

标签: python tensorflow


【解决方案1】:

我通过将张量转换为稀疏然后立即恢复为密集的方式解决了这个问题:

batch_size, n_inputs, _ = X.shape.as_list()
n_classes = 3

sparse_indices = tf.concat([
        tf.reshape(tf.repeat(tf.range(batch_size, dtype=tf.int64), n_inputs * 5), [-1, 1]),
        tf.reshape(tf.repeat(idcs, 5), [-1, 1]),
        tf.reshape(tf.tile(tf.range(5, dtype=tf.int64), [n_inputs * batch_size]), [-1, 1]),
    ],
    axis=1
)

# ravel X to 1d, create a sparse tensor for non-zero indices and then
# expand back to dense as a hack for filling in the zeros
X_ravel = tf.reshape(X, shape=[-1])
tf.sparse.to_dense(
    tf.sparse.SparseTensor(
        sparse_indices,
        X_ravel,
        dense_shape=[batch_size, n_classes, 5],
    ),
)

正如预期的那样,结果:

<tf.Tensor: shape=(3, 3, 5), dtype=float32, numpy=
array([[[0.36636186, 0.45606998, 0.785176  , 0.19967379, 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.2799339 , 0.9548653 , 0.7378969 , 0.5543541 , 1.        ]],

       [[0.07455064, 0.9868869 , 0.77224475, 0.19871569, 0.        ],
        [0.19579114, 0.0693613 , 0.100778  , 0.01822183, 1.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.47417384, 0.09783416, 0.49161586, 0.47347176, 0.        ],
        [0.684233  , 0.4401525 , 0.12203824, 0.4951769 , 0.        ]]],
      dtype=float32)>

【讨论】:

    猜你喜欢
    • 2021-10-03
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-02-16
    • 1970-01-01
    • 1970-01-01
    • 2014-12-03
    相关资源
    最近更新 更多