参考:https://pytorch.org/docs/stable/nn.html
Containers
Module
CLASS torch.nn.Module
所有神经网络模块的基类
你定义的模型必须是该类的子类,即继承与该类
模块也能包含其他模块,允许它们在树状结构中筑巢。您可以将子模块指定为常规属性:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
在这个例子中,nn.Conv2d(20, 20, 5)其实就是一个子模块
以这种方式赋值的子模块将会被登记,当你调用to()等函数时,它们的参数也将被转换。
方法:
1)
cpu()
将所有模型的参数和缓冲都移到CPU上
2)
cuda(device=None)
将所有模型的参数和缓冲都移到GPU上。因为其可能将关联的参数和缓冲变为不同的对象,所以如果优化时模块依赖于GPU,那么必须要在构造优化器之前调用该方法
参数:
device (int, optional) – 如果指定,则所有参数都将被复制到该设备上
3)
double()
强制转换浮点参数和缓冲区为double数据类型
float()
强制转换浮点参数和缓冲区为float数据类型
half()
强制转换浮点参数和缓冲区为half数据类型
举例:
import torch from torch import nn linear = nn.Linear(2, 2) print(linear.weight) linear.double() print(linear.weight)
返回:
Parameter containing: tensor([[-0.0890, 0.2313], [-0.4812, -0.3514]], requires_grad=True) Parameter containing: tensor([[-0.0890, 0.2313], [-0.4812, -0.3514]], dtype=torch.float64, requires_grad=True)
4)
type(dst_type)
强制转换所有参数和缓冲为给定的dst_type类型
参数:
dst_type (type or string) – 期望转成类型
举例:
import torch input = torch.FloatTensor([-0.8728, 0.3632, -0.0547]) print(input) print(input.type(torch.double))
返回:
tensor([-0.8728, 0.3632, -0.0547]) tensor([-0.8728, 0.3632, -0.0547], dtype=torch.float64)
5)
to(*args, **kwargs)
移动和/或强制转换参数和缓冲区
可调用的形式有三种:
-
to(device=None, dtype=None, non_blocking=False) to(dtype, non_blocking=False)-
to(tensor, non_blocking=False)
它的签名类似于torch.Tensor.to(),但是只接受浮点所需的dtype,即dtype仅能设置为float\double\half等浮点类型。如果给定了,该方法将仅强制转换浮点参数和缓冲区为指定的dtype。整数参数和缓冲区将移到给定的device中,dtypes类型不变。当设置non_blocking时,如果可能,它会尝试相对于主机异步地转换/移动,例如,将带有固定内存的CPU张量移动到CUDA设备。
参数:
-
device (
torch.device) – 该模块的参数和缓冲区期望使用的设备 -
dtype (
torch.dtype) – 该模块的浮点参数和缓冲区期望转换为的浮点参数和缓冲区 -
tensor (torch.Tensor) – 该该模块的所有参数和缓冲区都转换为该张量的dtype和device
举例:
import torch from torch import nn linear = nn.Linear(2, 2) print(linear.weight) linear.to(torch.double) print(linear.weight) gpu1 = torch.device("cuda:0") linear.to(gpu1, dtype=torch.half, non_blocking=True) print(linear.weight) cpu = torch.device("cpu") linear.to(cpu) print(linear.weight)
返回:
Parameter containing: tensor([[0.4604, 0.5215], [0.5981, 0.5912]], requires_grad=True) Parameter containing: tensor([[0.4604, 0.5215], [0.5981, 0.5912]], dtype=torch.float64, requires_grad=True) Parameter containing: tensor([[0.4604, 0.5215], [0.5981, 0.5913]], device='cuda:0', dtype=torch.float16, requires_grad=True) Parameter containing: tensor([[0.4604, 0.5215], [0.5981, 0.5913]], dtype=torch.float16, requires_grad=True)
6)
type(dst_type)
强制转换参数和缓冲区为dst_type类型
参数:
dst_type (type or string) –期望类型
举例:
import torch from torch import nn linear = nn.Linear(2, 2) print(linear.weight) linear.type(torch.double) print(linear.weight)
返回:
Parameter containing: tensor([[ 0.4370, -0.6806], [-0.4628, -0.4366]], requires_grad=True) Parameter containing: tensor([[ 0.4370, -0.6806], [-0.4628, -0.4366]], dtype=torch.float64, requires_grad=True)
7)
forward(*input)
定义每次调用时执行的计算。
应该被所有子类覆盖。
⚠️虽然需要在这个函数中定义前向传播的配方,但是应该在之后调用模块实例,而不是这个来调用;因为前者负责运行已注册的钩子,而后者则默默忽略它们。
这个函数就是我们在定义一个模块时定义的那个函数:
def forward(self, x):
当你调用模型时,该函数就会被调用:
import torchvision.models as models alexnet = models.alexnet() output = alexnet(input_data) #此时就会调用该forward()函数
8)
apply(fn)
递归地将函数fn应用到每个子模块(调用.children()方法返回的模块即子模块)和它自己上。典型地就是在初始化模块的参数时使用(在torch-nn-init中可见)
参数:
- fn (
Module-> None):应用到每个子模块上的函数
返回:
- self
返回类型:
- Module
例子:可见pytorch对模型参数初始化
9)
named_parameters(prefix='', recurse=True)
返回一个模型参数的迭代器,返回值包含参数的名字和参数本身
参数:
上面的例子就有使用,从返回结果可知我们能直接使用名字来获得参数值:
e.models.Conv2_3_64.weight.data
返回:
tensor([[[[ 1.8686e-02, -1.1276e-02, 1.0743e-02, -3.7258e-03], [ 1.7356e-02, -4.6002e-03, -1.5800e-02, 1.4272e-03], [-8.9406e-03, 2.8417e-02, 7.3844e-03, -2.0131e-02], [ 2.7378e-02, -1.3940e-02, -9.2417e-03, -1.3656e-02]], [[-2.6638e-02, 2.6307e-02, -2.9532e-02, 2.6932e-02], [-7.9886e-03, 3.4983e-03, -5.5121e-02, 1.8271e-02], [-4.3825e-02, 4.7733e-02, -3.5117e-02, -1.0677e-02], [-2.6437e-02, -4.5605e-03, 1.1901e-02, -1.9924e-02]], [[ 1.2108e-02, -2.0034e-02, -4.3065e-02, -4.4073e-03], [ 2.4294e-02, 2.0997e-04, 2.0511e-02, 4.0354e-02], [-7.4128e-03, 1.2180e-02, 2.1586e-02, -3.2092e-02], [-1.0036e-02, -1.3512e-02, 2.8016e-03, 1.7150e-02]]], [[[ 1.3010e-02, -7.7286e-03, -1.8568e-02, 2.6519e-03], [ 1.7086e-02, -3.7209e-03, 1.2222e-02, -9.8183e-03], [-1.2987e-02, -1.5011e-02, 1.0018e-02, -1.8424e-02], [-9.8759e-03, 3.1524e-03, 1.8473e-04, 3.0876e-02]], [[ 1.1653e-02, -3.5415e-02, -3.7799e-02, 1.5948e-02], [ 1.5886e-02, -2.0727e-02, 9.9321e-03, -2.6632e-02], [-1.3989e-02, -2.2149e-02, -1.6303e-02, -6.1840e-03], [-3.0577e-02, -8.2477e-03, 3.2550e-02, 3.0350e-02]], [[ 4.9647e-05, 2.5028e-02, 5.4636e-03, -2.2217e-02], [-1.7287e-02, -9.8452e-03, -2.1045e-02, 5.6478e-03], [ 9.7147e-03, 2.0614e-02, -1.5295e-02, 3.4130e-02], [ 4.1918e-02, -3.1760e-02, 7.8219e-03, 5.0951e-03]]], [[[-1.5743e-02, 3.2101e-02, -5.7166e-03, 3.7152e-02], [-8.6509e-03, -2.9025e-02, 1.2311e-02, 4.1298e-02], [ 1.3912e-02, -2.6538e-02, 1.2670e-02, -2.8338e-02], [ 1.7593e-04, 5.0950e-03, -3.0340e-02, 2.1955e-03]], [[ 4.7826e-03, 1.9481e-02, 5.3423e-03, -1.2969e-02], [ 5.1746e-03, -3.3188e-03, -2.3011e-02, 3.4073e-02], [ 1.5636e-02, -5.5335e-02, 1.1528e-03, -1.3905e-02], [ 9.9208e-03, -8.0908e-03, -9.8275e-03, -2.1614e-02]], [[ 9.2276e-03, -7.6164e-03, 8.6449e-03, -5.7667e-03], [ 2.2497e-02, -2.6568e-02, 2.9182e-02, 1.0791e-02], [ 2.8791e-02, -3.9055e-02, 4.0457e-04, -2.1397e-03], [-4.0300e-03, -2.0704e-03, -1.7246e-02, 3.2432e-02]]], ..., [[[ 1.7486e-02, 1.1616e-02, -1.2516e-02, -9.7095e-03], [-1.2367e-02, 3.0512e-02, 5.0169e-02, 1.1539e-02], [ 1.6477e-04, 2.5155e-03, -3.5218e-02, -1.3211e-02], [-1.3205e-02, 1.0017e-02, 4.2839e-02, -6.9317e-03]], [[-1.2817e-02, 3.1915e-02, 7.9632e-03, -6.4066e-03], [ 3.8245e-02, 1.1355e-02, 1.5460e-02, -1.1245e-03], [ 2.1138e-02, -2.4878e-03, 3.1970e-03, 4.2895e-02], [-2.4187e-02, -4.8445e-04, -2.5516e-02, 4.0083e-02]], [[ 2.0978e-02, -1.5094e-02, 3.0770e-02, 2.5550e-02], [ 8.2029e-03, 1.4726e-03, 1.2099e-02, -2.1542e-02], [ 6.7198e-03, -1.7803e-02, -4.8138e-03, -1.2432e-02], [-3.7668e-03, -1.9681e-02, -2.0834e-03, 8.3174e-04]]], [[[ 3.1066e-03, -1.3706e-02, 9.3733e-03, 1.2344e-02], [ 1.6753e-02, 1.4869e-03, -2.0681e-03, -8.8953e-03], [-3.0745e-02, 1.1374e-02, 2.1523e-02, -2.4726e-02], [ 1.0182e-02, 2.0394e-02, 5.5662e-04, 2.0951e-02]], [[ 2.1782e-02, 6.3107e-04, 1.6017e-02, 2.7767e-03], [ 7.6418e-03, -8.8861e-03, -2.2702e-02, -1.9778e-02], [ 2.2941e-02, 4.4974e-03, -2.7368e-02, -9.5090e-05], [ 3.2708e-02, -3.3382e-03, 1.5445e-02, -1.7446e-02]], [[ 1.5597e-02, -3.0816e-02, 1.4011e-02, -2.7484e-02], [ 2.3591e-03, 4.3519e-02, -1.3367e-02, 1.3066e-02], [-7.6286e-03, -4.7996e-03, 5.1619e-03, -1.1260e-02], [-1.5147e-02, 1.2956e-02, -2.5945e-02, 2.2437e-02]]], [[[ 2.1797e-02, 2.7596e-03, -2.0974e-02, -4.3435e-03], [ 4.6751e-03, -4.2520e-02, -1.0819e-02, 7.4361e-03], [ 4.7468e-02, -2.4098e-02, 7.5113e-04, -2.3566e-02], [ 1.6562e-03, 1.5573e-02, 1.5934e-02, 1.9551e-02]], [[ 1.7714e-02, 1.6497e-02, 1.9895e-02, -1.3463e-02], [ 1.6372e-02, -1.3358e-02, 2.0040e-02, -4.1047e-02], [-3.9821e-03, 1.3126e-02, -1.4217e-02, 5.7594e-03], [-2.2151e-02, -1.7522e-02, 2.9157e-03, 2.4983e-02]], [[-2.5523e-02, 1.2045e-02, 2.9011e-03, -1.2715e-02], [ 2.8795e-02, -2.6586e-02, 1.8300e-02, 3.7996e-02], [ 1.2800e-02, -1.6446e-02, -5.4592e-03, -1.6855e-02], [-4.6871e-02, 3.9172e-02, 2.6660e-02, -3.2577e-02]]]])