RNN训练mnist数据

此博文主要是为了用RNN 做图像分类,来了解pytorch的RNN用于训练图像时的用法。

mnist数据集中的图像:图像1由28*28个像素点组成如下图所示
RNN训练mnist数据
对于此图像我们可以将每张图像看作是长28的序列,序列中的每个元素的特征维度为28.

RNN的结构
RNN训练mnist数据
首先处理数据

import torch
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader

from torchvision import transforms as tfs
from torchvision.datasets import MNIST

定义数据

data_tf = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5], [0.5]) # 标准化
])

train_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=True, transform=data_tf)
test_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=False, transform=data_tf)

train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)

定义模型

class rnn_classify(nn.Module):
    def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
        super(rnn_classify, self).__init__()
        self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm
        self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果
        
    def forward(self, x):
        '''
        x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)
        '''
        x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)
        x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)
        out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)
        out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)
        out = self.classifier(out) # 得到分类结果
        return out

定义网络、损失函数、优化器

net = rnn_classify()
criterion = nn.CrossEntropyLoss()

optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)

定义训练函数

from datetime import datetime

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable


def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total


def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    if torch.cuda.is_available():
        net = net.cuda()
    prev_time = datetime.now()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            if torch.cuda.is_available():
                im = Variable(im.cuda())  # (bs, 3, h, w)
                label = Variable(label.cuda())  # (bs, h, w)
            else:
                im = Variable(im)
                label = Variable(label)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(output, label)

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                if torch.cuda.is_available():
                    im = Variable(im.cuda(), volatile=True)
                    label = Variable(label.cuda(), volatile=True)
                else:
                    im = Variable(im, volatile=True)
                    label = Variable(label, volatile=True)
                output = net(im)
                loss = criterion(output, label)
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                % (epoch, train_loss / len(train_data),
                   train_acc / len(train_data), valid_loss / len(valid_data),
                   valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)

开始训练

train(net, train_data, test_data, 10, optimzier, criterion)

完整代码:

    import torch
    from torch.autograd import Variable
    from torch import nn
    from torch.utils.data import DataLoader
    from datetime import datetime   
    import torch.nn.functional as F
    from torchvision import transforms as tfs
    from torchvision.datasets import MNIST


# 定义数据

    data_tf = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize([0.5], [0.5]) # 标准化
    ])
    
    train_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=True, transform=data_tf)
    test_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=False, transform=data_tf)
    
    train_data = DataLoader(train_set, 64, True, num_workers=4)
    test_data = DataLoader(test_set, 128, False, num_workers=4)


# 定义模型

    class rnn_classify(nn.Module):
        def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
            super(rnn_classify, self).__init__()
            self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm
            self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果
            
        def forward(self, x):
            '''
            x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)
            '''
            x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)
            x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)
            out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)
            out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)
            out = self.classifier(out) # 得到分类结果
            return out

# 定义网络、损失函数、优化器

    net = rnn_classify()
    criterion = nn.CrossEntropyLoss()
    
    optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)

# 定义训练函数


    
    
    def get_acc(output, label):
        total = output.shape[0]
        _, pred_label = output.max(1)
        num_correct = (pred_label == label).sum().item()
        return num_correct / total
    
    
    def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
        if torch.cuda.is_available():
            net = net.cuda()
        prev_time = datetime.now()
        for epoch in range(num_epochs):
            train_loss = 0
            train_acc = 0
            net = net.train()
            for im, label in train_data:
                if torch.cuda.is_available():
                    im = Variable(im.cuda())  # (bs, 3, h, w)
                    label = Variable(label.cuda())  # (bs, h, w)
                else:
                    im = Variable(im)
                    label = Variable(label)
                # forward
                output = net(im)
                loss = criterion(output, label)
                # backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
                train_loss += loss.item()
                train_acc += get_acc(output, label)
    
            cur_time = datetime.now()
            h, remainder = divmod((cur_time - prev_time).seconds, 3600)
            m, s = divmod(remainder, 60)
            time_str = "Time %02d:%02d:%02d" % (h, m, s)
            if valid_data is not None:
                valid_loss = 0
                valid_acc = 0
                net = net.eval()
                for im, label in valid_data:
                    if torch.cuda.is_available():
                        im = Variable(im.cuda(), volatile=True)
                        label = Variable(label.cuda(), volatile=True)
                    else:
                        im = Variable(im, volatile=True)
                        label = Variable(label, volatile=True)
                    output = net(im)
                    loss = criterion(output, label)
                    valid_loss += loss.item()
                    valid_acc += get_acc(output, label)
                epoch_str = (
                    "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                    % (epoch, train_loss / len(train_data),
                       train_acc / len(train_data), valid_loss / len(valid_data),
                       valid_acc / len(valid_data)))
            else:
                epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                             (epoch, train_loss / len(train_data),
                              train_acc / len(train_data)))
            prev_time = cur_time
            print(epoch_str + time_str)

注:参考廖星宇大神的《深度学习之pytorch》

相关文章: