改变我创建 TF 记录的方式对我来说很有效。看看下面的代码 -
example = tf.train.Example(
features= tf.train.Features(
feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename_str.encode('utf-8')),
'image/source_id': dataset_util.bytes_feature(filename_str.encode('utf-8')),
'image/format': dataset_util.bytes_feature(image_format),
'image/encoded': dataset_util.bytes_feature(image_data),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(labels_text),
'image/object/class/label': dataset_util.int64_list_feature(labels),
}
)
)
确保 TF 记录具有与上面显示的相同的键。这是因为您使用的模型需要与上述类似的键。我希望这会有所帮助。
早些时候,我使用了以下方法,但没有成功-
example = tf.train.Example(
features= tf.train.Features(
feature={
'image/height': dataset_util.int64_feature(shape[0]),
'image/width': dataset_util.int64_feature(shape[1]),
'image/channels': dataset_util.int64_feature(shape[2]),
'image/shape': dataset_util.int64_list_feature(shape),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/bbox/class/label': dataset_util.int64_list_feature(labels),
'image/object/bbox/class/text': dataset_util.bytes_list_feature(labels_text),
'image/object/bbox/difficult': dataset_util.int64_list_feature(difficult),
'image/object/bbox/truncated': dataset_util.int64_list_feature(truncated),
'image/format': dataset_util.bytes_feature(image_format),
'image/encoded': dataset_util.bytes_feature(image_data),
'image/filename': dataset_util.bytes_feature(filename_str.encode('utf-8')),
'image/source_id': dataset_util.bytes_feature(filename_str.encode('utf-8'))
}
)
)
如您所见,我写的是 image/object/bbox/class/label 而不是 image/object/class/label。我希望这会有所帮助。
您可以查看以下链接 -https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md