Exploring the Effects of Perturbations on Saliency Map Generation

The link between saliency maps and model natural robustness is currently unclear. This is a simple notebook exploring how perturbations might affect saliency maps, using tools provided by the nrtk and xaitk-saliency packages.

Table of Contents

Open In Colab

Set Up the Environment

Note for Colab users: After setting up the environment, you may need to “Restart Runtime” in order to resolve package version conflicts (see the README for more info).

from __future__ import annotations

import sys  # noqa: F401

!{sys.executable} -m pip install -qU pip
print("Installing xaitk-jatic...")
!{sys.executable} -m pip install -q xaitk-jatic
print("Installing nrtk[pybsm,headless]...")
!{sys.executable} -m pip install -q nrtk[pybsm,headless]
print("Installing Hugging Face datasets...")
!{sys.executable} -m pip install -q datasets>=3.4.0
print("Installing Hugging Face transformers...")
!{sys.executable} -m pip install -q transformers
print("Installing tabulate...")
!{sys.executable} -m pip install -q tabulate
print("Installing torch...")
!{sys.executable} -m pip install -q torch

print("Done!")
Installing xaitk-jatic...
Installing nrtk[pybsm,headless]...
Installing Hugging Face datasets...
Installing Hugging Face transformers...
Installing tabulate...
Installing torch...
Done!
from collections.abc import Hashable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import maite.protocols.image_classification as ic
import numpy as np
import torch
from datasets import Dataset, load_dataset
from matplotlib import pyplot as plt
from scipy.stats import entropy
from smqtk_classifier.interfaces.classify_image import ClassifyImage
from tabulate import tabulate
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
)
from xaitk_jatic.interop.image_classification.model import JATICImageClassifier
from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.slidingwindow import SlidingWindowStack
from xaitk_saliency.interfaces.gen_image_classifier_blackbox_sal import (
    GenerateImageClassifierBlackboxSaliency,
)

from nrtk.impls.perturb_image.pybsm.jitter_otf_perturber import JitterOTFPerturber

%matplotlib inline
%config InlineBackend.figure_format = "jpeg"  # Use JPEG format for inline visualizations

Example Images

We’ll use example images from the CIFAR-10 test dataset, but this could be expanded to many images – even across a dataset.

data = load_dataset("cifar10", split="test")

if TYPE_CHECKING:
    assert isinstance(data, Dataset)

labels = data.features["label"].names
data.set_transform(lambda x: {"image": x["img"], "label": x["label"]})
num_samples = 2
if TYPE_CHECKING:
    assert isinstance(data, Dataset)
imgs = np.asarray([np.asarray(data[idx]["image"]) for idx in range(num_samples)])
ground_truth: list[int] = [data[idx]["label"] for idx in range(num_samples)]

for img, gt in zip(imgs, ground_truth, strict=False):
    plt.figure(figsize=(2, 2))
    plt.xticks(())
    plt.yticks(())
    plt.xlabel(f"GT: {labels[gt]}")
    _ = plt.imshow(img)
../../_images/cf54498af8a4aca4ea74dc331bfb5d5cd60ff6dad4a8ad53462c556841da7cf5.jpg ../../_images/83f890f9eb30ac7a3b63d77b65da3969537b8abb7648e645faf6595dee2a4e1e.jpg

Defining the “Application”

First we’ll define a couple of dataclasses to keep track of results more easily:

@dataclass
class PerturbationResult:
    """Dataclass for storing perturbed image and associated results."""

    descriptor: str
    img: np.ndarray
    sal_maps: np.ndarray
    pred_class: int
    pred_prob: float


@dataclass
class SaliencyResults:
    """Dataclass for storing saliency map and associated results."""

    ref_img: np.ndarray
    ref_sal_maps: np.ndarray
    gt: int
    pred_class: int
    pred_prob: float
    perturbations: list[PerturbationResult] = field(default_factory=list)

Next, we’ll define a function to compute specified metrics upon our saliency map results. These metrics include measures such as the entropy of the resulting saliency map, as well as various measures of correlation between the saliency map computed on the original image and the saliency maps computed on perturbed images.

