# 2D Landmark → Heatmap (MNIST / MedMNIST) — with MONAI
Interactive demo converting clicked landmarks to Gaussian heatmaps **using MONAI**.

**Modes**
1) Array transform: `GenerateHeatmap`
2) Dict transform: `GenerateHeatmapd` with optional `MetaTensor` reference

In [None]:
# Installation requirements for interactive notebook
# %pip install --upgrade pip
# %pip install torch torchvision monai medmnist matplotlib ipywidgets
#
# For JupyterLab users, also run:
# %pip install jupyterlab-widgets
#
# For interactive matplotlib (optional, used in first implementation):
# %pip install ipympl

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from monai.transforms.post.array import GenerateHeatmap
from monai.transforms.post.dictionary import GenerateHeatmapd
from monai.data import MetaTensor

try:
    import medmnist
    from medmnist import PathMNIST

    HAS_MEDMNIST = True
    print("medmnist is available.")
except Exception:
    HAS_MEDMNIST = False
    print("medmnist not available. Run `pip install medmnist` to enable PathMNIST.")

In [None]:
# Load a small 2D image
use_medmnist = False
if use_medmnist and HAS_MEDMNIST:
    ds = PathMNIST(split="test", download=True, as_rgb=True)
    img = np.asarray(ds[0][0]).mean(axis=2).astype(np.float32)
else:
    mnist = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
    img = mnist[0][0][0].numpy().astype(np.float32)

if img.max() > 0:
    img = img / float(img.max())
H, W = img.shape
print(f"Image shape: {H}x{W}")

In [None]:
# Heatmap helper using GenerateHeatmap
# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.
# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).
# It now supports batched inputs.

sigma = 3.0


def heatmap_with_array_transform(x, y, sigma_override=None):
    s = float(sigma_override) if sigma_override is not None else float(sigma)
    tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))
    # Reorder (x,y) -> (y,x) for the transform
    # Support batched and non-batched inputs
    pts = np.array(list(zip(y, x)), dtype=np.float32)
    if pts.ndim == 2:
        pts = pts[np.newaxis, ...]  # Add batch dimension: (N, 2) -> (1, N, 2)
    pts_yx = pts[..., [1, 0]]
    return tr(pts_yx)  # (B, N, H, W)


In [None]:
affine = torch.eye(4)
ref_img = MetaTensor(torch.from_numpy(img).unsqueeze(0), affine=affine)
ref_img.meta["spatial_shape"] = (H, W)

# Dictionary version wrapper also accepts (x,y) and converts to (y,x)


def heatmap_with_dict_transform(x, y, sigma_override=None, use_ref=True):
    s = float(sigma_override) if sigma_override is not None else float(sigma)
    tr = GenerateHeatmapd(
        keys="points",
        heatmap_keys="heatmap",
        ref_image_keys="ref" if use_ref else None,
        spatial_shape=None if use_ref else (H, W),
        sigma=s,
    )
    # Support batched and non-batched inputs
    pts = np.array(list(zip(y, x)), dtype=np.float32)
    if pts.ndim == 2:
        pts = pts[np.newaxis, ...]  # Add batch dimension: (N, 2) -> (1, N, 2)
    pts_yx = pts[..., [1, 0]]
    data = {"points": pts_yx, "ref": ref_img}
    out = tr(data)
    return out["heatmap"]


In [None]:
# Simple random landmark → heatmap example (no interactivity)
# Re-run this cell to sample new random points and regenerate heatmaps.
# INTERNAL NOTE: GenerateHeatmap consumes (row=y, col=x). We sample (x,y) for user readability and convert.

import random

# Parameters
num_points = 3  # number of random landmarks
sigma_demo = 3.0  # Gaussian sigma
combine_mode = "max"  # or 'sum'
batched_input = True  # Set to True to test batched input

# Sample random (x,y) points within image bounds (user-friendly)
points_xy = np.array(
    [[random.uniform(0, W - 1), random.uniform(0, H - 1)] for _ in range(num_points)], dtype=np.float32
)  # (N,2)
print("Random points (x, y):")
print(points_xy)

# Convert to (y,x) for the transform
yx_points = points_xy[:, [1, 0]].copy()
if batched_input:
    yx_points = yx_points[np.newaxis, ...]  # Add a batch dimension

array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))
heatmaps = array_tr(yx_points)  # now correct orientation

if batched_input:
    heatmaps = heatmaps.squeeze(0)  # Remove batch dim for plotting

if combine_mode == "max":
    combined = heatmaps.max(axis=0)
elif combine_mode == "sum":
    combined = heatmaps.sum(axis=0)
    if combined.max() > 0:
        combined = combined / combined.max()
else:
    raise ValueError("combine_mode must be 'max' or 'sum'")

# Plot
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img, cmap="gray", vmin=0.0, vmax=1.0, origin="upper")
axes[0].set_title("Base Image")
axes[0].set_axis_off()
for x, y in points_xy:
    axes[0].plot(x, y, "r+", markersize=12, markeredgewidth=2)

axes[1].imshow(img, cmap="gray", vmin=0.0, vmax=1.0, origin="upper")
axes[1].imshow(combined, alpha=0.6, cmap="hot", origin="upper")
axes[1].set_title(f"Combined Heatmap (mode={combine_mode}, sigma={sigma_demo})")
axes[1].set_axis_off()
for x, y in points_xy:
    axes[1].plot(x, y, "c+", markersize=12, markeredgewidth=2)

plt.tight_layout()
plt.show()

# Individual channels
fig2, axes2 = plt.subplots(1, num_points, figsize=(4 * num_points, 4))
if num_points == 1:
    axes2 = [axes2]
for i, ax in enumerate(axes2):
    ax.imshow(heatmaps[i], cmap="hot", origin="upper")
    ax.plot(points_xy[i, 0], points_xy[i, 1], "w+", markersize=12, markeredgewidth=2)
    ax.set_title(f"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})")
    ax.set_axis_off()
plt.tight_layout()
plt.show()
