SethDeng

FSL-GNN代码解读

main.py(主函数)

1、加载数据集:

train_loader = generator.Generator(args.dataset_root, args, partition=\'train\', dataset=args.dataset)

2、初始化或加载模型:

enc_nn = models.load_model(\'enc_nn\', args, io)
metric_nn = models.load_model(\'metric_nn\', args, io)

if enc_nn is None or metric_nn is None:
	enc_nn, metric_nn = models.create_models(args=args)
    softmax_module = models.SoftmaxModule()

models.create_models(args=args) : in models.py

def create_models(args):
    print (args.dataset)
    if \'omniglot\' == args.dataset:
        enc_nn = EmbeddingOmniglot(args, 64)
    elif \'mini_imagenet\' == args.dataset:
        enc_nn = EmbeddingImagenet(args, 128)
    else:
        raise NameError(\'Dataset \' + args.dataset + \' not knows\')
    return enc_nn, MetricNN(args, emb_size=enc_nn.emb_size)

class EmbeddingOmniglot():				# 特征提取
class EmbeddingImagenet():				# 略

class MetricNN(nn.Module):
	if self.metric_network == \'gnn_iclr_nl\':……		# 正常的网络
	self.gnn_obj = gnn_iclr.GNN_nl()			# in gnn_iclr.py
	
	elif self.metric_network == \'gnn_iclr_active\':……	# 主动学习
	self.gnn_obj = gnn_iclr.GNN_active()# in gnn_iclr.py
	
class SoftmaxModule():		# 线性分类

class GNN_nl(nn.Module) & class GNN_active(nn.Module) : in gnn_iclr.py

class GNN_nl(nn.Module):		# 图网络主要部分
	class Wcompute(nn.Module)	# W邻接矩阵计算
    class Gconv(nn.Module)		# 组图
		def gmul(input)		# 更新图节点特征,W直接返回

3、训练

# 权重衰减
weight_decay = 1e-6

# 优化器
opt_enc_nn = optim.Adam(enc_nn.parameters(), lr=args.lr, weight_decay=weight_decay)
opt_metric_nn = optim.Adam(metric_nn.parameters(), lr=args.lr, weight_decay=weight_decay)

# 梯度置零,也就是把loss关于weight的导数变成0
opt_enc_nn.zero_grad()
opt_metric_nn.zero_grad()

# 训练
loss_d_metric = train_batch(
	model=[enc_nn, metric_nn, 
	softmax_module],
	data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels])

# 更新参数
opt_enc_nn.step()
opt_metric_nn.step()

# 自适应参数
adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn], lr=args.lr, iter=batch_idx)

# 显示训练中loss的更新
if batch_idx % args.log_interval == 0:
	display_str = \'Train Iter: {}\'.format(batch_idx)
	display_str += \'\tLoss_d_metric: {:.6f}\'.format(total_loss/counter)
	io.cprint(display_str)

# 测试
def test_one_shot(args, model, test_samples=5000, partition=\'test\') 定义于 test.py 中
val_acc_aux = test.test_one_shot	# 验证集上测试
test_acc_aux = test.test_one_shot	# 测试集上测试
test.test_one_shot(					# 训练集上测试
	args, 
	model=[enc_nn, metric_nn, softmax_module],
	test_samples=test_samples, 
	partition=\'train\')				

# 测试完毕,将模型设置回训练状态
enc_nn.train()
metric_nn.train()

# 若在验证集上的效果继续变好,则更新
if val_acc_aux is not None and val_acc_aux >= val_acc:

# 保存模型
torch.save(enc_nn, \'checkpoints/%s/models/enc_nn.t7\' % args.exp_name)
torch.save(metric_nn, \'checkpoints/%s/models/metric_nn.t7\' % args.exp_name)

# 全部训练完毕后进行测试
test.test_one_shot

分类:

技术点:

相关文章:

  • 2021-07-04
  • 2021-08-31
  • 2021-10-28
  • 2022-01-21
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2021-04-17
猜你喜欢
  • 2022-12-23
  • 2022-12-23
  • 2021-11-11
  • 2022-01-04
  • 2021-04-03
  • 2021-05-21
  • 2022-01-06
相关资源
相似解决方案