1.ipython 示例 读取vgg网络的relu7的输出:
主要函数有
get_symbol(num_layers,num_classes):
得到网络所有的结构。
list_outputs():(给出符号的输出变量)的说明
list_arguments (给出当前符号的输入变量)的说明
get_internals():
获取中间层结果
完整代码:
import mxnet as mx
from importlib import import_module
def get_bonenet(num_classes, bonenet):
"""调用基础网络作为输入"""
if bonenet.startswith('vgg19'):
net = import_module('network.symbols.vgg')
sym = net.get_symbol(num_classes, num_layers=19)
internals = sym.get_internals()
# print(internals.list_outputs())
bonenet_layer = internals['drop7_output']
return bonenet_layer
def nbc_network(num_classes, bonenet):
"""n binary classifiers network"""
bonenet_layer = get_bonenet(num_classes,bonenet)
fc = mx.sym.FullyConnected(data=bonenet_layer, num_hidden=num_classes, name="fc")
label = mx.sym.Variable(name='label')
symbol = mx.symbol.LogisticRegressionOutput(data=fc, label=label, name='LRO_1')
return symbol