def _compute_entropy(
    sal_map: np.ndarray,
    clip_min: int | None = None,
    clip_max: int | None = None,
) -> np.number | np.ndarray:
    if clip_min is not None or clip_max is not None:
        s = np.clip(sal_map, clip_min, clip_max)
    else:
        s = (sal_map - sal_map.min()) / (sal_map.max() - sal_map.min())
    return entropy(s.ravel(), base=2)


def _compute_ssd(sal_map: np.ndarray, ref_sal_map: np.ndarray) -> float:
    sum_sq_diff = np.sum(np.power(np.subtract(sal_map, ref_sal_map), 2))
    norm = np.sqrt(np.sum(np.power(sal_map, 2)) * np.sum(np.power(ref_sal_map, 2)))
    if not norm:
        return np.inf
    return sum_sq_diff / norm


def _compute_xcorr(sal_map: np.ndarray, ref_sal_map: np.ndarray) -> float:
    def _normalize(s: np.ndarray) -> tuple[np.ndarray, bool]:
        s -= s.mean()
        std = s.std()

        if std:
            s /= std

        return s, std == 0

    s1, c1 = _normalize(sal_map.copy())
    s2, c2 = _normalize(ref_sal_map.copy())

    if c1 and not c2:
        return 0.0
    return np.corrcoef(s1.flatten(), s2.flatten())[0, 1]


def _compute_metric(sal_map: np.ndarray, ref_sal_map: np.ndarray, m: str) -> float:
    if "entropy" in m:
        _compute_entropy_setup(sal_map, m)
    if m == "ssd":
        return _compute_ssd(sal_map, ref_sal_map)
    if m == "xcorr":
        return _compute_xcorr(sal_map, ref_sal_map)
    return np.nan


def _compute_entropy_setup(sal_map: np.ndarray, m: str) -> float | np.number | np.ndarray:
    if m == "entropy":
        return _compute_entropy(sal_map)
    if m == "pos saliency entropy":
        return _compute_entropy(sal_map, 0, 1)
    if m == "neg saliency entropy":
        return _compute_entropy(sal_map, -1, 0)
    return np.nan


def _generate_row(
    sal_map: np.ndarray,
    ref_sal_map: np.ndarray,
    row_label: str,
    metrics: tuple[str, ...],
) -> list[Hashable]:
    r: list[Hashable] = [row_label]

    r.extend([_compute_metric(sal_map, ref_sal_map, metric.lower().strip()) for metric in metrics])

    return r


def compute_metrics(results: list[SaliencyResults], metrics: tuple[str, ...]) -> None:
    """Compute metrics for saliency maps."""
    headers = ["Perturbation"]
    headers.extend(list(metrics))
    for res in results:
        rows = []

        rows.append(_generate_row(res.ref_sal_maps[res.gt], res.ref_sal_maps[res.gt], "Ref Image", metrics))
        rows.extend(
            [
                _generate_row(pert.sal_maps[res.gt], res.ref_sal_maps[res.gt], f"{pert.descriptor}", metrics)
                for pert in res.perturbations
            ],
        )

        print(tabulate(rows, headers=headers, tablefmt="plain"))

We’ll also define a function to display all generated saliency maps.

def _plot_img(img: np.ndarray, num_cols: int, descriptor: str = "") -> None:
    plt.subplot(2, num_cols, 1)
    plt.imshow(img, cmap="gray")
    plt.xticks(())
    plt.yticks(())
    plt.xlabel(descriptor)


