Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Configurations.jl for kwarg handling #9

Merged
merged 7 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ version = "1.3.1"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[compat]
ColorSchemes = "3"
Configurations = "0.17"
ImageCore = "0.9, 0.10"
ImageTransformations = "0.10"
Interpolations = "0.15"
Expand Down
7 changes: 4 additions & 3 deletions src/VisionHeatmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ using ImageTransformations: imresize
using Interpolations: Lanczos
using ImageCore
using XAIBase: Explanation, AbstractXAIMethod, analyze
using Configurations: @option

include("heatmap.jl")
include("overlay.jl")
include("xaibase.jl")
include("config.jl") # HeatmapOptions
include("heatmap.jl") # heatmap
include("overlay.jl") # heatmap_overlay

export heatmap, heatmap_overlay

Expand Down
39 changes: 39 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered

@option struct HeatmapOptions
colorscheme::Union{ColorScheme,Symbol} = DEFAULT_COLORSCHEME
reduce::Symbol = DEFAULT_REDUCE
rangescale::Symbol = DEFAULT_RANGESCALE
permute::Bool = true
process_batch::Bool = false
unpack_singleton::Bool = true
end

get_colorscheme(options::HeatmapOptions) = get_colorscheme(options.colorscheme)
get_colorscheme(c::ColorScheme) = c
get_colorscheme(s::Symbol)::ColorScheme = colorschemes[s]

#=================#
# XAIBase support #
#=================#

const HEATMAP_PRESETS = Dict{
Symbol,@NamedTuple{colorscheme::Symbol, reduce::Symbol, rangescale::Symbol}
}(
:attribution => (colorscheme=:seismic, reduce=:sum, rangescale=:centered),
:sensitivity => (colorscheme=:grays, reduce=:norm, rangescale=:extrema),
:cam => (colorscheme=:jet, reduce=:sum, rangescale=:extrema),
)
const DEFAULT_HEATMAP_PRESET = (
colorscheme=DEFAULT_COLORSCHEME, reduce=DEFAULT_REDUCE, rangescale=DEFAULT_RANGESCALE
)

# Override HeatmapOptions preset with keyword arguments
function HeatmapOptions(expl::Explanation; kwargs...)
c = get(HEATMAP_PRESETS, expl.heatmap, DEFAULT_HEATMAP_PRESET)
return HeatmapOptions(;
colorscheme=c.colorscheme, reduce=c.reduce, rangescale=c.rangescale, kwargs...
)
end
83 changes: 50 additions & 33 deletions src/heatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered

const InputDimensionError = ArgumentError(
"heatmap assumes the WHCN convention for input array dimensions (width, height, color channels, batch dimension).
Please reshape your input to match this format if your model doesn't adhere to this convention.",
)

"""
heatmap(x::AbstractArray)

Expand All @@ -23,54 +28,34 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di
- `rangescale::Symbol`: Selects how the color channel reduced heatmap is normalized
before the color scheme is applied. Can be either `:extrema` or `:centered`.
Defaults to `:$DEFAULT_RANGESCALE`.
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
- `process_batch::Bool`: When heatmapping a batch, setting `process_batch=true`
will apply the `rangescale` normalization to the entire batch
instead of computing it individually for each sample in the batch.
Defaults to `false`.
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
- `unpack_singleton::Bool`: If false, `heatmap` will always return a vector of images.
When heatmapping a batch with a single sample, setting `unpack_singleton=true`
will unpack the singleton vector and directly return the image. Defaults to `true`.
"""
function heatmap(
val::AbstractArray{T,N};
colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME,
reduce::Symbol=DEFAULT_REDUCE,
rangescale::Symbol=DEFAULT_RANGESCALE,
permute::Bool=true,
unpack_singleton::Bool=true,
process_batch::Bool=false,
) where {T,N}
heatmap(val; kwargs...) = heatmap(val, HeatmapOptions(; kwargs...))
function heatmap(val::AbstractArray{T,N}, options::HeatmapOptions) where {T,N}
N != 4 && throw(InputDimensionError)
colorscheme = get_colorscheme(colorscheme)
if unpack_singleton && size(val, 4) == 1
return single_heatmap(val[:, :, :, 1], colorscheme, reduce, rangescale, permute)
if options.unpack_singleton && size(val, 4) == 1
return single_heatmap(val[:, :, :, 1], options)
end
if process_batch
hs = single_heatmap(val, colorscheme, reduce, rangescale, permute)
if options.process_batch
hs = single_heatmap(val, options)
return [hs[:, :, i] for i in axes(hs, 3)]
end
return [
single_heatmap(v, colorscheme, reduce, rangescale, permute) for
v in eachslice(val; dims=4)
]
return [single_heatmap(v, options) for v in eachslice(val; dims=4)]
end

const InputDimensionError = ArgumentError(
"heatmap assumes the WHCN convention for input array dimensions (width, height, color channels, batch dimension).
Please reshape your input to match this format if your model doesn't adhere to this convention.",
)

get_colorscheme(c::ColorScheme) = c
get_colorscheme(s::Symbol)::ColorScheme = colorschemes[s]

# Lower level function, mapped along batch dimension
function single_heatmap(
val, colorscheme::ColorScheme, reduce::Symbol, rangescale::Symbol, permute::Bool
)
img = dropdims(reduce_color_channel(val, reduce); dims=3)
permute && (img = flip_wh(img))
return get(colorscheme, img, rangescale)
function single_heatmap(val, options::HeatmapOptions)
img = dropdims(reduce_color_channel(val, options.reduce); dims=3)
options.permute && (img = flip_wh(img))
cs = get_colorscheme(options)
return get(cs, img, options.rangescale)
end

flip_wh(img::AbstractArray{T,2}) where {T} = permutedims(img, (2, 1))
Expand Down Expand Up @@ -98,3 +83,35 @@ function reduce_color_channel(val::AbstractArray, method::Symbol)
),
)
end

#=================#
# XAIBase support #
#=================#

"""
heatmap(expl::Explanation)

Visualize `Explanation` from XAIBase as a vision heatmap.
Assumes WHCN convention (width, height, channels, batch dimension) for `explanation.val`.

This will use the default heatmapping style for the given type of explanation.
Defaults can be overridden via keyword arguments.
"""
heatmap(expl::Explanation; kwargs...) = heatmap(expl.val, HeatmapOptions(expl; kwargs...))

"""
heatmap(input::AbstractArray, analyzer::AbstractXAIMethod)

Compute an `Explanation` for a given `input` using the XAI method `analyzer` and visualize it
as a vision heatmap.

Any additional arguments and keyword arguments are passed to the analyzer.
Refer to the `analyze` documentation for more information on available keyword arguments.

To customize the heatmapping style, first compute an explanation using `analyze`
and then call [`heatmap`](@ref) on the explanation.
"""
function heatmap(input, analyzer::AbstractXAIMethod, analyze_args...; analyze_kwargs...)
expl = analyze(input, analyzer, analyze_args...; analyze_kwargs...)
return heatmap(expl)
end
9 changes: 6 additions & 3 deletions src/overlay.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const DEFAULT_OVERLAY_ALPHA = 0.6
const DEFAULT_RESIZE_METHOD = Lanczos(1)

"""
heatmap_overlay(val, img)
Expand All @@ -19,8 +20,8 @@ function heatmap_overlay(
val::AbstractArray{T,N},
im::AbstractMatrix{<:Colorant};
alpha=DEFAULT_OVERLAY_ALPHA,
resize_method=Lanczos(1),
kwargs...,
resize_method=DEFAULT_RESIZE_METHOD,
heatmap_kwargs...,
) where {T,N}
N != 4 && throw(InputDimensionError)
if size(val, 4) != 1
Expand All @@ -33,7 +34,9 @@ function heatmap_overlay(
if alpha < 0 || alpha > 1
throw(ArgumentError("alpha must be in the range [0, 1]"))
end
hm = heatmap(val; kwargs...)

options = HeatmapOptions(; heatmap_kwargs...)
hm = heatmap(val, options)
hmsize = size(hm)
imsize = size(im)
if hmsize != imsize
Expand Down
67 changes: 0 additions & 67 deletions src/xaibase.jl

This file was deleted.

1 change: 1 addition & 0 deletions test/references/abssum_centered_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_extrema_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_centered_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_extrema_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_centered_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_extrema_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_maxabs_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_maxabs_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_centered_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_extrema_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_maxabs_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_maxabs_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_maxabs_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_maxabs_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_norm_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_norm_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_norm_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
Loading
Loading