【问题标题】:Customise train_step in model.fit() "OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function"在 model.fit() 中自定义 train_step “OperatorNotAllowedInGraphError:不允许迭代 `tf.Tensor`:AutoGraph 确实转换了此函数”
【发布时间】:2021-05-26 19:42:34
【问题描述】:

我正在尝试编写一个自定义 train_step 以在 tf.keras.Model.fit() 函数中使用。我关注tensor flow tutorial。根据我的理解,在 train_step 函数中,输入参数数据应该是我即将传入 Model.fit() 函数的训练数据集。我的数据集是 TFRecordDataset。我的数据集给出了三个特定的特征,即图像、标签和框。因此,在 train_step 函数中,我首先尝试从传递的数据参数中获取 img、标签和框参数。

def train_step(self, data):
        print("printing data fed to train_step")
        print(data)
        img, label, gt_boxes = data
        if self.DEBUG:
            if(img == None):
                print("img input in train step is none")
        with tf.GradientTape() as tape:
            rpn_classification, rpn_regression = self(img, training=True)
            self.tf_rpn_target_generation_layer(gt_boxes, rpn_regression)
            loss = self.rpn_loss_function(rpn_classification)
        
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        loss_tracker.update_state(loss)
        #mae_metric.update_state()
        return [loss_tracker]

以上是我用于自定义 train_step 函数的代码。当我运行 fit 时,出现以下错误 OperatorNotAllowedInGraphError:不允许迭代tf.Tensor:AutoGraph 确实转换了此函数。这可能表明您正在尝试使用不受支持的功能。

我在训练数据集上使用了随机播放、缓存和重复操作。谁能帮我理解为什么会出现这个错误?

根据我之前的经验,我通常为数据集创建一个迭代器,然后通过 get_next 操作来获取特征。

编辑: 我尝试了以下程序,但没有产生任何结果

  1. 由于发送到 train_step 的数据是一个数据集对象,我使用 tf.raw_ops.IteratorGetNext 方法来访问返回错误的迭代器的元素 “TypeError: 'IteratorGetNext' Op 的输入'iterator' 的类型字符串与预期的资源类型不匹配。”

  2. 为了修复这个错误,我假设它很可能是 tensorflow 返回迭代器图,因此无法访问元素,所以我在 model.compile() 函数中添加了 run_eagerly=True 参数,该函数返回了正在打印的乱码,并且同样的错误。

Epoch 1/5
printing data fed to train_step
Tensor("Shape:0", shape=(0,), dtype=int32)
Tensor("IteratorGetNext:0", shape=(), dtype=string)

【问题讨论】:

  • 我认为索引也可以工作
  • 我已经尝试过索引,但它不起作用。据我了解,错误可能与使用 model.fit 时作为数据返回的迭代器对象有关。使用 tf.raw_ops 进一步调试时的错误,我了解到提供的迭代器对象是字符串,因此 tf.raw_ops.IteratorGetNext 不起作用,因为它不是预期的输入资源类型。

标签: conv-neural-network tensorflow2.0 tensorflow-datasets tfrecord faster-rcnn


【解决方案1】:

我找到了解决方案。传递给我的步进函数的数据是一个迭代器,因此我必须使用 tf.raw_ops.IteratorGetNext 方法来访问迭代器的内容。

执行此操作时,我最初收到另一个错误,指出迭代器类型与预期的资源类型不匹配,经过仔细调试后,我了解到我必须对数据集执行的 read_tfrecords 映射不成功,导致数据集仍然包含格式为 tf.string 的未映射 tfrecord,这不是 train_Step 的预期资源类型。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2020-07-31
    • 1970-01-01
    • 1970-01-01
    • 2021-09-30
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-04-02
    相关资源
    最近更新 更多