【问题标题】:Iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function. Function to dataset在 Graph 执行中不允许迭代 `tf.Tensor`。使用 Eager 执行或使用 @tf.function 修饰此函数。数据集的功能
【发布时间】:2022-01-13 14:44:18
【问题描述】:

我正在处理以下数据集:

import tensorflow_datasets as tfds
x, y = tfds.load('uc_merced', split=['train[:70%]', 'train[70%:]'], as_supervised=True)

数据集由图像组成,我正在尝试填充其中一些,因为它们的形状并不完全相同。

def PaddingCorrection(a, b):
  if any(tf.shape(a) != tf.constant([256,256,3])):
      a = tf.pad(a,
                   [[256 - tf.shape(a)[0], 0] ,
                    [256 - tf.shape(a)[1], 0], 
                    [0,0]],
                    "CONSTANT")

  return a, b

x = x.map(PaddingCorrection)

但是当我使用地图功能时,我得到了这个错误:

OperatorNotAllowedInGraphError: iterating over tf.Tensor is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

我该如何解决?

【问题讨论】:

  • 回答有用吗?

标签: python tensorflow tensorflow-datasets


【解决方案1】:

尝试仅使用Tensorflow 操作,如tf.cond

import tensorflow_datasets as tfds
import tensorflow as tf

(x, y) = tfds.load('uc_merced', split=['train[:70%]', 'train[70%:]'], as_supervised=True)

def padding_corretion(a, b):
  a = tf.cond(tf.math.reduce_all(tf.not_equal(tf.shape(a), tf.constant([256, 256, 3]))), 
                              lambda: a,
                              lambda: tf.pad(a,
                                          [[256 - tf.shape(a)[0], 0] ,
                                            [256 - tf.shape(a)[1], 0], 
                                            [0,0]], "CONSTANT"))
  return a, b

x = x.map(padding_corretion)

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2020-10-11
    • 2020-05-04
    • 2021-09-30
    • 1970-01-01
    • 2021-07-03
    • 2018-06-20
    • 2022-01-23
    • 2019-11-21
    相关资源
    最近更新 更多