【问题标题】:MxNet: label_shapes don't match names specified by label_namesMxNet:label_shapes 与 label_names 指定的名称不匹配
【发布时间】:2017-07-06 11:01:57
【问题描述】:

我编写了一个脚本,使用我用 MxNet 训练的模型对单个输入图像进行分类。为了对传入的图像进行分类,我通过网络将它们前馈。

简而言之,这就是我正在做的事情:

symbol, arg_params, aux_params = mx.model.load_checkpoint('model-prefix', 42)
model = mx.mod.Module(symbol=symbol, context=mx.cpu())
model.bind(data_shapes=[('data', (1, 3, 224, 244))], for_training=False)
model.set_params(arg_params, aux_params)

# ... loading the image & resizing ...
# img is the image to classify as numpy array of shape (3, 244, 244)

Batch = namedtuple('Batch', ['data'])
self._model.forward(Batch(data=[mx.nd.array(img)]))
probabilities = self._model.get_outputs()[0].asnumpy()

print(str(probabilities))

这工作正常,除了我收到以下警告

UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])

我应该改变什么以避免收到此警告?我不清楚 label_shapeslabel_names 参数的用途,以及我期望用什么来填充它们。

注意:我找到了一些关于它们的线程,但没有一个能够让我解决问题。同样,MxNet 文档也没有提供有关这些参数是什么以及应该如何填充它们的详细信息。

【问题讨论】:

    标签: python machine-learning computer-vision deep-learning mxnet


    【解决方案1】:

    设置label_names=Noneallow_missing=True。那应该消除警告。

    model = mx.mod.Module(symbol=symbol, context=mx.cpu(), label_names=None)
    ...
    model.set_params(arg_params, aux_params, allow_missing=True)
    

    如果您好奇为什么会首先打印警告,

    每个模块都有相关的标签。当这个模型被训练时,softmax_label被用作标签(很可能是因为输出层是一个名为'softmax'的softmax层)。从文件加载模型时,创建的模块将softmax_label 作为模块的标签。

    >>>print(model.label_names)
    ['softmax_label']
    

    model.bind 然后在不提供 label_shapes 的情况下被调用。

    model.bind(data_shapes=[('data', (1, 3, 224, 244))], for_training=False)
    

    MXNet 发现模块中有一个在绑定期间未提供的标签并抱怨它 - 这是您看到的警告消息。

    我认为如果使用 for_training=False 调用 bind,MXNet 不应该抱怨缺少标签。我创建了这个问题:https://github.com/dmlc/mxnet/issues/6958

    但是,对于我们从磁盘加载模型的这种特殊情况,我们可以使用 None 作为标签来加载它,这样 MXNet 以后不会在 bind 不提供标签时抱怨 - 这是建议的修复会的。

    【讨论】:

    • 感谢您的帮助。我确实尝试过。它删除了警告,但脚本不再起作用。使用label_name=None 时,脚本现在以RuntimeError: softmax_label is not presented 失败。它来自File "/usr/local/lib/python2.7/site-packages/mxnet-0.9.5-py2.7.egg/mxnet/module/module.py", line 264, in _impl raise RuntimeError("%s is not presented" % name) 知道发生了什么吗?我对这些各种标签参数的含义有点不知所措。
    • 可以为模块设置allow_missing=True吗? model.set_params(arg_params, aux_params, allow_missing=True)
    • 感谢@indhu-bharathi 的回答和解释。还将 allow_missing 添加到 true 使其工作。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多