【发布时间】:2021-02-06 10:53:45
【问题描述】:
我正在使用 EfficentDet 和 Tensorflow 对象检测 API,但遇到了一些问题 更改配置文件。
这很好用:
config_dic = config_util.get_configs_from_pipeline_file(fpath)
config_dic["model"].ssd.num_classes = len(LabelMap)
config_dic["model"].ssd.image_resizer.keep_aspect_ratio_resizer.min_dimension = 512
config_dic["train_config"].batch_size = 1
config_dic["train_config"].fine_tune_checkpoint = os.path.join(path, model_name, "checkpoint/ckpt-0")
config_dic["train_config"].fine_tune_checkpoint_type = "detection"
config_dic["train_config"].use_bfloat16 = False # Set to True if training on a TPU
config_dic["train_config"].num_steps = 10000
config_dic["train_input_config"].label_map_path = path_label
config_dic["train_input_config"].tf_record_input_reader.input_path[:] = train_data
config_dic["eval_input_configs"][0].label_map_path = path_label
config_dic["eval_input_configs"][0].tf_record_input_reader.input_path[:] = valid_data
config_dic["model"].ssd.image_resizer.keep_aspect_ratio_resizer.pad_to_max_dimension = False
config_dic["model"].ssd.image_resizer.keep_aspect_ratio_resizer.max_dimension = 1024
但是运行它会给我一个错误:
config_dic["train_config"].data_augmentation_options.random_horizontal_flip = False
config_dic["train_config"].data_augmentation_options.random_adjust_brightness = 0.4
config_dic["train_config"].data_augmentation_options.random_adjust_contrast = [0.6, 1.5]
config_dic["train_config"].data_augmentation_options.random_jitter_boxes = 0.1
config_dic["train_config"].data_augmentation_options.random_rotation90 = 0.5
'google.protobuf.pyext._message.RepeatedCompositeContainer' object has no attribute 'random_horizontal_flip'
以及其他类似的错误。
(我尝试使用整数而不是 False 并得到相同的错误消息)
这很奇怪,因为我在配置文件中有这个:
train_config {
batch_size: 1
data_augmentation_options {
random_horizontal_flip {
}
}
有人知道怎么解决吗?
编辑(添加了一些要求的信息):
我正在使用 tensorflow 2.4.1 和 protobuf 3.12.2
配置文件更改后出现的唯一代码是:
# Save changes
config = config_util.create_pipeline_proto_from_configs(config_dic)
config_util.save_pipeline_config(config, dst)
# Train
!python workspace/model_main_tf2.py --model_dir=$dst --pipeline_config_path=$fpath
【问题讨论】:
标签: python tensorflow object-detection