Grad-CAM Visualisation of ResNet-50’s Decision Pathways
click here to read this in medium
Tracing ResNet-50 Focus trands with Grad-CAM.
Understanding How AI Sees the World🌏
Imagine an AI as an artist trying to paint a picture but first needing to decide what part of a scene is worth focusing on. In this article, we explore a tool called Grad-CAM, which helps us visualize what catches the AI’s attention when it looks at an image. This tool is particularly useful for understanding complex image recognition models like ResNet-50, a type of deep neural network renowned for its accuracy in identifying objects in images.
Getting Started🙌🏻
To start, we need an image. Think of it as the scene our AI artist is going to paint. We use a standard JPEG image for this purpose. Our AI, powered by a model called ResNet-50, processes this image not just as a whole but looks deeply at various parts to decide what it sees.
Peeking Into the AI’s Mind (Grad-Cam)💡
To peek into what the AI is focusing on, we use Grad-CAM. This tool generates heatmaps that overlay on the original image. These heatmaps change color in areas where the AI is paying more attention. Thus, by looking at these heatmaps, we can understand which parts of the image are most important for the AI’s decision-making.
Step-by-Step Through the Code:-
.
├── ResNet50Vis.ipynb
├── sample.jpeg
├── result.gif(animated gif is here result)
└── images
└── (genarated images will be here)
Here is the file structure so we have only ResNet50Vis.ipynb and sample.jpeg files will be in the working directory.
Prepare Utils and Model Wrapping:
Here we are creating reusable
import warnings
warnings.filterwarnings('ignore')
from torchvision import transforms
from datasets import load_dataset
from pytorch_grad_cam import run_dff_on_image, GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import numpy as np
import cv2
import torch
import imageio
from typing import List, Callable, Optional
image = Image.open('./sample.jpeg')
img_tensor = transforms.ToTensor()(image)
class HuggingfaceToTensorModelWrapper(torch.nn.Module):
def __init__(self, model):
super(HuggingfaceToTensorModelWrapper, self).__init__()
self.model = model
def forward(self, x):
return self.model(x).logits
def category_name_to_index(model, category_name):
name_to_index = dict((v, k) for k, v in model.config.id2label.items())
return name_to_index[category_name]
def run_grad_cam_on_image(model: torch.nn.Module,
target_layer: torch.nn.Module,
targets_for_gradcam: List[Callable],
reshape_transform: Optional[Callable],
input_tensor: torch.nn.Module=img_tensor,
input_image: Image=image,
method: Callable=GradCAM):
with method(model=HuggingfaceToTensorModelWrapper(model),
target_layers=[target_layer],
reshape_transform=reshape_transform) as cam:
repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)
batch_results = cam(input_tensor=repeated_tensor,
targets=targets_for_gradcam)
results = []
for grayscale_cam in batch_results:
visualization = show_cam_on_image(np.float32(input_image)/255,
grayscale_cam,
use_rgb=True)
visualization = cv2.resize(visualization,
(visualization.shape[1]//2, visualization.shape[0]//2))
results.append(visualization)
return np.hstack(results)
def print_top_categories(model, img_tensor, top_k=5):
logits = model(img_tensor.unsqueeze(0)).logits
indices = logits.cpu()[0, :].detach().numpy().argsort()[-top_k :][::-1]
for i in indices:
print(f"Predicted class {i}: {model.config.id2label[i]}")
Target Identification :
now lets set the target label for which we will be genarating heat maps. As per our input image our target lable will cairn, cairn terrier
from transformers import ResNetForImageClassification
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
targets_for_gradcam = [ClassifierOutputTarget(category_name_to_index(model, "cairn, cairn terrier"))]
Running Grad-CAM:
Here we will be running gradCam on every stage of resnet50 and store the heatmap image to list_of_images.
list_of_images = []
for i in model.resnet.encoder.stages:
for j in i.layers:
target_layer = j
list_of_images.append(Image.fromarray(run_grad_cam_on_image(model=model,
target_layer=target_layer,
targets_for_gradcam=targets_for_gradcam,
reshape_transform=None)))
print_top_categories(model, img_tensor)
Visualization and Animation:
Great now can creat gif with our list of images.
image_files = []
for i, img in enumerate(list_of_images):
path = f'images/temp_image_{i}.png'
img.save(path)
image_files.append(path)
with imageio.get_writer('my_animation.gif', mode='I', duration=0.5) as writer:
for filename in image_files:
image = imageio.imread(filename)
writer.append_data(image)
from IPython.display import Image, display
display(Image(filename='./my_animation.gif'))
Results:
Insights:
By using the gradient values of each pixel in the image, we have generated heatmaps for all 15 stages of the encoder. These 15 heatmaps have been collected in a folder named ‘images.’ Subsequently, we created a GIF that illustrates how an algorithm shifts its focus within an image.
Complete code is avaliable at https://github.com/propardhu/ResNet_50_Vis