参考:https://pytorch.org/docs/stable/nn.html

Containers

Module

CLASS torch.nn.Module

所有神经网络模块的基类

你定义的模型必须是该类的子类,即继承与该类

模块也能包含其他模块,允许它们在树状结构中筑巢。您可以将子模块指定为常规属性:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

在这个例子中,nn.Conv2d(20, 20, 5)其实就是一个子模块

以这种方式赋值的子模块将会被登记,当你调用to()等函数时,它们的参数也将被转换。

 

方法:

1)

cpu()

将所有模型的参数和缓冲都移到CPU上

 

2)

cuda(device=None)

将所有模型的参数和缓冲都移到GPU上。因为其可能将关联的参数和缓冲变为不同的对象,所以如果优化时模块依赖于GPU,那么必须要在构造优化器之前调用该方法

参数:

device (int, optional) – 如果指定,则所有参数都将被复制到该设备上

 

3)

double()

强制转换浮点参数和缓冲区为double数据类型

 

float()

强制转换浮点参数和缓冲区为float数据类型

 

half()

强制转换浮点参数和缓冲区为half数据类型

举例:

import torch
from torch import nn
linear = nn.Linear(2, 2)
print(linear.weight)

linear.double()
print(linear.weight)

返回:

Parameter containing:
tensor([[-0.0890,  0.2313],
        [-0.4812, -0.3514]], requires_grad=True)
Parameter containing:
tensor([[-0.0890,  0.2313],
        [-0.4812, -0.3514]], dtype=torch.float64, requires_grad=True)

 

4)

type(dst_type)

强制转换所有参数和缓冲为给定的dst_type类型

参数:

dst_type (type or string) – 期望转成类型

举例:

import torch
input = torch.FloatTensor([-0.8728,  0.3632, -0.0547])
print(input)
print(input.type(torch.double))

返回:

tensor([-0.8728,  0.3632, -0.0547])
tensor([-0.8728,  0.3632, -0.0547], dtype=torch.float64)

 

5)

to(*args, **kwargs)

移动和/或强制转换参数和缓冲区

可调用的形式有三种:

  • to(device=None, dtype=None, non_blocking=False)
  • to(dtype, non_blocking=False)
  • to(tensor, non_blocking=False)

它的签名类似于torch.Tensor.to(),但是只接受浮点所需的dtype,即dtype仅能设置为float\double\half等浮点类型。如果给定了,该方法将仅强制转换浮点参数和缓冲区为指定的dtype。整数参数和缓冲区将移到给定的device中,dtypes类型不变。当设置non_blocking时,如果可能,它会尝试相对于主机异步地转换/移动,例如,将带有固定内存的CPU张量移动到CUDA设备。

参数:

  • device (torch.device) – 该模块的参数和缓冲区期望使用的设备

  • dtype (torch.dtype) – 该模块的浮点参数和缓冲区期望转换为的浮点参数和缓冲区

  • tensor (torch.Tensor) – 该该模块的所有参数和缓冲区都转换为该张量的dtypedevice

举例:

import torch
from torch import nn
linear = nn.Linear(2, 2)
print(linear.weight)

linear.to(torch.double)
print(linear.weight)

gpu1 = torch.device("cuda:0")
linear.to(gpu1, dtype=torch.half, non_blocking=True)
print(linear.weight)

cpu = torch.device("cpu")
linear.to(cpu)
print(linear.weight)

返回:

Parameter containing:
tensor([[0.4604, 0.5215],
        [0.5981, 0.5912]], requires_grad=True)
Parameter containing:
tensor([[0.4604, 0.5215],
        [0.5981, 0.5912]], dtype=torch.float64, requires_grad=True)
Parameter containing:
tensor([[0.4604, 0.5215],
        [0.5981, 0.5913]], device='cuda:0', dtype=torch.float16,
       requires_grad=True)
Parameter containing:
tensor([[0.4604, 0.5215],
        [0.5981, 0.5913]], dtype=torch.float16, requires_grad=True)

 

6)

type(dst_type)

强制转换参数和缓冲区为dst_type类型

参数:

dst_type (type or string) –期望类型

举例:

import torch
from torch import nn
linear = nn.Linear(2, 2)
print(linear.weight)

linear.type(torch.double)
print(linear.weight)

返回:

Parameter containing:
tensor([[ 0.4370, -0.6806],
        [-0.4628, -0.4366]], requires_grad=True)
Parameter containing:
tensor([[ 0.4370, -0.6806],
        [-0.4628, -0.4366]], dtype=torch.float64, requires_grad=True)

 

7)

forward(*input)

定义每次调用时执行的计算。
应该被所有子类覆盖。

⚠️虽然需要在这个函数中定义前向传播的配方,但是应该在之后调用模块实例,而不是这个来调用;因为前者负责运行已注册的钩子,而后者则默默忽略它们。

这个函数就是我们在定义一个模块时定义的那个函数:

def forward(self, x):

当你调用模型时,该函数就会被调用:

import torchvision.models as models

alexnet = models.alexnet()
output  = alexnet(input_data) #此时就会调用该forward()函数

 

8)

apply(fn)

递归地将函数fn应用到每个子模块(调用.children()方法返回的模块即子模块)和它自己上。典型地就是在初始化模块的参数时使用(在torch-nn-init中可见)

参数:

  • fn (Module -> None):应用到每个子模块上的函数

返回:

  • self

返回类型:

  • Module

例子:可见pytorch对模型参数初始化

 

9)

named_parameters(prefix='', recurse=True)

返回一个模型参数的迭代器,返回值包含参数的名字和参数本身

参数:

  • prefix (str) – 添加到所有参数名字前面的前缀.

  • recurse (bool) – 如果设置为真,则递归获取该模块及其子模块参数;如果为False,则仅得到本模块参数

上面的例子就有使用,从返回结果可知我们能直接使用名字来获得参数值:

e.models.Conv2_3_64.weight.data

返回:

