【问题标题】:question with fast.ai course lesson 8 g attribute具有fast.ai课程第8课g属性的问题
【发布时间】:2020-04-25 16:42:05
【问题描述】:

在fast.ai 2019第8课的课程中,反向传播中使用了一个奇怪的g属性,我检查了torch.Tensor这个属性不存在。我试图在调用方法中打印 inp.g/out.g 的值,但我得到 AttributeError: 'Tensor' object has no attribute 'g',但我能够在调用之前获得 inp.g/out.g 值向后赋值,这个 g 属性是如何工作的?

class Linear():
    def __init__(self, w, b):
        self.w, self.b = w, b

    def __call__(self, inp):
        print('in lin call')
        self.inp = inp
        self.out = inp@self.w + self.b
        try:
            print('out.g', self.out.g)
        except Exception as e:
            print('out.g dne yet')
        return self.out

    def backward(self):
        print('out.g', self.out.g)
        self.inp.g = self.out.g @ self.w.t()
        self.w.g = (self.inp.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)
        self.b.g = self.out.g.sum(0)

link to full code from the course

-更新-

我能够弄清楚 self.out.g 的值与成本函数 MSE self.inp.g 完全相同,但仍然无法弄清楚该值是如何传递到最后一个线性层的。

class MSE():
    def __call__(self, inp, targ):
        self.inp = inp
        self.targ = targ
        self.out = (inp.squeeze() - targ).pow(2).mean()
        return self.out

    def backward(self):
        self.inp.g = 2. * (self.inp.squeeze() - self.targ).unsqueeze(-1) \
                        / self.targ.shape[0]
        print('in mse backward', self.inp.g)

class Model():
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Lin(w1, b1), Relu(), Lin(w2, b2)]
        self.loss = Mse()

    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return self.loss(x, targ)

    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

【问题讨论】:

    标签: python pytorch fast-ai


    【解决方案1】:

    基本上这必须处理 python 赋值的工作方式(指针,类似于 C 指针的工作方式)。在使用 id(variable name) 跟踪变量后,我能够弄清楚 g 属性是如何产生的。

    # ... in model (forward pass)...
        x = layer(x) # from linear layer >> return self.out and is assigned to x
    
    # ...
        return self.loss(x, targ) # x is the same x (id) obtained from the model
    
    # ========
    
    # ... in model (backward pass) ...
        self.loss.backward() # this is how the self.inp.g came by 
    
    # ... in linear ...
        self.inp.g = self.out.g @ self.w.t() 
        # this self.out.g is the same instance as self.inp.g from loss 
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2014-09-29
      • 1970-01-01
      相关资源
      最近更新 更多