Skip to content

Commit

Permalink
Use Configurations.jl for kwarg handling (#9)
Browse files Browse the repository at this point in the history
* Add `HeatmapOptions` struct via Configurations.jl to handle keyword arguments

* Reorganized source code, adding config.jl file, removing xaibase.jl and moving all `heatmap` methods into heatmap.jl
  • Loading branch information
adrhill authored Feb 20, 2024
1 parent 0ade8c7 commit 034bffd
Show file tree
Hide file tree
Showing 98 changed files with 263 additions and 133 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ version = "1.3.1"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[compat]
ColorSchemes = "3"
Configurations = "0.17"
ImageCore = "0.9, 0.10"
ImageTransformations = "0.10"
Interpolations = "0.15"
Expand Down
7 changes: 4 additions & 3 deletions src/VisionHeatmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ using ImageTransformations: imresize
using Interpolations: Lanczos
using ImageCore
using XAIBase: Explanation, AbstractXAIMethod, analyze
using Configurations: @option

include("heatmap.jl")
include("overlay.jl")
include("xaibase.jl")
include("config.jl") # HeatmapOptions
include("heatmap.jl") # heatmap
include("overlay.jl") # heatmap_overlay

export heatmap, heatmap_overlay

Expand Down
39 changes: 39 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered

@option struct HeatmapOptions
colorscheme::Union{ColorScheme,Symbol} = DEFAULT_COLORSCHEME
reduce::Symbol = DEFAULT_REDUCE
rangescale::Symbol = DEFAULT_RANGESCALE
permute::Bool = true
process_batch::Bool = false
unpack_singleton::Bool = true
end

get_colorscheme(options::HeatmapOptions) = get_colorscheme(options.colorscheme)
get_colorscheme(c::ColorScheme) = c
get_colorscheme(s::Symbol)::ColorScheme = colorschemes[s]

#=================#
# XAIBase support #
#=================#

const HEATMAP_PRESETS = Dict{
Symbol,@NamedTuple{colorscheme::Symbol, reduce::Symbol, rangescale::Symbol}
}(
:attribution => (colorscheme=:seismic, reduce=:sum, rangescale=:centered),
:sensitivity => (colorscheme=:grays, reduce=:norm, rangescale=:extrema),
:cam => (colorscheme=:jet, reduce=:sum, rangescale=:extrema),
)
const DEFAULT_HEATMAP_PRESET = (
colorscheme=DEFAULT_COLORSCHEME, reduce=DEFAULT_REDUCE, rangescale=DEFAULT_RANGESCALE
)

# Override HeatmapOptions preset with keyword arguments
function HeatmapOptions(expl::Explanation; kwargs...)
c = get(HEATMAP_PRESETS, expl.heatmap, DEFAULT_HEATMAP_PRESET)
return HeatmapOptions(;
colorscheme=c.colorscheme, reduce=c.reduce, rangescale=c.rangescale, kwargs...
)
end
83 changes: 50 additions & 33 deletions src/heatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered

const InputDimensionError = ArgumentError(
"heatmap assumes the WHCN convention for input array dimensions (width, height, color channels, batch dimension).
Please reshape your input to match this format if your model doesn't adhere to this convention.",
)

"""
heatmap(x::AbstractArray)
Expand All @@ -23,54 +28,34 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di
- `rangescale::Symbol`: Selects how the color channel reduced heatmap is normalized
before the color scheme is applied. Can be either `:extrema` or `:centered`.
Defaults to `:$DEFAULT_RANGESCALE`.
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
- `process_batch::Bool`: When heatmapping a batch, setting `process_batch=true`
will apply the `rangescale` normalization to the entire batch
instead of computing it individually for each sample in the batch.
Defaults to `false`.
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
- `unpack_singleton::Bool`: If false, `heatmap` will always return a vector of images.
When heatmapping a batch with a single sample, setting `unpack_singleton=true`
will unpack the singleton vector and directly return the image. Defaults to `true`.
"""
function heatmap(
val::AbstractArray{T,N};
colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME,
reduce::Symbol=DEFAULT_REDUCE,
rangescale::Symbol=DEFAULT_RANGESCALE,
permute::Bool=true,
unpack_singleton::Bool=true,
process_batch::Bool=false,
) where {T,N}
heatmap(val; kwargs...) = heatmap(val, HeatmapOptions(; kwargs...))
function heatmap(val::AbstractArray{T,N}, options::HeatmapOptions) where {T,N}
N != 4 && throw(InputDimensionError)
colorscheme = get_colorscheme(colorscheme)
if unpack_singleton && size(val, 4) == 1
return single_heatmap(val[:, :, :, 1], colorscheme, reduce, rangescale, permute)
if options.unpack_singleton && size(val, 4) == 1
return single_heatmap(val[:, :, :, 1], options)
end
if process_batch
hs = single_heatmap(val, colorscheme, reduce, rangescale, permute)
if options.process_batch
hs = single_heatmap(val, options)
return [hs[:, :, i] for i in axes(hs, 3)]
end
return [
single_heatmap(v, colorscheme, reduce, rangescale, permute) for
v in eachslice(val; dims=4)
]
return [single_heatmap(v, options) for v in eachslice(val; dims=4)]
end

const InputDimensionError = ArgumentError(
"heatmap assumes the WHCN convention for input array dimensions (width, height, color channels, batch dimension).
Please reshape your input to match this format if your model doesn't adhere to this convention.",
)

get_colorscheme(c::ColorScheme) = c
get_colorscheme(s::Symbol)::ColorScheme = colorschemes[s]

# Lower level function, mapped along batch dimension
function single_heatmap(
val, colorscheme::ColorScheme, reduce::Symbol, rangescale::Symbol, permute::Bool
)
img = dropdims(reduce_color_channel(val, reduce); dims=3)
permute && (img = flip_wh(img))
return get(colorscheme, img, rangescale)
function single_heatmap(val, options::HeatmapOptions)
img = dropdims(reduce_color_channel(val, options.reduce); dims=3)
options.permute && (img = flip_wh(img))
cs = get_colorscheme(options)
return get(cs, img, options.rangescale)
end

flip_wh(img::AbstractArray{T,2}) where {T} = permutedims(img, (2, 1))
Expand Down Expand Up @@ -98,3 +83,35 @@ function reduce_color_channel(val::AbstractArray, method::Symbol)
),
)
end

