Skip to content

Commit

Permalink
Make XAIBase a direct dependency (#8)
Browse files Browse the repository at this point in the history
* Make XAIBase a direct dependency

* Improve documentation
  • Loading branch information
adrhill authored Feb 19, 2024
1 parent 40146ad commit fa0bcee
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 37 deletions.
10 changes: 1 addition & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
name = "VisionHeatmaps"
uuid = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"
authors = ["Adrian Hill <gh@adrianhill.de>"]
version = "1.3.0"
version = "1.3.1-DEV"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[weakdeps]
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[extensions]
VisionHeatmapsXAIBaseExt = "XAIBase"

[compat]
ColorSchemes = "3"
ImageCore = "0.9, 0.10"
ImageTransformations = "0.10"
Interpolations = "0.15"
Requires = "1"
XAIBase = "3"
julia = "1.6"
12 changes: 2 additions & 10 deletions src/VisionHeatmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@ using ColorSchemes: ColorScheme, colorschemes, get, seismic
using ImageTransformations: imresize
using Interpolations: Lanczos
using ImageCore
using Requires: @require
using XAIBase: Explanation, AbstractXAIMethod, analyze

include("heatmap.jl")
include("overlay.jl")

if !isdefined(Base, :get_extension)
using Requires
function __init__()
@require XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" include(
"../ext/VisionHeatmapsXAIBaseExt.jl"
)
end
end
include("xaibase.jl")

export heatmap, heatmap_overlay

Expand Down
4 changes: 2 additions & 2 deletions src/heatmap.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
const DEFAULT_COLORSCHEME = seismic
const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered

"""
heatmap(x)
heatmap(x::AbstractArray)
Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array dimensions
(width, height, color channels, batch dimension).
Expand Down
22 changes: 8 additions & 14 deletions ext/VisionHeatmapsXAIBaseExt.jl → src/xaibase.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
module VisionHeatmapsXAIBaseExt

using VisionHeatmaps, XAIBase

struct HeatmapConfig
colorscheme::Symbol
reduce::Symbol
rangescale::Symbol
end

const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered
const DEFAULT_HEATMAP_PRESET = HeatmapConfig(
DEFAULT_COLORSCHEME, DEFAULT_REDUCE, DEFAULT_RANGESCALE
)
Expand All @@ -37,12 +30,15 @@ function get_heatmapping_config(expl::Explanation; kwargs...)
end

"""
heatmap(explanation)
heatmap(expl::Explanation)
Visualize `Explanation` from XAIBase as a vision heatmap.
Assumes WHCN convention (width, height, channels, batchsize) for `explanation.val`.
Assumes WHCN convention (width, height, channels, batch dimension) for `explanation.val`.
This will use the default heatmapping style for the given type of explanation.
Defaults can be overridden via keyword arguments.
"""
function VisionHeatmaps.heatmap(expl::Explanation; kwargs...)
function heatmap(expl::Explanation; kwargs...)
c = get_heatmapping_config(expl; kwargs...)
return heatmap(
expl.val;
Expand All @@ -54,7 +50,7 @@ function VisionHeatmaps.heatmap(expl::Explanation; kwargs...)
end

"""
heatmap(input, analyzer)
heatmap(input::AbstractArray, analyzer::AbstractXAIMethod)
Compute an `Explanation` for a given `input` using the XAI method `analyzer` and visualize it
as a vision heatmap.
Expand All @@ -65,9 +61,7 @@ Refer to the `analyze` documentation for more information on available keyword a
To customize the heatmapping style, first compute an explanation using `analyze`
and then call [`heatmap`](@ref) on the explanation.
"""
function VisionHeatmaps.heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
function heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
expl = analyze(input, analyzer, args...; kwargs...)
return heatmap(expl)
end

end # module
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ using JuliaFormatter
@info "Testing heatmaps..."
include("test_heatmap.jl")
end
@testset "XAIBase extension" begin
@testset "XAIBase Explanations" begin
@info "Testing heatmaps on XAIBase explanations..."
include("test_xaibase_ext.jl")
include("test_xaibase.jl")
end
end
File renamed without changes.

0 comments on commit fa0bcee

Please sign in to comment.