原文:https://www.cnblogs.com/denny402/p/7520063.html

原文:https://www.jianshu.com/p/84f72791806f

原文:https://blog.csdn.net/lee813/article/details/89609691

 

 

1、下载fashion-mnist数据集

  地址:https://github.com/zalandoresearch/fashion-mnist

  下面这四个都要下载,下载完成后,解压到同一个目录,我是解压到“E:/fashion_mnist/”这个目录里面,好和下面的代码目录一致

Python 10 训练模型

 

 

 

2、在Geany中执行下面这段代码。

  这段代码里面,需要先用pip安装skimage、torch、torchvision,前两篇文章有安装步骤。

  这段代码的作用:将下载下来的 二进制文件 转换为 图片,会在目录中生成两个文件夹和两个文本。

          文件夹里面全是图片,图片的内容是数字,N多数字。

          文本的内容主要是图片和真实数字的一个关联。

 

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
    mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
        )
test_set = (
    mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
        )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
    if(train):
        f=open(root+'train.txt','w')
        data_path=root+'/train/'
        if(not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
            img_path=data_path+str(i)+'.jpg'
            io.imsave(img_path,img.numpy())
            f.write(img_path+' '+str(label)+'\n')
        f.close()
    else:
        f = open(root + 'test.txt', 'w')
        data_path = root + '/test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
            img_path = data_path+ str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            f.write(img_path + ' ' + str(label) + '\n')
        f.close()

convert_to_img(True)
convert_to_img(False)
View Code

相关文章: