【问题标题】:what is the difference of torch.nn.Softmax, torch.nn.funtional.softmax, torch.softmax and torch.nn.functional.log_softmaxtorch.nn.Softmax、tor​​ch.nn.Functional.softmax、tor​​ch.softmax 和 torch.nn.functional.log_softmax 有什么区别
【发布时间】:2021-11-11 23:17:43
【问题描述】:

我试图查找文档,但找不到有关 torch.softmax 的任何内容。

torch.nn.Softmax、tor​​ch.nn.functional.softmax、tor​​ch.softmax和torch.nn.functional.log_softmax有什么区别?

欢迎提供示例。

【问题讨论】:

    标签: python pytorch torch softmax


    【解决方案1】:
    import torch
    
    x = torch.rand(5)
    
    x1 = torch.nn.Softmax()(x)
    x2 = torch.nn.functional.softmax(x)
    x3 = torch.nn.functional.log_softmax(x)
    
    print(x1)
    print(x2)
    print(torch.log(x1))
    print(x3)
    
    tensor([0.2740, 0.1955, 0.1519, 0.1758, 0.2029])
    tensor([0.2740, 0.1955, 0.1519, 0.1758, 0.2029])
    tensor([-1.2946, -1.6323, -1.8847, -1.7386, -1.5952])
    tensor([-1.2946, -1.6323, -1.8847, -1.7386, -1.5952])
    

    torch.nn.Softmaxtorch.nn.functional.softmax 给出相同的输出,一个是类(pytorch 模块),另一个是函数。 log_softmax 应用 softmax 后应用 log。

    NLLLoss 以对数概率 (log(softmax(x))) 作为输入。因此,NLLLoss 需要 log_softmax,log_softmax 在数值上更稳定,通常会产生更好的结果。

    【讨论】:

      【解决方案2】:
      
      import torch
      import torch.nn as nn
      
      
      class Network(nn.Module):
          def __init__(self):
              super().__init__()
              self.layer_1 = nn.LazyLinear(128)
              self.activation = nn.ReLU()
              self.layer_2 = nn.Linear(128, 10)
              self.output_function = nn.Softmax(dim=1)
      
          def forward(self, x, softmax="module"):
              y = self.layer_1(x)
              y = self.activation(y)
              y = self.layer_2(y)
              if softmax == "module":
                  return self.output_function(y)
      
              # OR
              if softmax == "torch":
                  return torch.softmax(y, dim=1)
      
              # OR (deprecated)
              if softmax == "functional":
                  return nn.functional.softmax(y, dim=1)
      
              # OR (careful, the reason why the log is there is to ensure
              # numerical stability so you should use torch.exp wisely)
              if softmax == "log":
                  return torch.exp(torch.log_softmax(y, dim=1))
      
              raise ValueError(f"Unknown softmax type {softmax}")
      
      
      x = torch.rand(2, 2)
      net = Network()
      
      for s in ["module", "torch", "log"]:
          print(net(x, softmax=s))
      

      基本上nn.Softmax() 创建一个模块,所以它返回一个函数,而其他都是纯函数。

      为什么需要 log softmax? nn.Softmax 的文档中有一个例子:

      此模块不能直接与 NLLLoss 一起使用, 它期望在 Softmax 和自身之间计算 Log。 请改用LogSoftmax(它更快并且具有更好的数值属性)。

      另见What is the difference between log_softmax and softmax?

      【讨论】:

        猜你喜欢
        • 2018-10-12
        • 2015-10-15
        • 2015-04-22
        • 1970-01-01
        • 1970-01-01
        • 2014-09-08
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多