【问题标题】:What is self referring to in this PyTorch derived nn.Module class method?这个 PyTorch 派生的 nn.Module 类方法中的 self 指的是什么?
【发布时间】:2021-12-30 18:42:37
【问题描述】:

我正在关注 Pytorch 的 tutorial,在 nn.Module 类的派生类 MnistModule 方法 training_step 中有一行代码对我来说毫无意义。

这条线是 out = self(images)

请有人向我解释一下这里发生了什么?这是否正确,如果这是要遵循的约定。

谢谢

这是sn-p


class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(input_size, num_classes)
        
    def forward(self, xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return out
    
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        print(type(out))
        return loss

【问题讨论】:

标签: python class pytorch


【解决方案1】:

它引用MnistModel 的实例,与该类定义的任何其他方法相同。唯一奇怪的是self调用,但这可以通过nn.Module 定义__call__ 来解释,所以MnistModel 的所有实例本身都是可调用的。

out = self(images) 等价于out = self.__call__(images)

【讨论】:

    猜你喜欢
    • 2020-08-18
    • 2021-11-07
    • 2013-01-12
    • 2018-08-10
    • 2019-12-06
    • 2012-07-13
    • 1970-01-01
    • 2017-08-28
    • 1970-01-01
    相关资源
    最近更新 更多