解析minist数据集

image数据集是二进制存储,前32个字节是4个int,分别表示magic(没啥用?), num(图片个数), rows(图像行数), cols(列)。之后每个字节就是图像的每个像素。

labels数据集第前16个字节是2个int, 分别表示magic,, num,后面每个字节表示图像的数字,对应images里的每个图片

解析minst的方式很多,用python简单的实现以下

# -*-coding:utf8-*-
import numpy as np
import struct
import cv2

# 读取minist数据集,image, label读读取方式不同
def load_minist(path, kind="image"):
    """"
    :param path: 文件的路径 
    :param kind: 读取文件的种类,分为image, label
    :return: 
    """
    if kind == 'image':
        with open(path, mode='rb') as img_read:
            magic, num, row, col = struct.unpack('>IIII', img_read.read(16))
            images = np.fromfile(img_read, dtype=np.uint8).reshape(-1, row*col)
        return images

    if kind == 'label':
        with open(path, 'rb') as label_read:
            magic, num = struct.unpack('<II', label_read.read(8))
            labels = np.fromfile(label_read, dtype=np.uint8).reshape(-1, 1)
        return labels


# 从traim_img中获取一个图像做测试,看读取是否正确
def get_image(train_image, index):
    """

    :param train_image: img数组
    :param index:  图像索引 0 - 5999
    :return:
    """
    return train_images[index % 6000].reshape(28, 28)


if __name__ == '__main__':
    train_images = load_minist('train-images-idx3-ubyte', kind='image')
    train_labels = load_minist('train-labels-idx1-ubyte', kind='label')
    index = 3000
    img = get_image(train_images, index)
    cv2.imshow(str(train_labels[index][0]), img)
    cv2.waitKey()

 

相关文章:

  • 2021-11-17
  • 2021-08-19
  • 2021-05-31
  • 2021-08-18
  • 2022-12-23
  • 2022-12-23
  • 2021-12-04
  • 2022-12-23
猜你喜欢
  • 2021-09-05
  • 2022-01-19
  • 2021-04-27
  • 2021-06-10
  • 2021-08-07
  • 2021-11-07
  • 2022-03-03
相关资源
相似解决方案