【问题标题】:Practice assignment AWS Computer Vision : get_Cifar10_dataset练习作业 AWS 计算机视觉:get_Cifar10_dataset
【发布时间】:2020-06-21 00:22:30
【问题描述】:

我对这个方法有疑问,它应该返回训练和验证数据集并检查它以返回与CIFAR10 中每个类的第一次出现相对应的索引。

这是代码: def get_cifar10_dataset(): """应该创建cifar 10网络并识别每个新类第一次的数据集索引 出现

:return: tuple of training and validation dataset as well as label indices
:rtype: (gluon.data.Dataset, 'dict_values' object is not subscriptable, gluon.data.Dataset, 
 dict[int:int])
"""

train_data = None
val_data = None
# YOUR CODE HERE
train_data = datasets.CIFAR10(train=True, root=M5_IMAGES)
val_data = datasets.CIFAR10(train=False, root=M5_IMAGES)

【问题讨论】:

  • 您能否添加有关您在何处运行此代码的详细信息?
  • 我在 aws 云上运行这段代码

标签: amazon-web-services mxnet


【解决方案1】:

您被要求返回带有标签和相应索引的字典。使用以下功能可以解决您的问题。

def get_idx_dict(data):

    lis = []
    idx = []
    indices = {}
    
    for i in range(len(data)):
        if data[i][1] not in lis:
            lis.append(data[i][1])
            idx.append(i)
            
    indices = {lis[i]: idx[i] for i in range(len(lis))}
    return indices

该函数返回具有所需输出的字典。对来自训练集和验证集的数据使用此函数。

train_indices = get_idx_dict(train_data)
val_indices = get_idx_dict(val_data)

【讨论】:

    【解决方案2】:

    你可以这样做

    def get_cifar10_dataset():
        """
        Should create the cifar 10 network and identify the dataset index of the first time each new class appears
        
        :return: tuple of training and validation dataset as well as label indices
        :rtype: (gluon.data.Dataset, dict[int:int], gluon.data.Dataset, dict[int:int])
        """
        train_data = None
        val_data = None
        train_indices = {}
        val_indices = {}
        
        # Use `root=M5_IMAGES` for your dataset
        train_data = gluon.data.vision.datasets.CIFAR10(train=True, root=M5_IMAGES)
        val_data   = gluon.data.vision.datasets.CIFAR10(train=False, root=M5_IMAGES)
        
        #for train
        for i in range(len(train_data)):
            if train_data[i][1] not in train_indices:
                train_indices[train_data[i][1]] = i
        #for valid
        for i in range(len(val_data)):
            if val_data[i][1] not in val_indices:
                val_indices[val_data[i][1]] = i
        
        #raise NotImplementedError()
        
        return train_data, train_indices, val_data, val_indices
    
    
    

    【讨论】:

      猜你喜欢
      • 2011-02-09
      • 2020-09-16
      • 1970-01-01
      • 2014-07-07
      • 2018-11-07
      • 1970-01-01
      • 2011-10-18
      • 2017-08-26
      • 1970-01-01
      相关资源
      最近更新 更多