前言:
class DaGMM(nn.Module): # 自定义的模型需要继承nn.Module(固定写法) """Residual Block(残块).""" def __init__(self, n_gmm=2, latent_dim=3): # n_gmm=gmm_k=4 super(DaGMM, self).__init__() # (固定写法) layers = [] layers += [nn.Linear(118, 60)] layers += [nn.Tanh()] # 激活函数并不改变数据的形状 layers += [nn.Linear(60, 30)] layers += [nn.Tanh()] layers += [nn.Linear(30, 10)] layers += [nn.Tanh()] layers += [nn.Linear(10, 1)] self.encoder = nn.Sequential(*layers) #使用Sequential类来自定义顺序连接模型
1、Pytorch实现简单的自动编码器autoencoder
自动编码器包括Encoder和Decoder两部分,Encoder和Decoder都可以是任意的模型,目前神经网络模型用得较多。输入的数据经过神经网络降维到一个编码(coder),然后又通过一个神经网络去解码得到一个与原始输入数据一模一样shape的生成数据,然后通过比较这两个数据,最小化它们之间的差异来训练这个网络中的Encoder和Decoder的参数,当这个过程训练完之后,拿出这个解码器,随机传入一个编码,通过解码器能够生成一个和原始数据差不多的数据。
1 import torch 2 import torch.nn as nn 3 import torch.utils.data as Data 4 import torchvision 5 from torch.autograd import Variable 6 import matplotlib.pyplot as plt 7 from mpl_toolkits.mplot3d import Axes3D 8 from matplotlib import cm 9 import numpy as np 10 11 # 超参数 12 EPOCH = 10 13 BATCH_SIZE = 64 14 LR = 0.005 15 DOWNLOAD_MNIST = False 16 N_TEST_IMG = 5 17 18 # 下载MNIST数据 19 train_data = torchvision.datasets.MNIST( 20 root=\'./mnist/\', 21 train=True, 22 transform=torchvision.transforms.ToTensor(), 23 download=DOWNLOAD_MNIST, 24 ) 25 26 # 输出一个样本 27 # print(train_data.train_data.size()) 28 # print(train_data.train_labels.size()) 29 # plt.imshow(train_data.train_data[2].numpy(), cmap=\'gray\') 30 # plt.title(\'%i\' % train_data.train_labels[2]) 31 # plt.show() 32 33 # Dataloader 34 train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) 35 36 37 class AutoEncoder(nn.Module): 38 def __init__(self): 39 super(AutoEncoder, self).__init__() 40 self.encoder = nn.Sequential( 41 nn.Linear(28 * 28, 128), 42 nn.Tanh(), 43 nn.Linear(128, 64), 44 nn.Tanh(), 45 nn.Linear(64, 12), 46 nn.Tanh(), 47 nn.Linear(12, 3), 48 ) 49 50 self.decoder = nn.Sequential( 51 nn.Linear(3, 12), 52 nn.Tanh(), 53 nn.Linear(12, 64), 54 nn.Tanh(), 55 nn.Linear(64, 128), 56 nn.Tanh(), 57 nn.Linear(128, 28 * 28), 58 nn.Sigmoid(), 59 ) 60 61 def forward(self, x): 62 encoded = self.encoder(x) 63 decoded = self.decoder(encoded) 64 return encoded, decoded 65 66 autoencoder = AutoEncoder() 67 optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR) 68 loss_func = nn.MSELoss() 69 70 # initialize figure 71 f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2)) 72 plt.ion() # continuously plot 73 74 # original data (first row) for viewing 75 view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255. 76 for i in range(N_TEST_IMG): 77 a[0][i].imshow(np.reshape(Variable(view_data).data.numpy()[i], (28, 28)), cmap=\'gray\'); a[0][i].set_xticks(()); a[0][i].set_yticks(()) 78 79 80 for epoch in range(EPOCH): 81 for step, (x, y) in enumerate(train_loader): 82 b_x = Variable(x.view(-1, 28 * 28)) 83 b_y = Variable(x.view(-1, 28 * 28)) 84 b_label = Variable(y) 85 86 encoded, decoded = autoencoder(b_x) 87 88 loss = loss_func(decoded, b_y) 89 optimizer.zero_grad() 90 loss.backward() 91 optimizer.step() 92 93 if step % 100 == 0: 94 print(\'Epoch: \', epoch, \'| train loss: %.4f\' % loss.data.numpy()) 95 96 # plotting decoded image (second row) 97 _, decoded_data = autoencoder(Variable(view_data)) 98 for i in range(N_TEST_IMG): 99 a[1][i].clear() 100 a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap=\'gray\') 101 a[1][i].set_xticks(()); 102 a[1][i].set_yticks(()) 103 plt.draw(); 104 plt.pause(0.05) 105 106 plt.ioff() 107 plt.show() 108 109 # visualize in 3D plot 110 view_data = train_data.train_data[:200].view(-1, 28 * 28).type(torch.FloatTensor) / 255. 111 encoded_data, _ = autoencoder(Variable(view_data)) 112 fig = plt.figure(2); 113 ax = Axes3D(fig) 114 X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy() 115 116 values = train_data.train_labels[:200].numpy() 117 for x, y, z, s in zip(X, Y, Z, values): 118 c = cm.rainbow(int(255 * s / 9)); 119 ax.text(x, y, z, s, backgroundcolor=c) 120 ax.set_xlim(X.min(), X.max()); 121 ax.set_ylim(Y.min(), Y.max()); 122 ax.set_zlim(Z.min(), Z.max()) 123 plt.show()
参考:
参考1:PyTorch实现简单的自动编码器autoencoder
参考2:Pytorch学习(13)-自编码(AutoEncoder)