【发布时间】:2019-08-03 03:15:07
【问题描述】:
我目前正在尝试使用 TensorFlow API (https://github.com/tensorflow/models) 训练分类网络。在为我的数据集创建 TFrecords(存储在 research/slim/data 中)之后,我使用以下命令训练网络:
python research/slim/train_image_classifier.py \
--train_dir=research/slim/training/current_model \
--dataset_name=my_dataset \
--dataset_split_name=train \
--dataset_dir=research/slim/data \
--model_name=vgg_16 \
--checkpoint_path=research/slim/training/vgg_16_2016_08_28/vgg_16.ckpt \
--checkpoint_exclude_scopes=vgg_16/fc7,vgg_16/fc8 \
--trainable_scopes=vgg_16/fc7,vgg_16/fc8 \
--batch_size=5 \
--log_every_n_steps=10 \
--max_number_of_steps=1000 \
这适用于几种分类网络(Inception、ResNet、MobileNet),但对 VGG-Net 不太适用。我微调了 VGG-Net 16 的以下模型: http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
一般来说,它可以训练这个模型,但是当我训练网络时,损失会增加而不是减少。也许是因为我选择了“checkpoint_exclude_scopes”。
使用最后一个全连接层作为 checkpoint_exclude_scopes 是否正确?
对于参数“output_node_names”,冻结图形也会出现同样的问题。例如,对于 InceptionV3,它适用于“output_node_names=InceptionV3/Predictions/Reshape_1”。但是如何为 VGG-Net 设置这个参数。我尝试了以下方法:
python research/slim/freeze_graph.py
--input_graph=research/slim/training/current_model/graph.pb
--input_checkpoint=research/slim/training/current_model/model.ckpt
--input_binary=true
--output_graph=research/slim/training/current_model/frozen_inference_graph.pb
--output_node_names=vgg_16/fc8
我在 VGG-Net 模型中没有找到任何包含“Predictions”或“Logits”的层,所以我不确定。
感谢您的帮助!
【问题讨论】:
-
它是否适用于 MobileNet,如果是,您在 trainable_scopes、checkpoint_exclude_scopes 中传递了哪些值以及您在 checkpoint_path(即)新数据集的检查点文件或 Mobilenet 的默认检查点文件中使用了哪个检查点文件?你能指导一下吗
-
为什么不给有问题的模型(VGG16)的脚本,而不是 InceptionV3?
-
@Anju Paul - Intel:我刚刚更新了帖子,给出了我用于 VGG16 的脚本命令。
-
@Dinesh:是的,它适用于 MobileNet。这里是我用于 MobileNet v1 的参数: --trainable_scopes=MobilenetV1/Logits --checkpoint_exclude_scopes=MobilenetV1/Logits --checkpoint_path=mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt ___为了冻结图形,我使用了 --output_node_names=MobilenetV1 /Predictions/Reshape_1
标签: python tensorflow classification conv-neural-network vgg-net