Class Activation Map (CAM) Visualizations in PyTorch.
This project provide a script of class activation map (CAM) visualizations, which can be used for explaining predictions and model interpretability, etc.
$ pip install torchcam
$ pip install --upgrade git+https://github.com/Tramac/pytorch-cam.git
from torchcam import open_image, image2batch, int2tensor, getCAM
from torchvision.models import resnet18
img = open_image('./data/cat.jpg', (224, 224), convert_mode='RGB')
input = image2batch(img)
image_class = 284 # cat class in imagenet
target = int2tensor(image_class)
model = resnet18(pretrained=True)
# gradcam
cam = getCAM(model, img, input, target, display=True, save=False)
Besides the default gradcam
method, these following additional methods are alse available: vanilla_grad, grad_x_input, saliency, integrate_grad, deconv, smooth_grad
.
from torchcam import saliency
results = saliency.get_image_saliency_result(model, img, input, target, methods=['smooth_grad', 'vanilla_grad', 'grad_x_input', 'saliency'])
figure = saliency.get_image_saliency_plot(results, display=True, save=False)
model = YourModel()
cam = getCAM(model, img, input, target, layer_path=['xxx']) # The end backprop layer key in your model