写了个demo,只有两个文件。
定义了一个叫 Square 的类,相当于pytorch中的 Function,包含了forward和backward方法。实现 y = a*x^2。
Square.m
1 classdef Square < handle 2 properties 3 input; 4 a; 5 6 grad_input; 7 d_a; 8 end 9 10 methods 11 function self = Square(a) 12 self.a = a; 13 self.grad_input = 0; 14 self.d_a = 0; 15 end 16 17 function out = forward(self, input) 18 out = self.a*input.^2; % y = a*x^2 这个相当于model 19 self.input = input; 20 end 21 22 function grad_input = backward(self, grad_output) 23 % dy = 2*a*dx 以x为变量,根据dy 求 dx 24 % dy = 2*x*da 以待训练权重a为变量,根据dy 求 d_a,不求出d_a来, 25 % a_iter是没法更新的。 26 % 27 % 之前没有自己实现前后向的计算,真的没意识到要求两次微分。 28 grad_input = grad_output/(2*self.a); 29 self.d_a = grad_output/(2*self.input); 30 end 31 32 end 33 end