image数据集是二进制存储,前32个字节是4个int,分别表示magic(没啥用?), num(图片个数), rows(图像行数), cols(列)。之后每个字节就是图像的每个像素。
labels数据集第前16个字节是2个int, 分别表示magic,, num,后面每个字节表示图像的数字,对应images里的每个图片
解析minst的方式很多,用python简单的实现以下
# -*-coding:utf8-*-
import numpy as np
import struct
import cv2
# 读取minist数据集,image, label读读取方式不同
def load_minist(path, kind="image"):
""""
:param path: 文件的路径
:param kind: 读取文件的种类,分为image, label
:return:
"""
if kind == 'image':
with open(path, mode='rb') as img_read:
magic, num, row, col = struct.unpack('>IIII', img_read.read(16))
images = np.fromfile(img_read, dtype=np.uint8).reshape(-1, row*col)
return images
if kind == 'label':
with open(path, 'rb') as label_read:
magic, num = struct.unpack('<II', label_read.read(8))
labels = np.fromfile(label_read, dtype=np.uint8).reshape(-1, 1)
return labels
# 从traim_img中获取一个图像做测试,看读取是否正确
def get_image(train_image, index):
"""
:param train_image: img数组
:param index: 图像索引 0 - 5999
:return:
"""
return train_images[index % 6000].reshape(28, 28)
if __name__ == '__main__':
train_images = load_minist('train-images-idx3-ubyte', kind='image')
train_labels = load_minist('train-labels-idx1-ubyte', kind='label')
index = 3000
img = get_image(train_images, index)
cv2.imshow(str(train_labels[index][0]), img)
cv2.waitKey()