原创文章~转载请注明出处哦。其他部分内容参见以下链接~
GraphSAGE 代码解析(一) - unsupervised_train.py
GraphSAGE 代码解析(三) - aggregators.py
1. 类及其继承关系
Model / \ / \ MLP GeneralizedModel / \ / \ Node2VecModel SampleAndAggregate
首先看Model, GeneralizedModel, SampleAndAggregate这三个类的联系。
其中Model与 GeneralizedModel的区别在于,Model的build()函数中搭建了序列层模型,而在GeneralizedModel中被删去。self.ouput必须在GeneralizedModel的子类build()中被赋值。
class Model(object) 中的build()函数如下:
1 def build(self): 2 """ Wrapper for _build() """ 3 with tf.variable_scope(self.name): 4 self._build() 5 6 # Build sequential layer model 7 self.activations.append(self.inputs) 8 for layer in self.layers: 9 hidden = layer(self.activations[-1]) 10 self.activations.append(hidden) 11 self.outputs = self.activations[-1] 12 # 这部分sequential layer model模型在GeneralizedModel的build()中被删去 13 14 # Store model variables for easy access 15 variables = tf.get_collection( 16 tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) 17 self.vars = {var.name: var for var in variables} 18 19 # Build metrics 20 self._loss() 21 self._accuracy() 22 23 self.opt_op = self.optimizer.minimize(self.loss)