写了个demo,只有两个文件。 

定义了一个叫 Square 的类,相当于pytorch中的 Function,包含了forward和backward方法。实现 y = a*x^2。

Square.m

 1 classdef Square < handle
 2     properties
 3         input;
 4         a;
 5         
 6         grad_input;
 7         d_a;
 8     end
 9     
10     methods
11         function self = Square(a)
12             self.a = a;
13             self.grad_input = 0;
14             self.d_a = 0;
15         end
16         
17         function out = forward(self, input)
18             out = self.a*input.^2; % y = a*x^2  这个相当于model
19             self.input = input;
20         end
21         
22         function grad_input = backward(self, grad_output)
23             % dy = 2*a*dx 以x为变量,根据dy 求 dx   
24             % dy = 2*x*da 以待训练权重a为变量,根据dy 求 d_a,不求出d_a来,
25             % a_iter是没法更新的。
26             % 
27             % 之前没有自己实现前后向的计算,真的没意识到要求两次微分。
28             grad_input = grad_output/(2*self.a);
29             self.d_a = grad_output/(2*self.input);
30         end
31     
32     end
33 end
View Code

相关文章: