原创文章~转载请注明出处哦。其他部分内容参见以下链接~

GraphSAGE 代码解析(一) - unsupervised_train.py

GraphSAGE 代码解析(二) - layers.py

GraphSAGE 代码解析(三) - aggregators.py

1. 类及其继承关系

     Model 
     /   \
    /     \
  MLP   GeneralizedModel
          /  \
         /    \
Node2VecModel  SampleAndAggregate

首先看Model, GeneralizedModel, SampleAndAggregate这三个类的联系。

其中Model与 GeneralizedModel的区别在于,Model的build()函数中搭建了序列层模型,而在GeneralizedModel中被删去。self.ouput必须在GeneralizedModel的子类build()中被赋值。

class Model(object) 中的build()函数如下:

 1 def build(self):
 2     """ Wrapper for _build() """
 3     with tf.variable_scope(self.name):
 4         self._build()
 5 
 6     # Build sequential layer model
 7     self.activations.append(self.inputs)
 8     for layer in self.layers:
 9         hidden = layer(self.activations[-1])
10         self.activations.append(hidden)
11     self.outputs = self.activations[-1]
12     # 这部分sequential layer model模型在GeneralizedModel的build()中被删去
13 
14     # Store model variables for easy access
15     variables = tf.get_collection(
16         tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
17     self.vars = {var.name: var for var in variables}
18 
19     # Build metrics
20     self._loss()
21     self._accuracy()
22 
23     self.opt_op = self.optimizer.minimize(self.loss)
View Code

相关文章:

  • 2022-01-04
  • 2021-05-23
  • 2021-07-04
  • 2022-12-23
  • 2022-12-23
  • 2021-06-16
  • 2021-12-19
猜你喜欢
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2021-09-12
  • 2021-07-23
  • 2022-12-23
相关资源
相似解决方案