【问题标题】:How do I generate sharp images Cifar-10如何生成清晰的图像 Cifar-10
【发布时间】:2017-07-21 15:36:40
【问题描述】:

我正在使用 tensorflow,并尝试在 Cifar-10 上可视化自动编码器的输入/输出。

我在这里关注这个答案:Why CIFAR-10 images are not displayed properly using matplotlib?

这是运行他们的代码稍作修改的结果(将 figsize 更改为 5,5):

但是,这仍然不如原始页面中的图像清晰:https://www.cs.toronto.edu/~kriz/cifar.html

我怎样才能做得更好?

【问题讨论】:

    标签: python python-2.7 numpy matplotlib tensorflow


    【解决方案1】:

    这里可能有两个问题:

    问题 1:

    看起来您的颜色通道(红色、绿色、蓝色)是混合的。这可以解释为什么颜色如此奇怪。如果是这种情况,您将需要交换阵列中的颜色通道,如下所示。

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.cbook import get_sample_data
    
    rgb_image = plt.imread(get_sample_data("grace_hopper.png", asfileobj=False))
    
    # correct color channels (R, G, B)
    plt.figure()
    plt.imshow(rgb_image)
    plt.axis('off')
    

    # swapped color channels (R, B, G)
    rgb_image = rgb_image[:, :, [0, 2, 1]]
    plt.figure()
    plt.imshow(rgb_image)
    plt.axis('off')
    

    问题 2:

    Matplotlib 的 plt.imshow 有一个关键字参数 interpolation,如果未指定,则默认为 None。 Matplotlib 然后参考您的本地样式表来确定默认的插值行为。根据您的样式表,这可能会导致应用插值,从而导致图像失真。请参阅documentation for imshow for more details

    如果你想保证 Matplotlib 不会对你的图像进行插值,你应该在 plt.imshow 中指定 interpolation="none"。这是令人困惑的,因为None 的默认 NoneType 值与"none" 的字符串值产生不同的行为。

    red = np.zeros((100, 100, 3), dtype=np.uint8)
    red[:, :, 0] = 255
    red[40:60, 40:60, :] = 255
    
    # with interpolation
    plt.figure()
    plt.imshow(red, interpolation='bicubic') 
    plt.axis('off')
    

    # without interpolation
    plt.figure()
    plt.imshow(red, interpolation='none') 
    plt.axis('off')
    

    【讨论】:

    • 在 Matplotlib 2.0 中,默认插值已更改为 'image.interpolation': 'nearest'
    【解决方案2】:

    也许你应该做这样的事情。图像非常小,高度和宽度均为 32 像素,因此只有在缩略图大小时它们才会更清晰。我在这里使用双三次变换对其进行了插值。但是您可以将其更改为“无”,这样您将获得像素化图像,而不是模糊。

    def unpickle(file):
        with open(file, 'rb') as fo:
            dict1 = pickle.load(fo, encoding='bytes')
        return dict1
    
    pd_tr = pd.DataFrame()
    tr_y = pd.DataFrame()
    
    for i in range(1,6):
        data = unpickle('data/data_batch_' + str(i))
        pd_tr = pd_tr.append(pd.DataFrame(data[b'data']))
        tr_y = tr_y.append(pd.DataFrame(data[b'labels']))
        pd_tr['labels'] = tr_y
    
    tr_x = np.asarray(pd_tr.iloc[:, :3072])
    tr_y = np.asarray(pd_tr['labels'])
    ts_x = np.asarray(unpickle('data/test_batch')[b'data'])
    ts_y = np.asarray(unpickle('data/test_batch')[b'labels'])    
    labels = unpickle('data/batches.meta')[b'label_names']
    
    def plot_CIFAR(ind):
        arr = tr_x[ind]
        R = arr[0:1024].reshape(32,32)/255.0
        G = arr[1024:2048].reshape(32,32)/255.0
        B = arr[2048:].reshape(32,32)/255.0
    
        img = np.dstack((R,G,B))
        title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]]))
        fig = plt.figure(figsize=(3,3))
        ax = fig.add_subplot(111)
        ax.imshow(img,interpolation='bicubic')
        ax.set_title('Category = '+ title,fontsize =15)
    
    plot_CIFAR(4)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2017-05-15
      • 2016-08-26
      • 1970-01-01
      • 2017-10-27
      • 2012-10-21
      • 1970-01-01
      • 2017-12-19
      相关资源
      最近更新 更多