【发布时间】:2021-03-06 20:15:04
【问题描述】:
在 PyTorch 中定义我们的模型架构时,我们需要指定 CNN 输出层的大小以馈送到nn.Linear 层。我们如何在def __init__函数中找到该层的大小(不在def forward()中)
class model(nn.Module):
def __init__(self,word_count,img_channel,n_out):
super(multimodal,self).__init__()
# CNN image encoding hyperparameters
conv1_channel_out = 8
conv1_kernel = 5
pool1_size = 2
conv2_channel_out = 16
conv2_kernel = 16
pool2_size = 2
conv3_channel_out = 32
conv3_kernel = 4
dropout_rate = 0.1
cnn_fc_out = 512
comb_fc1_out = 512
comb_fc2_out = 128
# FNN text encoding hyperparameters
text_fc1_out = 4096
text_fc2_out = 512
# Text encoding
self.text_fc1 = nn.Linear(word_count, text_fc1_out)
self.text_fc2 = nn.Linear(text_fc1_out, text_fc2_out)
# Image encoding
self.conv1 = nn.Conv2d(img_channel, conv1_channel_out, conv1_kernel)
self.max_pool1 = nn.MaxPool2d(pool1_size)
self.conv2 = nn.Conv2d(conv1_channel_out, conv2_channel_out, conv2_kernel)
self.max_pool2 = nn.MaxPool2d(pool2_size)
self.conv3 = nn.Conv2d(conv2_channel_out, conv3_channel_out, conv3_kernel)
self.cnn_dropout = nn.Dropout(dropout_rate)
self.cnn_fc = nn.Linear(32*24*12, cnn_fc_out)
#Concat layer
concat_feat = cnn_fc_out + text_fc2_out
self.combined_fc1 = nn.Linear(concat_feat, comb_fc1_out)
self.combined_fc2 = nn.Linear(comb_fc1_out, comb_fc2_out)
self.output_fc = nn.Linear(comb_fc2_out, n_out)
def forward(self, text, img):
# Image Encoding
x = F.relu(self.conv1(img))
x = self.max_pool1(x)
x = F.relu(self.conv2(x))
x = self.max_pool2(x)
x = F.relu(self.conv3(x))
x = x.view(-1, 32*24*12)
x = self.cnn_dropout(x)
img = F.relu(self.cnn_fc(x))
# Text Encoding
text = F.relu(self.text_fc1(text))
text = F.relu(self.text_fc2(text))
# Concat the features
concat_inp = torch.cat((text, img), 1)
out = F.relu(self.combined_fc1(concat_inp))
out = F.relu(self.combined_fc2(out))
return torch.sigmoid(self.output_fc(out))
如你所见,我手动将 CNN 输出层的大小定义为 322412
self.cnn_fc = nn.Linear(32*24*12, cnn_fc_out)
我怎样才能避免这种情况?我知道我们可以在def forward() 中调用[model_name].[layer_name].in_features,但不能在def __init__() 中调用
【问题讨论】:
标签: python machine-learning deep-learning neural-network pytorch