【发布时间】:2021-04-26 02:10:23
【问题描述】:
Sequential 块的示例代码是
self._encoder = nn.Sequential(
# 1, 28, 28
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=3, padding=1),
# 32, 10, 10 = 16, (1//3)(28 + 2 * 1 - 3) + 1, (1//3)(28 + 2*1 - 3) + 1
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2),
# 32, 5, 5
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
# 64, 3, 3
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=1),
# 64, 2, 2
)
是否有像nn.Sequential 这样的结构将模块并行放入其中?
我现在想定义类似的东西
self._mean_logvar_layers = nn.Parallel(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0),
)
其输出应该是两个数据管道 - self._mean_logvar_layers 中的每个元素一个管道,然后可馈送到网络的其余部分。有点像多头网络。
我目前的实现:
self._mean_layer = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0)
self._logvar_layer = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0)
和
def _encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
for i, layer in enumerate(self._encoder):
x = layer(x)
mean_output = self._mean_layer(x)
logvar_output = self._logvar_layer(x)
return mean_output, logvar_output
我想将并行构造视为一个层。
这在 PyTorch 中可行吗?
【问题讨论】:
-
并行是并行运行(同时)还是只输出两个值但按顺序工作?如果是第一种情况,这种并行化究竟意味着什么(因为您可以通过一些
.to调用和不同的设备轻松完成,请参阅here)。 -
@SzymonMaszke 我的意思是不要考虑后端技术问题。我想要类似于我当前实现的行为,但是将它包装在一个允许我在网络架构中定义“拆分”的构造中。最好并行运行会很好,但我对 GPU 并行最佳实践的了解要少得多,因此希望让 pytorch 为我做魔法。
-
只需使用一层,后跟
torch.split。请参阅此处:pytorch.org/docs/stable/generated/torch.split.html 在这种情况下,您将使用具有 128 个输出通道的 conv,并沿轴 1 分成两个大小为 64 的部分(假设为 channels_first 数据格式)。这可能不是一个通用解决方案,所以我不愿将其作为答案发布,但它适用于大多数常见情况,如 VAE。 -
感谢添加。我也可以尝试写一个答案,但我对 pytorch 不是很熟悉——我使用的是 tensorflow,这就是我要做的。
tf.split很好,因为它还允许您简单地指定拆分的 number (在本例中为 2)并计算输出大小。看起来火炬不支持这个。至于 Gulzars 最后的评论:实际上,一个大小为 128 的 FC 层在数学上与接收相同输入的两个大小为 64 的层的串联完全相同(卷积类似)。这就是为什么这样做“没问题”。 -
@Gulzar 你有两次
N输入和M输出(均值和方差)。每个M参见N。在N->2M的情况下,每个M也会看到N输入。和@xdurch0 说的完全一样,Ms 之间也没有联系。对于 FC 来说,它就像多类逻辑回归(没有激活),每个权重向量都独立于其他权重向量(与所指出的卷积相同)。
标签: python machine-learning deep-learning pytorch pytorch-lightning