【问题标题】:tensor_scatter_nd_update ValueError: Shapes must be equal rank, but are 0 and 1tensor_scatter_nd_update ValueError: Shapes must be equal rank, but are 0 and 1
【发布时间】:2021-01-11 22:56:31
【问题描述】:

我一直能够使用tf.tensor_scatter_nd_update 写入张量而没有任何问题,但我无法弄清楚为什么它不适用于某些特定的张量。

举个简单的例子,假设我想根据布尔掩码mask=[[1 0 1]]input=[[0 0 0]] 中的某些值设置为update=[[1 2 3]]。 我会这样做:

input=tf.tensor_scatter_nd_update(input,tf.where(mask),update)

期望操作的结果是input=[[1 0 3]]

相反,我得到了

ValueError: Dimensions [2,2) of input[shape=[1,3]] = [] must match dimensions [1,2) of updates[shape=[1,3]] = [3]: Shapes must be equal rank, but are 0 and 1 for ... with input shapes: [1,3], [?,2], [1,3].

我真的不知道出了什么问题;即使在更复杂的情况下,我也始终能够毫无问题地使用该功能。

【问题讨论】:

  • 粗略地说:这不是想把 3 个值塞进 2 个点吗?

标签: python tensorflow


【解决方案1】:

我想通了。

问题的一部分确实是tf.where() 返回了一个二维张量,但这是因为我使用它来生成updates 向量:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.where(something_else))

解决方案是通过以下方式删除多余的维度:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.squeeze(tf.where(something_else)))

【讨论】:

  • something_else 是什么意思?如果是update,那就不行了
  • 使用tf.squeeze 不能成为答案 - 它只是将您的玩具示例转换为 1D。仅当第一个维度为 1 时才有效。
猜你喜欢
  • 1970-01-01
  • 2019-07-15
  • 1970-01-01
  • 2020-12-04
  • 2017-12-03
  • 2022-07-16
  • 2021-12-29
  • 1970-01-01
  • 2021-05-04
相关资源
最近更新 更多