Pytorch_模型转Caffe(一)解析caffemodel与prototxt
目录
Pytorch_模型转Caffe(一)
1.Caffe简介
2.Caffe进行目标检测任务
- 利用ssd进行目标检测任务,主要步骤如下(重点是模型的移植)
3.Caffe五大组件
4.caffemodel
- 包含了prototxt(除了solver.prototxt) 和 weights bias
prototxt 以文本的方式存储网络结构 - 通过创建
caffe_pb2.NetParameter()对象,获取caffemodel内容
model = caffe_pb2.NetParameter()
f = open(caffemodel_filename, \'rb\')
model.ParseFromString(f.read())
- 循环获取每个layer下的参数
model.layer是每层的信息
## 逐个解析prototxt 内容 但有点复杂
for i,layer in enumerate(Tarpa_model.layer):
tops = layer.top
bottoms = layer.bottom
top_str = \'\'
bottom_str =\'\'
transform_param_str = \'\'
data_param_str = \'\'
annotated_data_param_str=\'\'
for top in layer.top:
top_str += \'\ttop:"{}"\n\'.format(top)
for bottom in layer.bottom:
bottom_str += \'\tbottom:"{}"\n\'.format(bottom)
# transform
if str(layer.transform_param)!=\'\':
transform_param_str = str(layer.transform_param).split(\'\n\')
new_str_trans =\'\'
for item in transform_param_str:
new_str_trans += \'\t\t\'+str(item) + \'\n\' if item!=\'\' else \'\'
# print(new_str_trans)
transform_param_str = \'\t\' +\'transform_param {\n\'+ new_str_trans +\'\t}\'+\'\n\'
# data_param
if str(layer.data_param) != \'\':
data_param_str = str(layer.data_param).split(\'\n\')
new_str_data_param =\'\'
for item in data_param_str:
new_str_data_param += \'\t\t\'+str(item) + \'\n\' if item!=\'\' else \'\'
data_param_str = \'\t\' +\'data_param {\n\'+ new_str_data_param +\'\t}\'+\'\n\'
# annotated_data_param
if str(layer.annotated_data_param) != \'\':
annotated_data_param_str = str(layer.annotated_data_param).split(\'\n\')
new_str_annotated_data_param =\'\'
for item in annotated_data_param_str:
new_str_annotated_data_param += \'\t\t\'+str(item) + \'\n\' if item!=\'\' else \'\'
annotated_data_param_str = \'\t\' +\'annotated_data_param {\n\'+ new_str_annotated_data_param +\'\t}\'+\'\n\'
- 解析后的部分结果
### train.prototxt 卷积层
layer {
name: "conv1_2"
type: "Convolution"
bottom: "conv1_1"
top: "conv1_2"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
convolution_param {
num_output: 64
pad: 1
kernel_size: 3
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
value: 0.0
}
}
}
5.通过caffemodel解析train.prototxt
- 旨在学习了解caffemodel中的数据存储结构
采用剔除法,先保存所有layer,之后删除blobs和其他无用信息
import caffe.proto.caffe_pb2 as caffe_pb2
caffemodel_filename = src_path + \'/***.caffemodel\'
Tarpa_model = caffe_pb2.NetParameter()
f = open(caffemodel_filename, \'rb\')
Tarpa_model.ParseFromString(f.read())
f.close()
print(Tarpa_model.name)
print(Tarpa_model.input)
# print(Tarpa_model.layer)
# print(type(Tarpa_model.layer))
f = open(\'_caffemodel_.log\',\'w\')
f.write(\'name: "{}"\'.format(Tarpa_model.name)+\'\n\')
for i,layer in enumerate(Tarpa_model.layer):
transform_param_str = str(layer).split(\'\n\')
new_str_trans =\'\'
comtinue_flag = 0
for item in transform_param_str:
if item == \'phase: TRAIN\':
continue
if comtinue_flag and \'}\'in item:
continue
comtinue_flag = 0
if \'blobs\' in item or \'data:\'in item or \'shape\'in item or \'dim:\'in item:
comtinue_flag = 1
continue
new_str_trans += \'\t\'+str(item) + \'\n\' if item!=\'\' else \'\'
layer_str = \'layer {\' +\'\n\'+\
new_str_trans+\
\'}\'+\'\n\'
f.write(str(layer_str))
print(i)
# if i ==2:
# break
f.close()
6.caffemodel解析现存问题
在生成.prototxt后可以看出有很多split字段,暂未得到解决
layer {
name: "data_data_0_split"
type: "Split"
bottom: "data"
top: "data_data_0_split_0"
top: "data_data_0_split_1"
top: "data_data_0_split_2"
top: "data_data_0_split_3"
top: "data_data_0_split_4"
top: "data_data_0_split_5"
top: "data_data_0_split_6"
top: "data_data_0_split_7"
}