def _plot_rows(sal_maps: np.ndarray, num_cols: int, plot_idxes: list[int] | None = None) -> None:
    plot_idxes = list(range(len(sal_maps))) if plot_idxes is None else [*set(plot_idxes)]
    n_cols = min(num_cols - 1, len(plot_idxes))
    n_rows = 2

    num_imgs = 0
    for r in range(n_rows):
        col_offset = 2
        if r > 0:
            col_offset = 3
        for c in range(r * n_cols, min(r * n_cols + n_cols, len(plot_idxes))):
            plt.subplot(n_rows, num_cols, c + col_offset)
            im = plt.imshow(sal_maps[plot_idxes[c]], cmap=plt.get_cmap("RdBu"), vmin=-1, vmax=1)
            plt.xticks(())
            plt.yticks(())
            plt.xlabel(f"{labels[plot_idxes[c]]}")
            num_imgs += 1

            if num_imgs == len(plot_idxes):
                fig = plt.gcf()
                cax = fig.add_axes((0.38, 0.60, 0.01, 0.21))  # tweaked for this particular example
                plt.colorbar(im, cax=cax)


def display_results(results: list[SaliencyResults], labels: dict[int, str]) -> None:
    """Displays saliency results with plots and a table."""
    num_classes = len(labels)

    for res in results:
        plt.figure(figsize=(10, 5))
        num_cols = np.ceil(num_classes / 2).astype(int) + 1
        pred = f"{labels[res.pred_class]} ({res.pred_prob:.2f})"
        _plot_img(res.ref_img, num_cols, f"Ref Img\nGT: {labels[res.gt]}\nPred: {pred}")
        _plot_rows(res.ref_sal_maps, num_cols, [res.gt])

        for pert in res.perturbations:
            plt.figure(figsize=(10, 5))
            pred = f"{pert.pred_class} ({pert.pred_prob:.2f})"
            _plot_img(pert.img, num_cols, f"{pert.descriptor}\nPred: {pred}")
            _plot_rows(pert.sal_maps, num_cols, [res.gt])

    plt.show()

Finally, we’ll define the “application”, which perturbs the given input image(s) to varying degrees and generates saliency maps. In this case, we’ll perturb the images using a pyBSM based perturber. To easily apply this perturbation, we’ll use the JitterOTFPerturber, which simulates varying amounts of sensor jitter on image collection.

def _max_class(probs: dict) -> str:
    v = list(probs.values())
    k = list(probs.keys())
    return k[v.index(max(v))]


def _generate_augmented_maps(
    idx: int,
    additional_params: list[dict[str, Any]],
    res: SaliencyResults,
    img: np.ndarray,
    num_images: int,
    image_classifier: ClassifyImage,
    saliency_generator: GenerateImageClassifierBlackboxSaliency,
) -> None:
    for k in additional_params:
        print(f"Generating saliency maps for s_y={k['s_y']} (ref image {idx + 1} of {num_images})")
        xform = JitterOTFPerturber(**k)
        img_out, _ = xform(np.copy(img))
        sal_maps = saliency_generator(img_out, image_classifier)
        probs = next(image_classifier.classify_images(np.expand_dims(img_out, axis=0)))
        pred_class = _max_class(probs)

        pert = PerturbationResult(
            descriptor=f"ksize={k}",
            img=img_out,
            sal_maps=sal_maps,
            pred_class=labels[pred_class],
            pred_prob=probs[pred_class],
        )

        res.perturbations.append(pert)


def generate_perturbed_sal_maps(
    images: np.ndarray,
    ground_truth: list[int],
    image_classifier: ClassifyImage,
    saliency_generator: GenerateImageClassifierBlackboxSaliency,
    additional_params: list[dict[str, Any]],
    display_labels: dict[int, str],
    display_maps: bool = True,
    metrics: tuple[str, ...] = ("Pos Saliency Entropy", "Neg Saliency Entropy", "Entropy", "SSD", "XCorr"),
) -> list[SaliencyResults]:
    """Generate saliency maps for image."""
    # Get class labels
    labels = image_classifier.get_labels()

    # Generate saliency maps
    results = list()
    for idx, img in enumerate(images):
        print(f"Generating saliency maps for reference image (image {idx + 1} of {len(images)})")
        sal_maps = saliency_generator(img, image_classifier)
        probs = next(image_classifier.classify_images(np.expand_dims(img, axis=0)))
        pred_class = _max_class(probs)
        res = SaliencyResults(
            ref_img=np.copy(img),
            ref_sal_maps=sal_maps,
            gt=ground_truth[idx],
            pred_class=labels.index(pred_class),
            pred_prob=probs[pred_class],
        )

        _generate_augmented_maps(idx, additional_params, res, img, len(images), image_classifier, saliency_generator)

        results.append(res)

    for result in results:
        # Plot each image in set with saliency maps
        if display_maps:
            display_results([result], display_labels)

        # Compute metrics
        compute_metrics([result], metrics)

    return results

