【问题标题】:how to implement Grad-CAM on your own network?如何在自己的网络上实施 Grad-CAM?
【发布时间】:2021-02-08 20:31:53
【问题描述】:

我想在自己的网络上实现 Grad-CAM,我应该保存我的模型并加载它,然后像 VGG-16 一样对待我保存的模型,然后做类似的操作吗?

我尝试在网上搜索,发现所有方法都是基于著名的模型,而不是他们自己的。

所以我想知道,也许我只需要将自己的模型视为 VGG-16,然后做类似的事情。

【问题讨论】:

  • 太忙了,没时间跟进,但我只是这样做了。在向前或向后运行之前,访问您要在其上应用 GradCam 的层,例如使用 c = list(self.model.children())[-3][2].conv3 for resnet。 c 上的 apply 前向和后向钩子,它存储 `def hook_feature(module, input, output): self.features = output.clone().detach()` 和 `def hook_gradient(module, grad_in, grad_out): self.gradients = grad_out[0].clone().detach()`
  • 稍后将添加此作为正确答案。请参阅此代码:-github.com/utkuozbulak/pytorch-cnn-visualizations/blob/master/…。这个仅在有一个顺序块的情况下适用于 VGG 类型的网络。与具有嵌套层块的 resnet 不同。编辑此代码以添加您的自定义图层选择并使用前进和后退挂钩而不是register_hook,您就完成了。

标签: python-3.x pytorch


【解决方案1】:

嗨,我在 pytorch 中有一个解决方案

import torch
import torch.nn as nn
from torch.utils import data
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np

# use the ImageNet transformation
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# define a 1 image dataset
dataset = datasets.ImageFolder(root='./data/Elephant/', transform=transform)

# define the dataloader to load that single image
dataloader = data.DataLoader(dataset=dataset, shuffle=False, batch_size=1)

vgg19 = Mymodel() ## create an object of your model
vgg19.load_state_dict(torch.load("your_vgg19_weights"))
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        # get the pretrained VGG19 network
        self.vgg = vgg19
        
        # disect the network to access its last convolutional layer
        self.features_conv = self.vgg.features[:36]  # 36th layer was my last conv layer
        
        # get the max pool of the features stem
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        # get the classifier of the vgg19
        self.classifier = self.vgg.classifier
        
        # placeholder for the gradients
        self.gradients = None
    
    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        x = self.features_conv(x)
        
        # register the hook
        h = x.register_hook(self.activations_hook)
        
        # apply the remaining pooling
        x = self.max_pool(x)
        x = x.view((1, -1))
        x = self.classifier(x)
        return x
    
    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients
    
    # method for the activation exctraction
    def get_activations(self, x):
        return self.features_conv(x)

vgg = VGG()

# set the evaluation mode
vgg.eval()

# get the image from the dataloader
img, _ = next(iter(dataloader))

# get the most likely prediction of the model
pred_class = vgg(img).argmax(dim=1).numpy()[0]
pred = vgg(img)

pred[:, pred_class].backward()

# pull the gradients out of the model
gradients = vgg.get_activations_gradient()

# pool the gradients across the channels
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# get the activations of the last convolutional layer
activations = vgg.get_activations(img).detach()

# weight the channels by corresponding gradients
for i in range(512):
    activations[:, i, :, :] *= pooled_gradients[i]
    
# average the channels of the activations
heatmap = torch.mean(activations, dim=1).squeeze()

# relu on top of the heatmap
# expression (2) in https://arxiv.org/pdf/1610.02391.pdf
heatmap = np.maximum(heatmap, 0)

# normalize the heatmap
heatmap /= torch.max(heatmap)
heatmap = heatmap.numpy()

import cv2
img = cv2.imread('./data/Elephant/data/05fig34.jpg')
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap * 0.4 + img
cv2.imwrite('./map.jpg', superimposed_img)  ###saves gradcam visualization image

【讨论】:

    猜你喜欢
    • 2021-05-16
    • 2021-06-28
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-01-07
    • 1970-01-01
    • 2020-07-18
    相关资源
    最近更新 更多