#=================#
# XAIBase support #
#=================#

"""
heatmap(expl::Explanation)
Visualize `Explanation` from XAIBase as a vision heatmap.
Assumes WHCN convention (width, height, channels, batch dimension) for `explanation.val`.
This will use the default heatmapping style for the given type of explanation.
Defaults can be overridden via keyword arguments.
"""
heatmap(expl::Explanation; kwargs...) = heatmap(expl.val, HeatmapOptions(expl; kwargs...))

"""
heatmap(input::AbstractArray, analyzer::AbstractXAIMethod)
Compute an `Explanation` for a given `input` using the XAI method `analyzer` and visualize it
as a vision heatmap.
Any additional arguments and keyword arguments are passed to the analyzer.
Refer to the `analyze` documentation for more information on available keyword arguments.
To customize the heatmapping style, first compute an explanation using `analyze`
and then call [`heatmap`](@ref) on the explanation.
"""
function heatmap(input, analyzer::AbstractXAIMethod, analyze_args...; analyze_kwargs...)
expl = analyze(input, analyzer, analyze_args...; analyze_kwargs...)
return heatmap(expl)
end
9 changes: 6 additions & 3 deletions src/overlay.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const DEFAULT_OVERLAY_ALPHA = 0.6
const DEFAULT_RESIZE_METHOD = Lanczos(1)

"""
heatmap_overlay(val, img)
Expand All @@ -19,8 +20,8 @@ function heatmap_overlay(
val::AbstractArray{T,N},
im::AbstractMatrix{<:Colorant};
alpha=DEFAULT_OVERLAY_ALPHA,
resize_method=Lanczos(1),
kwargs...,
resize_method=DEFAULT_RESIZE_METHOD,
heatmap_kwargs...,
) where {T,N}
N != 4 && throw(InputDimensionError)
if size(val, 4) != 1
Expand All @@ -33,7 +34,9 @@ function heatmap_overlay(
if alpha < 0 || alpha > 1
throw(ArgumentError("alpha must be in the range [0, 1]"))
end
hm = heatmap(val; kwargs...)

options = HeatmapOptions(; heatmap_kwargs...)
hm = heatmap(val, options)
hmsize = size(hm)
imsize = size(im)
if hmsize != imsize
Expand Down
67 changes: 0 additions & 67 deletions src/xaibase.jl

This file was deleted.

1 change: 1 addition & 0 deletions test/references/abssum_centered_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_extrema_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/abssum_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_centered_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_extrema_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/maxabs_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_centered_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_extrema_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/norm_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions test/references/overlay_maxabs_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions test/references/overlay_maxabs_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_centered_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_extrema_grays.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
1 change: 1 addition & 0 deletions test/references/overlay_norm_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_maxabs_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_maxabs_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_maxabs_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_maxabs_extrema_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_norm_centered_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_norm_centered_seismic.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
3 changes: 3 additions & 0 deletions test/references/overlay_rescale_norm_extrema_jet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
▀▀▀▀▀▀
▀▀▀▀▀▀
▀▀▀▀▀▀
Loading

0 comments on commit 034bffd

Please sign in to comment.