Running the “Application”

Classifier

We’ll use a Hugging Face model conforming to the maite image classification protocol, along with the relevant xaitk-saliency adapter.

class HuggingFaceClassifier:
    """MAITE wrapper for HuggingFaceClassifier."""

    def __init__(self, model_name: str, device: str) -> None:
        """Initialize HuggingFaceClassifier."""
        self.image_processor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModelForImageClassification.from_pretrained(model_name)
        self.device = device

        self.model.eval()
        self.model.to(device)
        self.metadata = self.model.config.id2label

    def __call__(self, batch: Sequence[ic.InputType]) -> Sequence[ic.TargetType]:
        """Run classifier for batch and return results."""
        # tensor bridging
        input_tensor = torch.as_tensor(batch)
        if input_tensor.ndim != 4:
            raise ValueError(f"Invalid input dimensions. Expected 4, got {input_tensor.ndim}")

        # preprocess
        hf_inputs = self.image_processor(input_tensor, return_tensors="pt")

        # put on device
        hf_inputs = hf_inputs.to(self.device)

        # get predictions
        with torch.no_grad():
            return self.model(**hf_inputs).logits.softmax(1).detach().cpu()


jatic_classifier: ic.Model = HuggingFaceClassifier(
    model_name="aaraki/vit-base-patch16-224-in21k-finetuned-cifar10",
    device="cuda" if torch.cuda.is_available() else "cpu",
)
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
ids = [int(k) for k in jatic_classifier.metadata]
classifier = JATICImageClassifier(classifier=jatic_classifier, ids=ids)

Saliency Generator

We’ll use the SlidingWindowStack blackbox saliency generator.

sal_generator = SlidingWindowStack(window_size=(2, 2), stride=(1, 1), threads=4)
sal_generator.fill = (128, 128, 128)

Results

Note: for clarity, we’ll only be performing the saliency analysis with respect to the groundtruth class, but this analysis could also be applied to the predicted class, which may be useful in cases where the groundtruth and predictions may differ.

results = generate_perturbed_sal_maps(
    images=imgs,
    ground_truth=ground_truth,
    image_classifier=classifier,
    saliency_generator=sal_generator,
    display_labels={int(k): str(v) for k, v in jatic_classifier.metadata.items()},
    additional_params=[{"s_y": 1e-4, "s_x": 0.0}, {"s_y": 1.2e-4, "s_x": 0.0}, {"s_y": 1.4e-4, "s_x": 0.0}],
)
Generating saliency maps for reference image (image 1 of 2)
/tmp/ipykernel_1253135/2834420211.py:17: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:254.)
  input_tensor = torch.as_tensor(batch)
