1,数据集简介

  SVHN(Street View House Number)Dateset 来源于谷歌街景门牌号码,原生的数据集1也就是官网的 Format 1 是一些原始的未经处理的彩色图片,如下图所示(不含有蓝色的边框),下载的数据集含有 PNG 的图像和 digitStruct.mat  的文件,其中包含了边框的位置信息,这个数据集每张图片上有好几个数字,适用于 OCR 相关方向。

  这里采用 Format2, Format2 将这些数字裁剪成32x32的大小,如图所示,并且数据是 .mat 文件。

TFlearn——(2)SVHN    TFlearn——(2)SVHN

2,数据处理

  数据集含有两个变量 X 代表图像, 训练集 X 的 shape 是  (32,32,3,73257) 也就是(width, height, channels, samples),  tensorflow 的张量需要 (samples, width, height, channels),所以需要转换一下,由于直接调用 cifar 10 的网络模型,数据只需要先做个归一化,所有像素除于255就 OK,另外原始数据 0 的标签是 10,这里要转化成 0,并提供 one_hot 编码。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 19 09:55:36 2017

@author: cheers
"""

import scipy.io as sio
import matplotlib.pyplot as plt
import numpy as np

image_size = 32
num_labels = 10

def display_data():
    print 'loading Matlab data...'
    train = sio.loadmat('train_32x32.mat')
    data=train['X']
    label=train['y']
    for i in range(10):
        plt.subplot(2,5,i+1)
        plt.title(label[i][0])
        plt.imshow(data[...,i])
        plt.axis('off')
    plt.show()

def load_data(one_hot = False):
    
    train = sio.loadmat('train_32x32.mat')
    test = sio.loadmat('test_32x32.mat')
    
    train_data=train['X']
    train_label=train['y']
    test_data=test['X']
    test_label=test['y']
    
    
    train_data = np.swapaxes(train_data, 0, 3)
    train_data = np.swapaxes(train_data, 2, 3)
    train_data = np.swapaxes(train_data, 1, 2)
    test_data = np.swapaxes(test_data, 0, 3)
    test_data = np.swapaxes(test_data, 2, 3)
    test_data = np.swapaxes(test_data, 1, 2)
    
    test_data = test_data / 255.
    train_data =train_data / 255.
    
    for i in range(train_label.shape[0]):
         if train_label[i][0] == 10:
             train_label[i][0] = 0
                        
    for i in range(test_label.shape[0]):
         if test_label[i][0] == 10:
             test_label[i][0] = 0

    if one_hot:
        train_label = (np.arange(num_labels) == train_label[:,]).astype(np.float32)
        test_label = (np.arange(num_labels) == test_label[:,]).astype(np.float32)

    return train_data,train_label, test_data,test_label

if __name__ == '__main__':
    load_data(one_hot = True)
    display_data()
View Code

相关文章:

  • 2021-11-20
  • 2021-05-15
  • 2021-11-19
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
猜你喜欢
  • 2022-03-04
  • 2022-01-08
  • 2021-10-28
  • 2021-09-26
  • 2021-04-09
  • 2021-08-17
  • 2022-02-15
相关资源
相似解决方案