【发布时间】:2019-08-18 23:27:55
【问题描述】:
我正在尝试沿第一个维度切割形状为 (?, 32, 32) 的张量。我必须选择两行索引存储在另一个形状为(1, 2) 的张量中。我想要像 array[list of indexes, :, :] 这样的东西。
我该怎么做?我需要这个操作来计算 model_fn 函数中的损失,传递给我的自定义 Tensorflow Estimator。
【问题讨论】:
标签: tensorflow
我正在尝试沿第一个维度切割形状为 (?, 32, 32) 的张量。我必须选择两行索引存储在另一个形状为(1, 2) 的张量中。我想要像 array[list of indexes, :, :] 这样的东西。
我该怎么做?我需要这个操作来计算 model_fn 函数中的损失,传递给我的自定义 Tensorflow Estimator。
【问题讨论】:
标签: tensorflow
我使用tf.gather_nd 解决了这个问题。我重塑了包含索引的张量:
ids = tf.reshape(tensor_with_indexes, shape=(-1, 1))
然后我申请了:
new_tensor = tf.gather_nd(original_tensor, ids)
【讨论】: