Exploring the Effects of Perturbations on Saliency Map Generation#
The link between saliency maps and model natural robustness is currently unclear. This is a simple notebook exploring how perturbations might affect saliency maps, using tools provided by the nrtk and xaitk-saliency packages.
Table of Contents#
Set Up the Environment #
Note for Colab users: After setting up the environment, you may need to “Restart Runtime” in order to resolve package version conflicts (see the README for more info).
Note: We are suppressing warnings within this notebook to reduce visual clutter for demonstration purposes. If any issues arise while executing this notebook, we recommend that this cell is not executed so that any related warnings are shown.
import warnings
warnings.filterwarnings("ignore")
%pip install -qU pip
print("Installing nrtk with required extras...")
%pip install -q "nrtk[pybsm,headless]"
print("Installing notebook-specific packages...")
%pip install -q xaitk-jatic "datasets>=3.4.0" transformers tabulate torch
print("Done!")
from nrtk.utils._extras import print_extras_status # noqa: E402 - intentionally after %pip install
print_extras_status()
Note: you may need to restart the kernel to use updated packages.
Installing nrtk with required extras...
Note: you may need to restart the kernel to use updated packages.
Installing notebook-specific packages...
Note: you may need to restart the kernel to use updated packages.
Done!
Detected status of NRTK extras and their dependencies:
[albumentations]
- nrtk-albumentations ✗ missing
[diffusion]
- torch ✓ 2.10.0
- diffusers ✗ missing
- accelerate ✗ missing
- Pillow ✓ 12.1.1
- transformers ✓ 5.2.0
- protobuf ✗ missing
[graphics]
- opencv-python ✗ missing
[headless]
- opencv-python-headless ✓ 4.13.0.92
[maite]
- maite ✓ 0.9.2
[pillow]
- Pillow ✓ 12.1.1
[pybsm]
- pybsm ✓ 0.14.3
[skimage]
- scikit-image ✗ missing
[tools]
- kwcoco ✗ missing
- Pillow ✓ 12.1.1
- click ✓ 8.3.1
- fastapi ✗ missing
- uvicorn ✗ missing
- pydantic ✗ missing
- pydantic-settings ✗ missing
- python-json-logger ✗ missing
[waterdroplet]
- scipy ✓ 1.17.0
- numba ✓ 0.63.1
For details about installing NRTK extras, please visit:
https://nrtk.readthedocs.io/en/stable/
from collections.abc import Hashable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import maite.protocols.image_classification as ic
import numpy as np
import torch
from datasets import Dataset, load_dataset
from matplotlib import pyplot as plt
from scipy.stats import entropy
from smqtk_classifier.interfaces.classify_image import ClassifyImage
from tabulate import tabulate
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
)
from xaitk_jatic.interop.image_classification.model import JATICImageClassifier
from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.slidingwindow import SlidingWindowStack
from xaitk_saliency.interfaces.gen_image_classifier_blackbox_sal import (
GenerateImageClassifierBlackboxSaliency,
)
from nrtk.impls.perturb_image.photometric.blur import AverageBlurPerturber
%matplotlib inline
%config InlineBackend.figure_format = "jpeg" # Use JPEG format for inline visualizations
Example Images #
We’ll use example images from the CIFAR-10 test dataset, but this could be expanded to many images – even across a dataset.
data = load_dataset("cifar10", split="test")
if TYPE_CHECKING:
assert isinstance(data, Dataset)
labels = data.features["label"].names
data.set_transform(lambda x: {"image": x["img"], "label": x["label"]})
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
num_samples = 2
if TYPE_CHECKING:
assert isinstance(data, Dataset)
imgs = np.asarray([np.asarray(data[idx]["image"]) for idx in range(num_samples)])
ground_truth: list[int] = [data[idx]["label"] for idx in range(num_samples)]
for img, gt in zip(imgs, ground_truth, strict=False):
plt.figure(figsize=(2, 2))
plt.xticks(())
plt.yticks(())
plt.xlabel(f"GT: {labels[gt]}")
_ = plt.imshow(img)
Defining the “Application” #
First we’ll define a couple of dataclasses to keep track of results more easily:
@dataclass
class PerturbationResult:
"""Dataclass for storing perturbed image and associated results."""
descriptor: str
img: np.ndarray
sal_maps: np.ndarray
pred_class: int
pred_prob: float
@dataclass
class SaliencyResults:
"""Dataclass for storing saliency map and associated results."""
ref_img: np.ndarray
ref_sal_maps: np.ndarray
gt: int
pred_class: int
pred_prob: float
perturbations: list[PerturbationResult] = field(default_factory=list)
Next, we’ll define a function to compute specified metrics upon our saliency map results. These metrics include measures such as the entropy of the resulting saliency map, as well as various measures of correlation between the saliency map computed on the original image and the saliency maps computed on perturbed images.
def _compute_entropy(
sal_map: np.ndarray,
clip_min: int | None = None,
clip_max: int | None = None,
) -> np.number | np.ndarray:
if clip_min is not None or clip_max is not None:
s = np.clip(sal_map, clip_min, clip_max)
else:
s = (sal_map - sal_map.min()) / (sal_map.max() - sal_map.min())
return entropy(s.ravel(), base=2)
def _compute_ssd(sal_map: np.ndarray, ref_sal_map: np.ndarray) -> float:
sum_sq_diff = np.sum(np.power(np.subtract(sal_map, ref_sal_map), 2))
norm = np.sqrt(np.sum(np.power(sal_map, 2)) * np.sum(np.power(ref_sal_map, 2)))
if not norm:
return np.inf
return sum_sq_diff / norm
def _compute_xcorr(sal_map: np.ndarray, ref_sal_map: np.ndarray) -> float:
def _normalize(s: np.ndarray) -> tuple[np.ndarray, bool]:
s -= s.mean()
std = s.std()
if std:
s /= std
return s, std == 0
s1, c1 = _normalize(sal_map.copy())
s2, c2 = _normalize(ref_sal_map.copy())
if c1 and not c2:
return 0.0
return np.corrcoef(s1.flatten(), s2.flatten())[0, 1]
def _compute_metric(sal_map: np.ndarray, ref_sal_map: np.ndarray, m: str) -> float:
if "entropy" in m:
_compute_entropy_setup(sal_map, m)
if m == "ssd":
return _compute_ssd(sal_map, ref_sal_map)
if m == "xcorr":
return _compute_xcorr(sal_map, ref_sal_map)
return np.nan
def _compute_entropy_setup(sal_map: np.ndarray, m: str) -> float | np.number | np.ndarray:
if m == "entropy":
return _compute_entropy(sal_map)
if m == "pos saliency entropy":
return _compute_entropy(sal_map, 0, 1)
if m == "neg saliency entropy":
return _compute_entropy(sal_map, -1, 0)
return np.nan
def _generate_row(
sal_map: np.ndarray,
ref_sal_map: np.ndarray,
row_label: str,
metrics: tuple[str, ...],
) -> list[Hashable]:
r: list[Hashable] = [row_label]
r.extend([_compute_metric(sal_map, ref_sal_map, metric.lower().strip()) for metric in metrics])
return r
def compute_metrics(results: list[SaliencyResults], metrics: tuple[str, ...]) -> None:
"""Compute metrics for saliency maps."""
headers = ["Perturbation"]
headers.extend(list(metrics))
for res in results:
rows = []
rows.append(_generate_row(res.ref_sal_maps[res.gt], res.ref_sal_maps[res.gt], "Ref Image", metrics))
rows.extend(
[
_generate_row(pert.sal_maps[res.gt], res.ref_sal_maps[res.gt], f"{pert.descriptor}", metrics)
for pert in res.perturbations
],
)
print(tabulate(rows, headers=headers, tablefmt="plain"))
We’ll also define a function to display all generated saliency maps.
def _plot_img(img: np.ndarray, num_cols: int, descriptor: str = "") -> None:
plt.subplot(2, num_cols, 1)
plt.imshow(img, cmap="gray")
plt.xticks(())
plt.yticks(())
plt.xlabel(descriptor)
def _plot_rows(sal_maps: np.ndarray, num_cols: int, plot_idxes: list[int] | None = None) -> None:
plot_idxes = list(range(len(sal_maps))) if plot_idxes is None else [*set(plot_idxes)]
n_cols = min(num_cols - 1, len(plot_idxes))
n_rows = 2
num_imgs = 0
for r in range(n_rows):
col_offset = 2
if r > 0:
col_offset = 3
for c in range(r * n_cols, min(r * n_cols + n_cols, len(plot_idxes))):
plt.subplot(n_rows, num_cols, c + col_offset)
im = plt.imshow(sal_maps[plot_idxes[c]], cmap=plt.get_cmap("RdBu"), vmin=-1, vmax=1)
plt.xticks(())
plt.yticks(())
plt.xlabel(f"{labels[plot_idxes[c]]}")
num_imgs += 1
if num_imgs == len(plot_idxes):
fig = plt.gcf()
cax = fig.add_axes((0.38, 0.60, 0.01, 0.21)) # tweaked for this particular example
plt.colorbar(im, cax=cax)
def display_results(results: list[SaliencyResults], labels: dict[int, str]) -> None:
"""Displays saliency results with plots and a table."""
num_classes = len(labels)
for res in results:
plt.figure(figsize=(10, 5))
num_cols = np.ceil(num_classes / 2).astype(int) + 1
pred = f"{labels[res.pred_class]} ({res.pred_prob:.2f})"
_plot_img(res.ref_img, num_cols, f"Ref Img\nGT: {labels[res.gt]}\nPred: {pred}")
_plot_rows(res.ref_sal_maps, num_cols, [res.gt])
for pert in res.perturbations:
plt.figure(figsize=(10, 5))
pred = f"{pert.pred_class} ({pert.pred_prob:.2f})"
_plot_img(pert.img, num_cols, f"{pert.descriptor}\nPred: {pred}")
_plot_rows(pert.sal_maps, num_cols, [res.gt])
plt.show()
Finally, we’ll define the “application”, which perturbs the given input image(s) to varying degrees and generates saliency maps. In this case, we’ll perturb the images using a pyBSM based perturber. To easily apply this perturbation, we’ll use the AverageBlurPerturber, which applies varying amounts of bluring to an image collection.
def _max_class(probs: dict) -> str:
v = list(probs.values())
k = list(probs.keys())
return k[v.index(max(v))]
def _generate_augmented_maps(
idx: int,
ksizes: list[int],
res: SaliencyResults,
img: np.ndarray,
num_images: int,
image_classifier: ClassifyImage,
saliency_generator: GenerateImageClassifierBlackboxSaliency,
) -> None:
for ksize in ksizes:
print(f"Generating saliency maps for ksize={ksize} (ref image {idx + 1} of {num_images})")
xform = AverageBlurPerturber(ksize=ksize)
img_out, _ = xform(image=np.copy(img))
sal_maps = saliency_generator(img_out, image_classifier)
probs = next(image_classifier.classify_images(np.expand_dims(img_out, axis=0)))
pred_class = _max_class(probs)
pert = PerturbationResult(
descriptor=f"ksize={ksize}",
img=img_out,
sal_maps=sal_maps,
pred_class=labels[pred_class],
pred_prob=probs[pred_class],
)
res.perturbations.append(pert)
def generate_perturbed_sal_maps(
images: np.ndarray,
ground_truth: list[int],
image_classifier: ClassifyImage,
saliency_generator: GenerateImageClassifierBlackboxSaliency,
ksizes: list[int],
display_labels: dict[int, str],
display_maps: bool = True,
metrics: tuple[str, ...] = ("Pos Saliency Entropy", "Neg Saliency Entropy", "Entropy", "SSD", "XCorr"),
) -> list[SaliencyResults]:
"""Generate saliency maps for image."""
# Get class labels
labels = image_classifier.get_labels()
# Generate saliency maps
results = []
for idx, img in enumerate(images):
print(f"Generating saliency maps for reference image (image {idx + 1} of {len(images)})")
sal_maps = saliency_generator(img, image_classifier)
probs = next(image_classifier.classify_images(np.expand_dims(img, axis=0)))
pred_class = _max_class(probs)
res = SaliencyResults(
ref_img=np.copy(img),
ref_sal_maps=sal_maps,
gt=ground_truth[idx],
pred_class=labels.index(pred_class),
pred_prob=probs[pred_class],
)
_generate_augmented_maps(idx, ksizes, res, img, len(images), image_classifier, saliency_generator)
results.append(res)
for result in results:
# Plot each image in set with saliency maps
if display_maps:
display_results([result], display_labels)
# Compute metrics
compute_metrics([result], metrics)
return results
Running the “Application” #
Classifier #
We’ll use a Hugging Face model conforming to the maite image classification protocol, along with the relevant xaitk-saliency adapter.
class HuggingFaceClassifier:
"""MAITE wrapper for HuggingFaceClassifier."""
def __init__(self, model_name: str, device: str) -> None:
"""Initialize HuggingFaceClassifier."""
self.image_processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModelForImageClassification.from_pretrained(model_name)
self.device = device
self.model.eval()
self.model.to(device)
self.metadata = self.model.config.id2label
def __call__(self, batch: Sequence[ic.InputType]) -> Sequence[ic.TargetType]:
"""Run classifier for batch and return results."""
# tensor bridging
input_tensor = torch.as_tensor(batch)
if input_tensor.ndim != 4:
raise ValueError(f"Invalid input dimensions. Expected 4, got {input_tensor.ndim}")
# preprocess
hf_inputs = self.image_processor(input_tensor, return_tensors="pt")
# put on device
hf_inputs = hf_inputs.to(self.device)
# get predictions
with torch.no_grad():
return self.model(**hf_inputs).logits.softmax(1).detach().cpu()
jatic_classifier: ic.Model = HuggingFaceClassifier(
model_name="aaraki/vit-base-patch16-224-in21k-finetuned-cifar10",
device="cuda" if torch.cuda.is_available() else "cpu",
)
ids = [int(k) for k in jatic_classifier.metadata]
classifier = JATICImageClassifier(classifier=jatic_classifier, ids=ids)
Saliency Generator #
We’ll use the SlidingWindowStack blackbox saliency generator.
sal_generator = SlidingWindowStack(window_size=(2, 2), stride=(1, 1), threads=4)
sal_generator.fill = (128, 128, 128)
Results #
Note: for clarity, we’ll only be performing the saliency analysis with respect to the groundtruth class, but this analysis could also be applied to the predicted class, which may be useful in cases where the groundtruth and predictions may differ.
results = generate_perturbed_sal_maps(
images=imgs,
ground_truth=ground_truth,
image_classifier=classifier,
saliency_generator=sal_generator,
display_labels={int(k): str(v) for k, v in jatic_classifier.metadata.items()},
ksizes=[3, 5, 7],
)
Generating saliency maps for reference image (image 1 of 2)
Generating saliency maps for ksize=3 (ref image 1 of 2)
Generating saliency maps for ksize=5 (ref image 1 of 2)
Generating saliency maps for ksize=7 (ref image 1 of 2)
Generating saliency maps for reference image (image 2 of 2)
Generating saliency maps for ksize=3 (ref image 2 of 2)
Generating saliency maps for ksize=5 (ref image 2 of 2)
Generating saliency maps for ksize=7 (ref image 2 of 2)
Perturbation Pos Saliency Entropy Neg Saliency Entropy Entropy SSD XCorr
Ref Image nan nan nan 0 1
ksize=3 nan nan nan 2.57343 -0.0921394
ksize=5 nan nan nan 2.13568 -0.013307
ksize=7 nan nan nan 3.11349 -0.0699706
Perturbation Pos Saliency Entropy Neg Saliency Entropy Entropy SSD XCorr
Ref Image nan nan nan 0 1
ksize=3 nan nan nan 1.11666 0.500884
ksize=5 nan nan nan 0.845931 0.362914
ksize=7 nan nan nan 2.96265 0.0502351
We can visually see that as the quality of the input image degrades (more perturbation), the quality of the generated saliency maps similarly degrades.
In an attempt to quantify these differences, we’ve also computed several metrics:
Entropy #
If we compare entropy values using both positive and negative saliency we don’t see much of a change across degradations. This is likely due to negative and positive saliency “fighting” each other as degradation increases (as one increases, the other decreases).
If we consider entropy values computed from only positive or only negative saliency values, we see differences in values. Looking at the dominant saliency type (i.e. positive/blue) for the ship, we can see that as degradation gets worse, entropy increases – to a certain point. The reduction in entropy likely corresponds to the classifier being less able to identify key features that led to the original probability distribution for the reference image due to the degradion. Eventually these features may become so degraded that the classifier begins predicting with very low confidence. Looking at the domainant saliency type (i.e. negative/red) for the cat, we see similar changes as the other reference image; however, the positive and negative saliencies are significantly closer. This could be a sign of less definitve features used for identificantion.
If we look at the opposite saliency type for each reference image, we see a very slight decrease in entropy as degradation increases. This potentially indicates that the degradation introduces noise that the classifier misidentifies as a contraindicator for the ground truth class, but more likely corresponds to the classifier predicting with less confidence due to the loss in higher quality features.
Sum of Squared Differences (SSD) #
(0 is most similar) The sum of squared differences lets us quantitatively confirm that as degradation gets worse, saliency maps are increasiningly dissimilar to the original reference saliency map. However, the metric doesn’t give us much insight into what is actually happening to create these differences.
Cross-Correlation (XCorr) #
(1 is most similar) Cross-correlations tell us similar information as SSD. The introduction of negative correlation values, however, potentially indicates that the saliency maps begin to become the “opposite” of the original reference saliency maps. The aligns with the pattern we saw with positive/negative saliency entropy – we see an introduction of the opposite saliency as the image becomes more degraded and the classifier becomes more confused. The likely doesn’t occur with the ship reference image as it was more strongly salient in one direction compared to the cat reference image was contained a more balanced mix of both positive and negative saliency.
While aspects of model explainability and model robustness have previously been studied independently, this notebook demonstrates a preliminary exploration of their relationship through quantification of how saliency maps change due to various perturbations. Future work will explore whether quantitative changes in the structure and quality of saliency maps provide a mechanism for understanding model failure modes and edge cases due to various perturbations.