【问题标题】:Python OOP issue: How a function in a class is getting input value, when it is not mentioned in __init__() function as a parameter?Python OOP 问题:当 __init__() 函数中未将其作为参数提及时,类中的函数如何获取输入值?
【发布时间】:2021-08-30 19:10:43
【问题描述】:

我对面向对象编程比较陌生,目前正在从事生成对抗网络项目。我遇到了 maxout 激活函数。该函数是通过 maxout 类定义的。请看下面的代码:

class maxout(torch.nn.Module):
    def __init__(self, num_pieces):
        
        super(maxout, self).__init__()
        
        self.num_pieces = num_pieces
        
    def forward(self, x):
        
        assert x.shape[1] % self.num_pieces == 0  # 625 % 5 = 0
        
        ret = x.view(*x.shape[:1], # batch_size
                      x.shape[1] // self.num_pieces, 
                      self.num_pieces, # num_pieces
                      *x.shape[2:]    )
                     
        ret, _ = ret.max(dim=2)
        
        return ret

这个 maxout 函数后来在鉴别器类中使用。以下是判别器类的代码。

class discriminator(torch.nn.Module):
    
    def __init__(self):
        
        super(discriminator, self).__init__()
        
        self.fcn = torch.nn.Sequential(
            # Fully connected layer 1
            torch.nn.Linear(
                in_features = 784,
                out_features=240,
                bias = True
            ),
            maxout(5),
            
            # Fully connected layer 2
            torch.nn.Linear(
                in_features = 48,
                out_features=1,
                bias = True
            )  )
        
    def forward(self, batch):
        inputs = batch.view(batch.size(0), -1)
        outputs = self.fcn(inputs)
        outputs = outputs.mean(0)
        return outputs.view(1)  # it will return a single value

代码运行良好,但根据我对面向对象编程的天真理解,maxout 类的 forward() 函数中的“x”值应通过 init() 函数提供。

我的问题是:maxout 类的 forward() 函数如何接收输入“x”,而不通过 init() 函数获取输入。

这个问题的另一种说法是:判别器类中线性层的输出如何作为'x'传递给maxout函数?

【问题讨论】:

  • "代码运行良好,但根据我对面向对象编程的天真理解,maxout 类的 forward() 函数中的 'x' 值应该通过 init() 函数提供。不,它是forward 方法的参数
  • discriminator.forward 具有相同的参数;它被命名为batch 而不是x,但使用方式完全相同。

标签: python oop object-oriented-analysis


【解决方案1】:

您将层传递给分配给self.fcnSequential 模型的构造函数,而不是discriminator.forward 中的您调用此模型。它是__call__ 方法,而不是调用它所包含的层的所有前向函数。

你可以想象这样的事情正在发生

...
def forward(self, batch):
   return torch.nn.Linear(
       in_features = 48,
       out_features=1,
       bias = True
   ).forward(
       maxout(5).forward(
           torch.nn.Linear(
                in_features = 784,
                out_features=240,
                bias = True
           ).forward(batch.view(batch.size(0), -1)
       )
    ).mean(0).view(1)

【讨论】:

  • 嗨,我明白了。
  • 关于 discriminator.forward 有点混乱。 call 方法是否会在 discriminator.forward 被调用时自动被调用(因为我没有明确提到 call 方法)。如果是,那么它背后的理论/程序是什么。
【解决方案2】:

方法与常规函数的不同之处仅在于它们的第一个参数是定义方法的类的对象,并且第一个参数可以用点运算符提供。 只有程序员决定通过__init__ 传递哪些值以将它们分配给对象属性,以及在第一个self 参数之后作为参数传递哪些值。不需要首先通过__init__ 方法将我们提供给其他方法的所有值传递给其他方法。

回答您的问题: torch.nn.Sequential 保存您传递的函数列表,在本例中,它的长度为 3,并输出一些可调用对象(可以像函数一样工作的对象)。你把它保存到self.fcn。然后,当您调用self.fcn 时,它会调用第一个torch.nn.Linearforward 方法,然后将其返回值传递给maxout(5) 对象的forward 方法——这就是x 的位置! - 然后将maxout(5)对象的输出传递给第二个torch.nn.Linearforward方法并返回结果。

【讨论】:

  • 嗨,我很清楚这个过程。如果能回答我心中的以下困惑,那就太好了。关于 discriminator.forward 有点混乱。调用 discriminator.forward 时是否会自动调用 call 方法(因为我没有明确提到 call 方法)。如果是,那么它背后的理论/程序是什么。
猜你喜欢
  • 1970-01-01
  • 2023-03-18
  • 1970-01-01
  • 2021-07-22
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多