【个人笔记】迁移学习:tensorflow利用inception_v3模型和retrain实现图像分类训练
数据集
这里以Visual Geometry Group Home Page上的动物分类图片为例。
在下载的7000张动物图片中手动挑选了五种分类,分别创建了五个文件夹装入,剩下的6000余张照片丢弃。
注意,文件夹及其内图片的名称全小写,绝对路径名中不能包含中文
文件夹中只需包含图片文件,损坏的图片文件和mat类型的文件都需要去除,可以手动或者利用程序自动去除
inception_v3模型下载
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
无需解压
github tensorflow及retrain文件下载
从github上克隆tensorflow的源码和retraining的源码到本地并进行解压缩
在~hub-master\examples\image_retraining中找到retrain.py文件,待会需要定位文件位置
创建文件夹框架
新建一个文件夹,包含如下几个文件(夹)
bottleneck文件夹为空
test_images中为检测模型效果的图片文件(在数据集代表的对象范围内),不能从数据集中抽取,需要从网上下载。
data文件夹中包含一个包含images子文件夹的train文件夹,images中放入数据集图片文件夹
下一步,新建retrain.bat批处理文件,内容如下:
下图为即将开始训练时后的文件布局
运行retrain.bat开始训练。到这一步为止会出现一些小错误,比如无法使用指定的本地inception_v3模型,GPU或者tensorflow环境没有配置好等,在百度或者google上全部可以搜索解决
训练结果
通过训练后可以看到在文件夹中产生了output_graph.pb,output_labels等文件
bottleneck文件夹中会产生对应图片分类的文件夹
打开后可以看到每个图片对应一个的txt文件,这是每一张图片在inception_v3模型中除去softmax层的固化模型中计算得到的特征数据,不用理会
output_labels中为标签数据,不用理会
调用模型进行分类检测
调用训练好的.pb文件和output_labels文件,输入准备好的检测文件的路径,检测成果
代码源自学习的网络教程
import tensorflow as tf
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
lines = tf.gfile.GFile(r'E:\Python\Tensorflow\retrain\output_labels.txt').readlines()
uid_to_human = {}
for uid, line in enumerate(lines):
line = line.strip('\n')
uid_to_human[uid] = line
def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id]
with tf.gfile.FastGFile(r'E:\Python\Tensorflow\retrain\output_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
for root, dirs, files in os.walk(r'E:\Python\Tensorflow\retrain\test_images'):
for file in files:
image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
image_path = os.path.join(root, file)
print(image_path)
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
top_k = predictions.argsort()[::-1]
print(top_k)
for node_id in top_k:
human_string = id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
print()
检测效果感人,模型训练成功
经验+5