【问题标题】:iterating over `tf.Tensor` is not allowed不允许迭代 `tf.Tensor`
【发布时间】:2021-09-30 13:47:09
【问题描述】:

我正在尝试将此函数用于 @tf.function 装饰:

h 和 h2 是一个形状为 [3,3] 的张量

def fn(h,i):
    print(h[i])
    return h[i]

tensor = [fn(h,i) for i in tf.range(tf.cast(tf.shape(h)[0],tf.int32)) if  tf.reduce_all(tf.equal(h[i],h2[i])) ]
tf.print(tensor)

但我得到了这个错误:

main_coat_rds.py:139 train_step  *
        pseudo_label_1,images_discard_rede1=predict_aug_images(rede_2,rede_1,img_rede1_aug_1,img_rede1_aug_2,img_rede1_aug_3,img_rede1_aug_4,img_rede1_aug_5,img_rede1_aug_6,img_rede1_aug_7,img_rede1_aug_8,images_discard_rede1,Correct_labels)
    /vitor/codigo_noise_label/codigo_rds/utils_loss_function.py:289 predict_aug_images  *
        pred_match = [get_value_labels(all_predics,i) for i in tf.range(tf.cast(tf.shape(all_predics)[0],tf.int32)) if  tf.reduce_all(tf.equal(all_predics[i],all_predics_aj[i])) ]
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:503 __iter__
        self._disallow_iteration()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:496 _disallow_iteration
        self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
        " indicate you are trying to use an unsupported feature.".format(task))

    OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

还有什么方法可以做到?

【问题讨论】:

    标签: python tensorflow tensor tensorflow2


    【解决方案1】:

    当您想根据条件选择张量的某些部分时,一个不错的选择是使用tf.gathertf.where 的组合。

    在这里,例如,要选择hh2 之间相等的行,您可以使用:

    tf.gather_nd(h, tf.where(tf.reduce_all(tf.equal(h, h2),axis=1)))
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-07-03
      • 1970-01-01
      • 2018-08-17
      • 1970-01-01
      • 1970-01-01
      • 2017-08-30
      • 2021-05-26
      • 2018-11-16
      相关资源
      最近更新 更多