【发布时间】:2018-08-03 13:05:39
【问题描述】:
我正在尝试实施受 Xception 想法启发的 NN。 无法理解我的模型有什么问题...
local torch = require 'torch'
local nn = require 'nn'
dofile('GlobalAveragePooling.lua')
local model = nn.Sequential()
-- Entry convolution
model:add( nn.SpatialConvolution(3, 64, 3, 3, 2, 2, 1, 1) )
model:add( nn.SpatialBatchNormalization(64) )
model:add( nn.ReLU() )
-- Xception Unit with "skip-path"
local seq = nn.Sequential()
seq:add( nn.SpatialDepthWiseConvolution(64, 1, 3, 3, 1, 1, 1, 1) )
seq:add( nn.SpatialConvolution(64, 512, 1, 1, 1, 1, 0, 0) )
seq:add( nn.SpatialBatchNormalization(512) )
seq:add( nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1) )
local con = nn.ConcatTable()
con:add( nn.SpatialConvolution(64, 512, 1, 1, 2, 2, 0, 0) )
con:add( seq )
model:add( con )
model:add( nn.CAddTable() )
model:add( nn.ReLU() )
-- Exit fully-connected layers for softmax(3) output
model:add( nn.GlobalAveragePooling() )
model:add( nn.Reshape(512) )
model:add( nn.Linear(512, 3) )
model:add( nn.LogSoftMax() )
print(tostring(model))
local X = torch.randn(10, 3, 16, 8)
local Y = torch.LongTensor(10):random(1,3)
local criterion = nn.ClassNLLCriterion()
local Yhat = model:forward(X)
local loss = criterion:forward(Yhat, Y)
local gradLoss = criterion:backward(Yhat, Y)
model:backward(X, gradLoss)
该模型在 forward() 步骤中效果很好。 但是当涉及到模型时失败:backward(X, gradLoss) 错误:
/nn/THNN.lua:110: Need gradOutput of dimension 5 and gradOutput.size[3] == 8 but got gradOutput to be of shape: [10 x 64 x 1 x 4 x 8] at ../THNN/generic/SpatialDepthWiseConvolution.c:53
stack traceback:
[C]: in function 'v' ../nn/THNN.lua:110: in function 'SpatialDepthWiseConvolution_updateGradInput' ../nn/SpatialDepthWiseConvolution.lua:80:
in function 'updateGradInput' ../Module.lua:31:
in function <../nn/Module.lua:29>
[C]: in function 'xpcall' ../nn/Container.lua:63:
in function 'rethrowErrors' ../nn/Sequential.lua:88:
in function <../nn/Sequential.lua:78>
[C]: in function 'xpcall' ../Container.lua:63:
in function 'rethrowErrors' ../nn/ConcatTable.lua:66:
in function <../ConcatTable.lua:30>
[C]: in function 'xpcall' ../nn/Container.lua:63:
in function 'rethrowErrors' ../nn/Sequential.lua:84:
in function 'backward' test.lua:45:
in main chunk [C]: at 0x00405d50
【问题讨论】:
-
经过一番调查,我将问题缩小到了 Torch 中的 SpatialDepthWiseConvolution 实现。在此处查看问题:github.com/torch/nn/issues/1307
标签: lua neural-network deep-learning torch