centerloss,顾名思义,中心损失函数,它的原理主要是在softmax loss的基础上,通过对训练集的每个类别在特征空间分别维护一个类中心,在训练过程,增加样本经过网络映射后在特征空间与类中心的距离约束,从而兼顾了类内聚合与类间分离,centerloss只是一个辅助损失函数,softmaxloss才是主打,但softmaxloss只能简单的将类分开,还得加上centerloss这一个强力辅助才能保证特征之间不仅具有可分性,同时也具有可判别性。
我们都知道,对于分类来说,希望类内距小,类间距大,那centerloss+softmaxloss就有这种功能。
简单复习一下softmax函数:
关于softmax的这个函数,有一些基本特性:是归一化指数函数,本质是离散概率分布,常用于多分类,值域为[0,1],输出结果之和为1。
那接着就来看看softmaxloss这个损失函数:
其中Sj为sigmoid输出的值,yj为标签对应独热编码的值(0或者1)
因此softmaxloss可以化简为:
log函数大家都知道,是一个定义域为[0,+∞],值域在[-∞,∞]的增函数,那么softmaxloss定义域在[0,1],取log就是在[-∞,1],那么取-log整个函数最终就变成了定义域在[0,1],值域在[0,+∞]的减函数,并且过(1,0)这个点,这一点正好符合我们梯度下降(当概率为1,损失下降到0),因此我们就可以使用softmaxloss来一步步降低分类的损失。
关于cneterloss,可以先看看公式:
N表示mini-batch的大小,xi表示输出特征,C表示对应的i个类中心,因此centerloss就是希望一个batch中的每个样本的feature离feature 的中心的距离的平方和要越小越好,也就是类内距离要越小越好。
反向传播:
α是学习率,也就是步长,设置一般取值0.5。
这里有一个问题就是centerloss学习率取值为0.5,那如果用同一个优化器进行优化,必然会造成softmaxloss的梯度爆炸,导致整个模型崩溃。
因此这里我们想到用两个优化器进行优化,分别优化centerloss和softmaxloss,这一点可以在代码里看到。
两个损失函数共同作用,softmaxloss负责大致分开各数据,centerloss使类内距越来越小,各司其职,达到把特征区分到最佳效果。其中λ是一个超参数,表示训练时更加倾向于哪个的损失,我在训练时候,λ选择2。
下面看看训练的效果吧:
我只训练了39轮,其实还是能看出来效果还是挺好的。
大概提一下训练过程中的坑吧,因为这些中心点是随机的,有可能随机到的中心点不好,数据点久久不能分开,建议中止训练重新开始或者直接删除参数重新训练。还有就是λ的值对结果影响挺大的,小心调参。
代码:
import torch
import torch.nn as nn
class CLNet(nn.Module):
def __init__(self):
super().__init__()
self.center = nn.Parameter(torch.randn(100, 2), requires_grad=True) # (10, 2)
def forward(self, feature, label, lambdas=2):
center_exp = self.center.index_select(dim=0, index=label.long()) # (100, 2)
count = torch.histc(label, bins=int(max(label).item() + 1), min=int(min(label).item()), max=int(max(label).item())) # (10,)
count_exp = count.index_select(dim=0, index=label.long()) # (100,)
loss = lambdas / 2 * torch.mean(torch.div(torch.sum(torch.pow(feature - center_exp, 2), dim=1), count_exp))
return loss
import torch
from Net_Model import Net
from centerloss import CLNet
import torch.nn as nn
from torchvision import transforms, datasets
import os
class Trainer:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.s_net = Net().to(self.device)
self.c_net = CLNet().to(self.device)
self.s_save_path = "models/softmax_net.pth"
self.c_save_path = "models/center_net.pth"
self.nll_loss = nn.NLLLoss()
self.s_optimizer = torch.optim.SGD(self.s_net.parameters(), lr=0.0005, momentum=0.9, weight_decay=0.0005)
self.c_optimizer = torch.optim.SGD(self.c_net.parameters(), lr=0.5)
self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.s_optimizer, gamma=0.95, last_epoch=-1)
self.mean, self.std = self.mean_std()
self.dataLoader = self.data_loader()
def mean_std(self):
sets = datasets.MNIST("./MNIST", train=True, download=False, transform=transforms.ToTensor())
loader = torch.utils.data.DataLoader(sets, batch_size=len(sets), shuffle=True)
data = next(iter(loader))[0]
mean = round(torch.mean(data, dim=(0, 2, 3)).item(), 3)
std = round(torch.std(data, dim=(0, 2, 3)).item(), 3)
return mean, std
def data_loader(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((self.mean,), (self.std,))
])
dataSet = datasets.MNIST("./MNIST", train=True, download=False, transform=transform)
dataLoader = torch.utils.data.DataLoader(dataSet, batch_size=100, shuffle=True, num_workers=4)
return dataLoader
def train_test(self):
if os.path.exists(self.s_save_path) and os.path.exists(self.c_save_path):
self.s_net.load_state_dict(torch.load(self.s_save_path))
self.c_net.load_state_dict(torch.load(self.c_save_path))
else:
print("NO Param")
epoch = 0
while True:
feature_loader = []
label_loader = []
for i, (x, y) in enumerate(self.dataLoader):
x = x.to(self.device)
y = y.to(self.device)
feature, output = self.s_net(x)
nll_loss = self.nll_loss(output, y)
y = y.float()
center_loss = self.c_net(feature, y, 2)
loss = nll_loss + center_loss
self.s_optimizer.zero_grad()
self.c_optimizer.zero_grad()
loss.backward()
self.s_optimizer.step()
self.c_optimizer.step()
feature_loader.append(feature)
label_loader.append(y)
if i % 100 == 0:
print("epoch:", epoch, "i:", i, "loss:", loss.item(), "softmax_loss:", nll_loss.item(),
"center_loss:", center_loss.item())
features = torch.cat(feature_loader, dim=0)
labels = torch.cat(label_loader, dim=0)
self.s_net.visualize(features.data.cpu().numpy(), labels.data.cpu().numpy(), epoch)
torch.save(self.s_net.state_dict(), self.s_save_path)
torch.save(self.c_net.state_dict(), self.c_save_path)
self.scheduler.step(None)
epoch += 1
if epoch == 40:
break
if __name__ == '__main__':
Trainer=Trainer()
Trainer.train_test()