这是我目前的解决方案。它的工作原理是将输出拆分为一个表格,然后使用 nn.SelectTable():backward() 获得完整的渐变:
require 'nn'
-- the input
local batch_sz = 2
local x = torch.Tensor(batch_sz, 3, 100, 100):uniform(-1,1)
-- the model
local net = nn.Sequential()
net:add(nn.SpatialConvolution(3, 128, 9, 9, 9, 9, 1, 1))
net:add(nn.SpatialConvolution(128, 1, 3, 3, 3, 3, 1, 1))
net:add(nn.Squeeze(1, 3))
-- convert output into a table format
net:add(nn.View(1, -1)) -- vectorize
net:add(nn.SplitTable(1, 1)) -- split all outputs into table elements
print(net)
-- the loss
local loss = nn.SmoothL1Criterion()
-- forward'ing x through the network would result in a (2)x4x4 output
y = net:forward(x)
print(y)
-- returns the output table's index belonging to specific location
function get_sample_idx(feat_h, feat_w, smpl_idx, feat_r, feat_c)
local idx = (smpl_idx - 1) * feat_h * feat_w
return idx + feat_c + ((feat_r - 1) * feat_w)
end
-- I want to back-propagate the loss of this sample at this feature location
local smpl_idx = 2
local feat_r = 3
local feat_c = 4
-- get the actual index location in the output table (for a 4x4 output feature map)
local out_idx = get_sample_idx(4, 4, smpl_idx, feat_r, feat_c)
-- the (fake) ground-truth
local gt = torch.rand(1)
-- compute loss on the selected feature map location for the selected sample
local err = loss:forward(y[out_idx], gt)
-- compute loss gradient, as if there was only this one location
local dE_dy = loss:backward(y[out_idx], gt)
-- now convert into full loss gradient (zero'ing out irrelevant losses)
local full_dE_dy = nn.SelectTable(out_idx):backward(y, dE_dy)
-- do back-prop through who network
net:backward(x, full_dE_dy)
print("The full dE/dy")
print(table.unpack(full_dE_dy))
如果有人指出一种更简单或更有效的方法,我将不胜感激。