【问题标题】:How to get labels from minibatch?如何从小批量中获取标签?
【发布时间】:2017-01-05 09:31:01
【问题描述】:

我正在编写本教程:

https://github.com/Microsoft/CNTK/blob/master/Tutorials/CNTK_201B_CIFAR-10_ImageHandsOn.ipynb

测试/训练数据文件是简单的制表符分隔的文本文件,包含图像文件名和正确的标签,如下所示:

...\data\CIFAR-10\test\00000.png    3
...\data\CIFAR-10\test\00001.png    8
...\data\CIFAR-10\test\00002.png    8

如何从小批量中提取原始标签?

我已尝试使用此代码:

reader_test = MinibatchSource(ImageDeserializer('test_map.txt', StreamDefs(
    features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
    labels   = StreamDef(field='label', shape=num_classes)      # and second as 'label'
)))

test_minibatch = reader_test.next_minibatch(10)
labels_stream_info = reader_test['labels']
orig_label = test_minibatch[labels_stream_info].value
print(orig_label)

<cntk.cntk_py.Value; proxy of <Swig Object of type 'CNTK::ValuePtr *' at 0x0000000007A32C00> >

但是,正如您在上面看到的,结果不是带有标签的数组。

获取标签的正确代码是什么?

此代码有效,但它使用不同的文件格式而不是 ImageDeserializer。

文件格式:

|labels 0 0 1 0 0 0 |features 0
|labels 1 0 0 0 0 0 |features 457

工作代码:

mb_source = text_format_minibatch_source('test_map2.txt', [
    StreamConfiguration('features', 1),
    StreamConfiguration('labels', num_classes)])

test_minibatch = mb_source.next_minibatch(2)

labels_stream_info = mb_source['labels']
orig_label = test_minibatch[labels_stream_info].value
print(orig_label)

[[[ 0.  0.  1.  0.  0.  0.]]
 [[ 1.  0.  0.  0.  0.  0.]]]

使用 ImageDeserializer 时如何获取输入中的标签?

【问题讨论】:

    标签: cntk


    【解决方案1】:

    你可以尝试使用:

    orig_label = test_minibatch[labels_stream_info].value
    

    【讨论】:

    • 我尝试了您的建议,但仍然得到相同的结果。
    【解决方案2】:

    我只是试图重现 - 我认为这里潜伏着一些奇怪的错误。我的直觉是,事实上labels 对象没有作为有效的numpy 数组返回。我在教程CNTK_201B中的train_and_evaluate函数中插入了以下调试输出:

    for epoch in range(max_epochs):       # loop over epochs
        sample_count = 0
        while sample_count < epoch_size:  # loop over minibatches in the epoch
            data = reader_train.next_minibatch(min(minibatch_size, epoch_size - sample_count), input_map=input_map) # fetch minibatch.
            print("Features:")
            print(data[input_var].shape)
            print(data[input_var].value.shape)
            print("Labels:")
            print(data[label_var].shape)
            print(data[label_var].value.shape)
    

    输出:

    Training 116906 parameters in 10 parameter tensors.
    Features:
    (64, 1, 3, 32, 32)
    (64, 1, 3, 32, 32)
    Labels:
    (64, 1, 10)
    ()
    

    标签显示为numpy.ndarray,但它没有有效的shape

    我称之为错误。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-09-29
      • 1970-01-01
      • 2016-01-09
      • 2013-10-29
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多