wupiao

前言:

class DaGMM(nn.Module):  # 自定义的模型需要继承nn.Module(固定写法)
    """Residual Block(残块)."""

    def __init__(self, n_gmm=2, latent_dim=3):  # n_gmm=gmm_k=4
        super(DaGMM, self).__init__()  # (固定写法)

        layers = []
        layers += [nn.Linear(118, 60)]
        layers += [nn.Tanh()]  # 激活函数并不改变数据的形状
        layers += [nn.Linear(60, 30)]
        layers += [nn.Tanh()]
        layers += [nn.Linear(30, 10)]
        layers += [nn.Tanh()]
        layers += [nn.Linear(10, 1)]
        self.encoder = nn.Sequential(*layers) #使用Sequential类来自定义顺序连接模型

1、Pytorch实现简单的自动编码器autoencoder

 自动编码器包括Encoder和Decoder两部分,Encoder和Decoder都可以是任意的模型,目前神经网络模型用得较多。输入的数据经过神经网络降维到一个编码(coder),然后又通过一个神经网络去解码得到一个与原始输入数据一模一样shape的生成数据,然后通过比较这两个数据,最小化它们之间的差异来训练这个网络中的Encoder和Decoder的参数,当这个过程训练完之后,拿出这个解码器,随机传入一个编码,通过解码器能够生成一个和原始数据差不多的数据。

  1 import torch
  2 import torch.nn as nn
  3 import torch.utils.data as Data
  4 import torchvision
  5 from torch.autograd import Variable
  6 import matplotlib.pyplot as plt
  7 from mpl_toolkits.mplot3d import Axes3D
  8 from matplotlib import cm
  9 import numpy as np
 10  
 11 # 超参数
 12 EPOCH = 10
 13 BATCH_SIZE = 64
 14 LR = 0.005
 15 DOWNLOAD_MNIST = False
 16 N_TEST_IMG = 5
 17  
 18 # 下载MNIST数据
 19 train_data = torchvision.datasets.MNIST(
 20     root=\'./mnist/\',
 21     train=True,
 22     transform=torchvision.transforms.ToTensor(),
 23     download=DOWNLOAD_MNIST,
 24 )
 25  
 26 # 输出一个样本
 27 # print(train_data.train_data.size())
 28 # print(train_data.train_labels.size())
 29 # plt.imshow(train_data.train_data[2].numpy(), cmap=\'gray\')
 30 # plt.title(\'%i\' % train_data.train_labels[2])
 31 # plt.show()
 32  
 33 # Dataloader
 34 train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 35  
 36  
 37 class AutoEncoder(nn.Module):
 38     def __init__(self):
 39         super(AutoEncoder, self).__init__()
 40         self.encoder = nn.Sequential(
 41             nn.Linear(28 * 28, 128),
 42             nn.Tanh(),
 43             nn.Linear(128, 64),
 44             nn.Tanh(),
 45             nn.Linear(64, 12),
 46             nn.Tanh(),
 47             nn.Linear(12, 3),
 48         )
 49  
 50         self.decoder = nn.Sequential(
 51             nn.Linear(3, 12),
 52             nn.Tanh(),
 53             nn.Linear(12, 64),
 54             nn.Tanh(),
 55             nn.Linear(64, 128),
 56             nn.Tanh(),
 57             nn.Linear(128, 28 * 28),
 58             nn.Sigmoid(),
 59         )
 60  
 61     def forward(self, x):
 62         encoded = self.encoder(x)
 63         decoded = self.decoder(encoded)
 64         return encoded, decoded
 65  
 66 autoencoder = AutoEncoder()
 67 optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
 68 loss_func = nn.MSELoss()
 69  
 70 # initialize figure
 71 f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
 72 plt.ion()   # continuously plot
 73  
 74 # original data (first row) for viewing
 75 view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.
 76 for i in range(N_TEST_IMG):
 77     a[0][i].imshow(np.reshape(Variable(view_data).data.numpy()[i], (28, 28)), cmap=\'gray\'); a[0][i].set_xticks(()); a[0][i].set_yticks(())
 78  
 79  
 80 for epoch in range(EPOCH):
 81     for step, (x, y) in enumerate(train_loader):
 82         b_x = Variable(x.view(-1, 28 * 28))
 83         b_y = Variable(x.view(-1, 28 * 28))
 84         b_label = Variable(y)
 85  
 86         encoded, decoded = autoencoder(b_x)
 87  
 88         loss = loss_func(decoded, b_y)
 89         optimizer.zero_grad()
 90         loss.backward()
 91         optimizer.step()
 92  
 93         if step % 100 == 0:
 94             print(\'Epoch: \', epoch, \'| train loss: %.4f\' % loss.data.numpy())
 95  
 96             # plotting decoded image (second row)
 97             _, decoded_data = autoencoder(Variable(view_data))
 98             for i in range(N_TEST_IMG):
 99                 a[1][i].clear()
100                 a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap=\'gray\')
101                 a[1][i].set_xticks(());
102                 a[1][i].set_yticks(())
103             plt.draw();
104             plt.pause(0.05)
105  
106 plt.ioff()
107 plt.show()
108  
109 # visualize in 3D plot
110 view_data = train_data.train_data[:200].view(-1, 28 * 28).type(torch.FloatTensor) / 255.
111 encoded_data, _ = autoencoder(Variable(view_data))
112 fig = plt.figure(2);
113 ax = Axes3D(fig)
114 X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
115  
116 values = train_data.train_labels[:200].numpy()
117 for x, y, z, s in zip(X, Y, Z, values):
118     c = cm.rainbow(int(255 * s / 9));
119     ax.text(x, y, z, s, backgroundcolor=c)
120 ax.set_xlim(X.min(), X.max());
121 ax.set_ylim(Y.min(), Y.max());
122 ax.set_zlim(Z.min(), Z.max())
123 plt.show()

 


 参考:

参考1:PyTorch实现简单的自动编码器autoencoder

参考2:Pytorch学习(13)-自编码(AutoEncoder)


 

分类:

技术点:

相关文章: