用Tensorflow Object Detection API 训练自己的mask R-CNN模型
前言
工作有这个需求但是网上似乎找不到用Object Detection API 训练Mask R-CNN的完整教程,所以把它写出来希望能帮到有需要的人。
训练使用Mask R-CNN Inception V2模型,这篇博客Building a Custom Mask RCNN model with Tensorflow Object Detection介绍了完整的步骤,但是它提供的数据和脚本有错误,会产生错误的record文件导致无法完成训练。因此我fork了这个库【github】并做了一些修改,已经能够产生可用的record文件了。
准备训练数据
Jpg图片是必须准备的了,另外还会用到描述图片基本信息的xml文件(我修改的脚本中bonding box由png图片产生)和表示mask区域的png图片。产生xml的方法很多,比如labelImg;png推荐labelme,通过它可以生成json文件,再通过命令转化为png。如果不想准备自己的数据,fork的git库中有现成的。
产生record文件
首先当然是要配置好tensorflow object detection API的环境。如果使用git库中的默认数据,只需要运行一个python脚本。
$ cd Deep-Learning/Custom_Mask_RCNN
$ python create_pet_tf_record.py
这样就产生了pet_train.record文件和pet_val.record。把它们和目录中原有的label_map.pbtxt都保存到训练目录中。
配置config文件
在object detection API目录中,复制samples/configs/mask_rcnn_inception_v2_coco.config文件到训练目录。修改文件中的num_classes为1;所有的PATH_TO_BE_CONFIGURED修改为合适的目标路径。
开始训练
python legacy/train.py --logtostderr --train_dir='训练目录/工作目录' --pipeline_config_path='训练目录/mask_rcnn_inception_v2_coco.config'
用tensorboard观察,稳定下来的速度很快。
测试模型
把训练好的模型export之后可以用test.py测试训练结果(复制脚本到object detection API目录中并修改模型和图片相应的路径)。