Generating saliency maps for s_y=0.0001 (ref image 1 of 2)
Generating saliency maps for s_y=0.00012 (ref image 1 of 2)
Generating saliency maps for s_y=0.00014 (ref image 1 of 2)
Generating saliency maps for reference image (image 2 of 2)
Generating saliency maps for s_y=0.0001 (ref image 2 of 2)
Generating saliency maps for s_y=0.00012 (ref image 2 of 2)
Generating saliency maps for s_y=0.00014 (ref image 2 of 2)
../../_images/0209f3b18a30a40d6aa6f5033aea4323f6d954e7a2a20260e234ddf601c9c03a.jpg ../../_images/69398957307abd2476b34501533d08ab691dfe9ed1554a8269381223ef47eaf1.jpg ../../_images/71eee536685b2339969739311fc12698e889830686a0bfd535a589c5a7f2d8db.jpg ../../_images/5795a015af2b0a5d10a100f6ea80d76e6e28f04555a217c574723bdc17cb5dd7.jpg
Perturbation                          Pos Saliency Entropy    Neg Saliency Entropy    Entropy      SSD       XCorr
Ref Image                                              nan                     nan        nan  0         1
ksize={'s_y': 0.0001, 's_x': 0.0}                      nan                     nan        nan  1.96528   0.0305445
ksize={'s_y': 0.00012, 's_x': 0.0}                     nan                     nan        nan  2.19786  -0.0597571
ksize={'s_y': 0.00014, 's_x': 0.0}                     nan                     nan        nan  2.35233  -0.124605
../../_images/dffb953b21d42ec1d0388067fd3797224f93bec3cdb25fe5b60f7258b545044d.jpg ../../_images/517fbd76a213f4c5723b840a45b4057dc9accf5d524257c14303228267078748.jpg ../../_images/d3a4717190e336ac1d47d12f936145b80e889894f6fb633c9dd8e56842c19eed.jpg ../../_images/9ec93b9a283c1306c8253311f3d4ed401caadf2d82a46a005f49b73869874f4c.jpg
Perturbation                          Pos Saliency Entropy    Neg Saliency Entropy    Entropy       SSD     XCorr
Ref Image                                              nan                     nan        nan  0         1
ksize={'s_y': 0.0001, 's_x': 0.0}                      nan                     nan        nan  0.930901  0.437871
ksize={'s_y': 0.00012, 's_x': 0.0}                     nan                     nan        nan  1.06851   0.292291
ksize={'s_y': 0.00014, 's_x': 0.0}                     nan                     nan        nan  1.14059   0.221283

We can visually see that as the quality of the input image degrades (more perturbation), the quality of the generated saliency maps similarly degrades.

In an attempt to quantify these differences, we’ve also computed several metrics:

Entropy

If we compare entropy values using both positive and negative saliency we don’t see much of a change across degradations. This is likely due to negative and positive saliency “fighting” each other as degradation increases (as one increases, the other decreases).

If we consider entropy values computed from only positive or only negative saliency values, we see differences in values. Looking at the dominant saliency type (i.e. positive/blue) for the ship, we can see that as degradation gets worse, entropy increases – to a certain point. The reduction in entropy likely corresponds to the classifier being less able to identify key features that led to the original probability distribution for the reference image due to the degradion. Eventually these features may become so degraded that the classifier begins predicting with very low confidence. Looking at the domainant saliency type (i.e. negative/red) for the cat, we see similar changes as the other reference image; however, the positive and negative saliencies are significantly closer. This could be a sign of less definitve features used for identificantion. Since this does not c

If we look at the opposite saliency type for each reference image, we see a very slight decrease in entropy as degradation increases. This potentially indicates that the degradation introduces noise that the classifier misidentifies as a contraindicator for the ground truth class, but more likely corresponds to the classifier predicting with less confidence due to the loss in higher quality features.

Sum of Squared Differences (SSD)

(0 is most similar) The sum of squared differences lets us quantitatively confirm that as degradation gets worse, saliency maps are increasiningly dissimilar to the original reference saliency map. However, the metric doesn’t give us much insight into what is actually happening to create these differences.

Cross-Correlation (XCorr)

(1 is most similar) Cross-correlations tell us similar information as SSD. The introduction of negative correlation values, however, potentially indicates that the saliency maps begin to become the “opposite” of the original reference saliency maps. The aligns with the pattern we saw with positive/negative saliency entropy – we see an introduction of the opposite saliency as the image becomes more degraded and the classifier becomes more confused. The likely doesn’t occur with the ship reference image as it was more strongly salient in one direction compared to the cat reference image was contained a more balanced mix of both positive and negative saliency.

While aspects of model explainability and model robustness have previously been studied independently, this notebook demonstrates a preliminary exploration of their relationship through quantification of how saliency maps change due to various perturbations. Future work will explore whether quantitative changes in the structure and quality of saliency maps provide a mechanism for understanding model failure modes and edge cases due to various perturbations.