tensor([[[[ 1.8686e-02, -1.1276e-02,  1.0743e-02, -3.7258e-03],
          [ 1.7356e-02, -4.6002e-03, -1.5800e-02,  1.4272e-03],
          [-8.9406e-03,  2.8417e-02,  7.3844e-03, -2.0131e-02],
          [ 2.7378e-02, -1.3940e-02, -9.2417e-03, -1.3656e-02]],

         [[-2.6638e-02,  2.6307e-02, -2.9532e-02,  2.6932e-02],
          [-7.9886e-03,  3.4983e-03, -5.5121e-02,  1.8271e-02],
          [-4.3825e-02,  4.7733e-02, -3.5117e-02, -1.0677e-02],
          [-2.6437e-02, -4.5605e-03,  1.1901e-02, -1.9924e-02]],

         [[ 1.2108e-02, -2.0034e-02, -4.3065e-02, -4.4073e-03],
          [ 2.4294e-02,  2.0997e-04,  2.0511e-02,  4.0354e-02],
          [-7.4128e-03,  1.2180e-02,  2.1586e-02, -3.2092e-02],
          [-1.0036e-02, -1.3512e-02,  2.8016e-03,  1.7150e-02]]],


        [[[ 1.3010e-02, -7.7286e-03, -1.8568e-02,  2.6519e-03],
          [ 1.7086e-02, -3.7209e-03,  1.2222e-02, -9.8183e-03],
          [-1.2987e-02, -1.5011e-02,  1.0018e-02, -1.8424e-02],
          [-9.8759e-03,  3.1524e-03,  1.8473e-04,  3.0876e-02]],

         [[ 1.1653e-02, -3.5415e-02, -3.7799e-02,  1.5948e-02],
          [ 1.5886e-02, -2.0727e-02,  9.9321e-03, -2.6632e-02],
          [-1.3989e-02, -2.2149e-02, -1.6303e-02, -6.1840e-03],
          [-3.0577e-02, -8.2477e-03,  3.2550e-02,  3.0350e-02]],

         [[ 4.9647e-05,  2.5028e-02,  5.4636e-03, -2.2217e-02],
          [-1.7287e-02, -9.8452e-03, -2.1045e-02,  5.6478e-03],
          [ 9.7147e-03,  2.0614e-02, -1.5295e-02,  3.4130e-02],
          [ 4.1918e-02, -3.1760e-02,  7.8219e-03,  5.0951e-03]]],


        [[[-1.5743e-02,  3.2101e-02, -5.7166e-03,  3.7152e-02],
          [-8.6509e-03, -2.9025e-02,  1.2311e-02,  4.1298e-02],
          [ 1.3912e-02, -2.6538e-02,  1.2670e-02, -2.8338e-02],
          [ 1.7593e-04,  5.0950e-03, -3.0340e-02,  2.1955e-03]],

         [[ 4.7826e-03,  1.9481e-02,  5.3423e-03, -1.2969e-02],
          [ 5.1746e-03, -3.3188e-03, -2.3011e-02,  3.4073e-02],
          [ 1.5636e-02, -5.5335e-02,  1.1528e-03, -1.3905e-02],
          [ 9.9208e-03, -8.0908e-03, -9.8275e-03, -2.1614e-02]],

         [[ 9.2276e-03, -7.6164e-03,  8.6449e-03, -5.7667e-03],
          [ 2.2497e-02, -2.6568e-02,  2.9182e-02,  1.0791e-02],
          [ 2.8791e-02, -3.9055e-02,  4.0457e-04, -2.1397e-03],
          [-4.0300e-03, -2.0704e-03, -1.7246e-02,  3.2432e-02]]],


        ...,


        [[[ 1.7486e-02,  1.1616e-02, -1.2516e-02, -9.7095e-03],
          [-1.2367e-02,  3.0512e-02,  5.0169e-02,  1.1539e-02],
          [ 1.6477e-04,  2.5155e-03, -3.5218e-02, -1.3211e-02],
          [-1.3205e-02,  1.0017e-02,  4.2839e-02, -6.9317e-03]],

         [[-1.2817e-02,  3.1915e-02,  7.9632e-03, -6.4066e-03],
          [ 3.8245e-02,  1.1355e-02,  1.5460e-02, -1.1245e-03],
          [ 2.1138e-02, -2.4878e-03,  3.1970e-03,  4.2895e-02],
          [-2.4187e-02, -4.8445e-04, -2.5516e-02,  4.0083e-02]],

         [[ 2.0978e-02, -1.5094e-02,  3.0770e-02,  2.5550e-02],
          [ 8.2029e-03,  1.4726e-03,  1.2099e-02, -2.1542e-02],
          [ 6.7198e-03, -1.7803e-02, -4.8138e-03, -1.2432e-02],
          [-3.7668e-03, -1.9681e-02, -2.0834e-03,  8.3174e-04]]],


        [[[ 3.1066e-03, -1.3706e-02,  9.3733e-03,  1.2344e-02],
          [ 1.6753e-02,  1.4869e-03, -2.0681e-03, -8.8953e-03],
          [-3.0745e-02,  1.1374e-02,  2.1523e-02, -2.4726e-02],
          [ 1.0182e-02,  2.0394e-02,  5.5662e-04,  2.0951e-02]],

         [[ 2.1782e-02,  6.3107e-04,  1.6017e-02,  2.7767e-03],
          [ 7.6418e-03, -8.8861e-03, -2.2702e-02, -1.9778e-02],
          [ 2.2941e-02,  4.4974e-03, -2.7368e-02, -9.5090e-05],
          [ 3.2708e-02, -3.3382e-03,  1.5445e-02, -1.7446e-02]],

         [[ 1.5597e-02, -3.0816e-02,  1.4011e-02, -2.7484e-02],
          [ 2.3591e-03,  4.3519e-02, -1.3367e-02,  1.3066e-02],
          [-7.6286e-03, -4.7996e-03,  5.1619e-03, -1.1260e-02],
          [-1.5147e-02,  1.2956e-02, -2.5945e-02,  2.2437e-02]]],


        [[[ 2.1797e-02,  2.7596e-03, -2.0974e-02, -4.3435e-03],
          [ 4.6751e-03, -4.2520e-02, -1.0819e-02,  7.4361e-03],
          [ 4.7468e-02, -2.4098e-02,  7.5113e-04, -2.3566e-02],
          [ 1.6562e-03,  1.5573e-02,  1.5934e-02,  1.9551e-02]],

         [[ 1.7714e-02,  1.6497e-02,  1.9895e-02, -1.3463e-02],
          [ 1.6372e-02, -1.3358e-02,  2.0040e-02, -4.1047e-02],
          [-3.9821e-03,  1.3126e-02, -1.4217e-02,  5.7594e-03],
          [-2.2151e-02, -1.7522e-02,  2.9157e-03,  2.4983e-02]],

         [[-2.5523e-02,  1.2045e-02,  2.9011e-03, -1.2715e-02],
          [ 2.8795e-02, -2.6586e-02,  1.8300e-02,  3.7996e-02],
          [ 1.2800e-02, -1.6446e-02, -5.4592e-03, -1.6855e-02],
          [-4.6871e-02,  3.9172e-02,  2.6660e-02, -3.2577e-02]]]])
View Code

相关文章:

  • 2022-12-23
  • 2022-12-23
  • 2021-12-22
  • 2022-12-23
  • 2021-11-29
  • 2021-10-23
  • 2022-12-23
  • 2022-12-23
猜你喜欢
  • 2021-07-24
  • 2021-04-13
  • 2021-04-10
  • 2021-08-04
  • 2022-03-03
  • 2021-10-12
  • 2021-08-04
相关资源
相似解决方案