参考:https://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate
torch.optim.lr_scheduler提供了几种方法来根据迭代的数量来调整学习率
自己手动定义一个学习率衰减函数:
def adjust_learning_rate(optimizer, epoch, lr): """Sets the learning rate to the initial LR decayed by 10 every 2 epochs""" lr *= (0.1 ** (epoch // 2)) for param_group in optimizer.param_groups: param_group['lr'] = lr
optimizer通过param_group来管理参数组。param_group中保存了参数组及其对应的学习率,动量等等
使用:
model = AlexNet(num_classes=2) optimizer = optim.SGD(params = model.parameters(), lr=10) plt.figure() x = list(range(10)) y = [] lr_init = optimizer.param_groups[0]['lr'] for epoch in range(10): adjust_learning_rate(optimizer, epoch, lr_init) lr = optimizer.param_groups[0]['lr'] print(epoch, lr) y.append(lr) plt.plot(x,y)
返回:
0 10.0 1 10.0 2 1.0 3 1.0 4 0.10000000000000002 5 0.10000000000000002 6 0.010000000000000002 7 0.010000000000000002 8 0.0010000000000000002 9 0.0010000000000000002
如图:
举例先导入所需的库:
import torch import torch.optim as optim from torch.optim import lr_scheduler from torchvision.models import AlexNet import matplotlib.pyplot as plt
1.LambdaLR
CLASS torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
将每个参数组的学习率设置为初始lr乘以给定函数。当last_epoch=-1时,将初始lr设置为lr。
参数:
-
optimizer (Optimizer) – 封装好的优化器
-
lr_lambda (function or list) –当是一个函数时,需要给其一个整数参数,使其计算出一个乘数因子,用于调整学习率,通常该输入参数是epoch数目;或此类函数的列表,根据在optimator.param_groups中的每组的长度决定lr_lambda的函数个数,如下报错。
-
last_epoch (int) – 最后一个迭代epoch的索引. Default: -1.
如:
optimizer = optim.SGD(params = model.parameters(), lr=0.05) lambda1 = lambda epoch:epoch // 10 #根据epoch计算出与lr相乘的乘数因子为epoch//10的值 lambda2 = lambda epoch:0.95 ** epoch #根据epoch计算出与lr相乘的乘数因子为0.95 ** epoch的值 scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
报错:
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-2-c02d2d9ffc0d> in <module> 4 lambda1 = lambda epoch:epoch // 10 5 lambda2 = lambda epoch:0.95 ** epoch ----> 6 scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 7 plt.figure() 8 x = list(range(40)) /anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/optim/lr_scheduler.py in __init__(self, optimizer, lr_lambda, last_epoch) 83 if len(lr_lambda) != len(optimizer.param_groups): 84 raise ValueError("Expected {} lr_lambdas, but got {}".format( ---> 85 len(optimizer.param_groups), len(lr_lambda))) 86 self.lr_lambdas = list(lr_lambda) 87 self.last_epoch = last_epoch ValueError: Expected 1 lr_lambdas, but got 2
说明这里只需要一个lambda函数
举例:
1)使用的是lambda2
model = AlexNet(num_classes=2) optimizer = optim.SGD(params = model.parameters(), lr=0.05) #下面是两种lambda函数 #epoch=0到9时,epoch//10=0,所以这时的lr = 0.05*0=0 #epoch=10到19时,epoch//10=1,所以这时的lr = 0.05*1=0.05 lambda1 = lambda epoch:epoch // 10 #当epoch=0时,lr = lr * (0.2**0)=0.05;当epoch=1时,lr = lr * (0.2**1)=0.01 lambda2 = lambda epoch:0.2 ** epoch scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda2) plt.figure() x = list(range(40)) y = [] for epoch in range(40): scheduler.step() lr = scheduler.get_lr() print(epoch, scheduler.get_lr()[0]) y.append(scheduler.get_lr()[0]) plt.plot(x,y)
返回:
0 0.05 1 0.010000000000000002 2 0.0020000000000000005 3 0.00040000000000000013 4 8.000000000000002e-05 5 1.6000000000000006e-05 6 3.2000000000000015e-06 7 6.400000000000002e-07 8 1.2800000000000006e-07 9 2.5600000000000014e-08 10 5.120000000000003e-09 11 1.0240000000000006e-09 12 2.0480000000000014e-10 13 4.096000000000003e-11 14 8.192000000000007e-12 15 1.6384000000000016e-12 16 3.276800000000003e-13 17 6.553600000000007e-14 18 1.3107200000000014e-14 19 2.621440000000003e-15 20 5.242880000000006e-16 21 1.0485760000000013e-16 22 2.0971520000000027e-17 23 4.194304000000006e-18 24 8.388608000000012e-19 25 1.6777216000000025e-19 26 3.355443200000005e-20 27 6.71088640000001e-21 28 1.3421772800000022e-21 29 2.6843545600000045e-22 30 5.368709120000009e-23 31 1.0737418240000018e-23 32 2.147483648000004e-24 33 4.294967296000008e-25 34 8.589934592000016e-26 35 1.7179869184000033e-26 36 3.435973836800007e-27 37 6.871947673600015e-28 38 1.3743895347200028e-28 39 2.748779069440006e-29