【发布时间】:2021-08-30 19:10:43
【问题描述】:
我对面向对象编程比较陌生,目前正在从事生成对抗网络项目。我遇到了 maxout 激活函数。该函数是通过 maxout 类定义的。请看下面的代码:
class maxout(torch.nn.Module):
def __init__(self, num_pieces):
super(maxout, self).__init__()
self.num_pieces = num_pieces
def forward(self, x):
assert x.shape[1] % self.num_pieces == 0 # 625 % 5 = 0
ret = x.view(*x.shape[:1], # batch_size
x.shape[1] // self.num_pieces,
self.num_pieces, # num_pieces
*x.shape[2:] )
ret, _ = ret.max(dim=2)
return ret
这个 maxout 函数后来在鉴别器类中使用。以下是判别器类的代码。
class discriminator(torch.nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.fcn = torch.nn.Sequential(
# Fully connected layer 1
torch.nn.Linear(
in_features = 784,
out_features=240,
bias = True
),
maxout(5),
# Fully connected layer 2
torch.nn.Linear(
in_features = 48,
out_features=1,
bias = True
) )
def forward(self, batch):
inputs = batch.view(batch.size(0), -1)
outputs = self.fcn(inputs)
outputs = outputs.mean(0)
return outputs.view(1) # it will return a single value
代码运行良好,但根据我对面向对象编程的天真理解,maxout 类的 forward() 函数中的“x”值应通过 init() 函数提供。
我的问题是:maxout 类的 forward() 函数如何接收输入“x”,而不通过 init() 函数获取输入。
这个问题的另一种说法是:判别器类中线性层的输出如何作为'x'传递给maxout函数?
【问题讨论】:
-
"代码运行良好,但根据我对面向对象编程的天真理解,maxout 类的 forward() 函数中的 'x' 值应该通过 init() 函数提供。不,它是
forward方法的参数。 -
discriminator.forward具有相同的参数;它被命名为batch而不是x,但使用方式完全相同。
标签: python oop object-oriented-analysis