【发布时间】:2021-05-26 19:42:34
【问题描述】:
我正在尝试编写一个自定义 train_step 以在 tf.keras.Model.fit() 函数中使用。我关注tensor flow tutorial。根据我的理解,在 train_step 函数中,输入参数数据应该是我即将传入 Model.fit() 函数的训练数据集。我的数据集是 TFRecordDataset。我的数据集给出了三个特定的特征,即图像、标签和框。因此,在 train_step 函数中,我首先尝试从传递的数据参数中获取 img、标签和框参数。
def train_step(self, data):
print("printing data fed to train_step")
print(data)
img, label, gt_boxes = data
if self.DEBUG:
if(img == None):
print("img input in train step is none")
with tf.GradientTape() as tape:
rpn_classification, rpn_regression = self(img, training=True)
self.tf_rpn_target_generation_layer(gt_boxes, rpn_regression)
loss = self.rpn_loss_function(rpn_classification)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
loss_tracker.update_state(loss)
#mae_metric.update_state()
return [loss_tracker]
以上是我用于自定义 train_step 函数的代码。当我运行 fit 时,出现以下错误
OperatorNotAllowedInGraphError:不允许迭代tf.Tensor:AutoGraph 确实转换了此函数。这可能表明您正在尝试使用不受支持的功能。
我在训练数据集上使用了随机播放、缓存和重复操作。谁能帮我理解为什么会出现这个错误?
根据我之前的经验,我通常为数据集创建一个迭代器,然后通过 get_next 操作来获取特征。
编辑: 我尝试了以下程序,但没有产生任何结果
-
由于发送到 train_step 的数据是一个数据集对象,我使用 tf.raw_ops.IteratorGetNext 方法来访问返回错误的迭代器的元素 “TypeError: 'IteratorGetNext' Op 的输入'iterator' 的类型字符串与预期的资源类型不匹配。”
-
为了修复这个错误,我假设它很可能是 tensorflow 返回迭代器图,因此无法访问元素,所以我在 model.compile() 函数中添加了 run_eagerly=True 参数,该函数返回了正在打印的乱码,并且同样的错误。
Epoch 1/5
printing data fed to train_step
Tensor("Shape:0", shape=(0,), dtype=int32)
Tensor("IteratorGetNext:0", shape=(), dtype=string)
【问题讨论】:
-
我认为索引也可以工作
-
我已经尝试过索引,但它不起作用。据我了解,错误可能与使用 model.fit 时作为数据返回的迭代器对象有关。使用 tf.raw_ops 进一步调试时的错误,我了解到提供的迭代器对象是字符串,因此 tf.raw_ops.IteratorGetNext 不起作用,因为它不是预期的输入资源类型。
标签: conv-neural-network tensorflow2.0 tensorflow-datasets tfrecord faster-rcnn