Skip to content

Commit

Permalink
Remove heatmapping functionality (#16)
Browse files Browse the repository at this point in the history
Move heatmapping on `Explanation` type to VisionHeatmaps.jl and TextHeatmaps.jl via package extensions on XAIBase.

This change is breaking: users are now required to manually load either VisionHeatmaps or TextHeatmaps.
  • Loading branch information
adrhill authored Feb 19, 2024
1 parent 63d7f60 commit bd65d33
Show file tree
Hide file tree
Showing 14 changed files with 31 additions and 252 deletions.
8 changes: 1 addition & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
name = "XAIBase"
uuid = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
authors = ["Adrian Hill <gh@adrianhill.de>"]
version = "2.0.0"

[deps]
TextHeatmaps = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"
version = "3.0.0-DEV"

[compat]
TextHeatmaps = "1.1"
VisionHeatmaps = "1.1"
julia = "1.6"
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ In simpler terms, methods that try to answer the question

Building on top of XAIBase (or providing an interface via [package extensions][docs-extensions])
makes your package compatible with the Julia-XAI ecosystem,
allowing you to automatically compute heatmaps for vision and language models.
allowing you to automatically compute heatmaps for vision and language models
using [VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/)
and [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/).
It also allows you to use input-augmentations from [ExplainableAI.jl][url-explainableai].

## Interface description
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6 changes: 0 additions & 6 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ The return type of `analyze` is an `Explanation`:
Explanation
```

## Visualizing explanations
`Explanation`s can be visualized using `heatmap`:
```@docs
heatmap
```

## Feature selection
```@docs
AbstractFeatureSelector
Expand Down
25 changes: 19 additions & 6 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,23 @@ function (method::RandomAnalyzer)(input, output_selector::AbstractOutputSelector
end
```

We can directly use XAIBase's `analyze` and `heatmap` functions
to compute and visualize the random explanation:
We can directly use XAIBase's `analyze` function
to compute the random explanation:

```@example implementations
analyzer = RandomAnalyzer(model)
heatmap(input, analyzer)
expl = analyze(input, analyzer)
```

Using either [VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/)
or [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/),
which provide package extensions on XAIBase's `Explanation` type,
we can visualize the explanations:

```@example implementations
using VisionHeatmaps # load heatmapping functionality
heatmap(expl.val)
```

As expected, the explanation is just noise.
Expand Down Expand Up @@ -81,15 +92,17 @@ end
that works with batched inputs and only requires a single forward
and backward pass through the model.

Once again, we can directly use XAIBase's `analyze` and `heatmap` functions
Once again, we can directly use XAIBase's `analyze` and VisionHeatmaps' `heatmap` functions
```@example implementations
using VisionHeatmaps
analyzer = MyGradient(model)
expl = analyze(input, analyzer)
heatmap(expl)
heatmap(expl.val)
```

```@example implementations
heatmap(expl, colorscheme=:twilight, reduce=:norm, rangescale=:centered)
heatmap(expl.val, colorscheme=:twilight, reduce=:norm, rangescale=:centered)
```

and make use of all the features provided by the Julia-XAI ecosystem.
Expand Down
4 changes: 3 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ in the [Julia-XAI ecosystem](https://julia-xai.github.io/XAIDocs/).
Building on top of XAIBase
(or providing an interface via [package extensions](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)))
makes your package compatible with the Julia-XAI ecosystem,
allowing you to automatically compute heatmaps for vision and language models.
allowing you to automatically compute heatmaps for vision and language models
using [VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/)
and [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/).

This only requires you to fulfill the following two requirements:

Expand Down
10 changes: 2 additions & 8 deletions src/XAIBase.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
module XAIBase

using TextHeatmaps
using VisionHeatmaps

include("compat.jl")
include("utils.jl")

Expand All @@ -15,7 +12,8 @@ It is expected that all XAI methods are callable types that return an `Explanati
(method::AbstractXAIMethod)(input, output_selector::AbstractOutputSelector)
```
If this function is implemented, XAIBase will provide the `analyze` and `heatmap` functionality.
If this function is implemented, XAIBase will provide the `analyze` functionality
and `heatmap` functionality by loading either VisionHeatmaps.jl or TextHeatmaps.jl.
"""
abstract type AbstractXAIMethod end

Expand All @@ -31,16 +29,12 @@ include("explanation.jl")
# which in turn calls `(method)(input, output_selector)`.
include("analyze.jl")

# Heatmapping for vision and NLP tasks.
include("heatmaps.jl")

# Utilities for XAI methods that compute Explanations w.r.t. specific features:
include("feature_selection.jl")

export AbstractXAIMethod
export Explanation
export analyze
export heatmap
export AbstractOutputSelector, MaxActivationSelector, IndexSelector
export AbstractFeatureSelector, IndexedFeatures, TopNFeatures
end #module
2 changes: 1 addition & 1 deletion src/analyze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Apply the analyzer `method` for the given input, returning an [`Explanation`](@r
If `output_selection` is specified, the explanation will be calculated for that output.
Otherwise, the output with the highest activation is automatically chosen.
See also [`Explanation`](@ref) and [`heatmap`](@ref).
See also [`Explanation`](@ref).
## Keyword arguments
- `add_batch_dim`: add batch dimension to the input without allocating. Default is `false`.
Expand Down
135 changes: 0 additions & 135 deletions src/heatmaps.jl

This file was deleted.

5 changes: 0 additions & 5 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,3 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TextHeatmaps = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"

[compat]
Aqua = "0.8"
8 changes: 0 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,4 @@ using JuliaFormatter
@info "Testing feature selection..."
include("test_feature_selection.jl")
end
@testset "Vision heatmaps" begin
@info "Testing vision heatmaps..."
include("test_heatmap.jl")
end
@testset "Text heatmaps" begin
@info "Testing text heatmaps..."
include("test_textheatmap.jl")
end
end
23 changes: 1 addition & 22 deletions test/test_api.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Create dummy analyzer to test API and heatmapping
# Create dummy analyzer to test API
struct DummyAnalyzer <: AbstractXAIMethod end
function (method::DummyAnalyzer)(input, output_selector::AbstractOutputSelector)
output = input
Expand Down Expand Up @@ -38,24 +38,3 @@ expl = analyzer(input, output_index)
# Ouput selection + add_batch_dim
expl = analyzer(input_vec, output_index; add_batch_dim=true)
@test expl.val == val[:, 1:1]

# Test direct heatmapping
input = rand(5, 5, 3, 1)

h1 = heatmap(analyze(input, analyzer))
h2 = heatmap(input, analyzer)
@test h1 == h2

h1 = heatmap(analyze(input, analyzer, 5))
h2 = heatmap(input, analyzer, 5)
@test h1 == h2

input = rand(5, 5, 3)

h1 = heatmap(analyze(input, analyzer; add_batch_dim=true))
h2 = heatmap(input, analyzer; add_batch_dim=true)
@test h1 == h2

h1 = heatmap(analyze(input, analyzer, 5; add_batch_dim=true))
h2 = heatmap(input, analyzer, 5; add_batch_dim=true)
@test h1 == h2
30 changes: 0 additions & 30 deletions test/test_heatmap.jl

This file was deleted.

22 changes: 0 additions & 22 deletions test/test_textheatmap.jl

This file was deleted.

0 comments on commit bd65d33

Please sign in to comment.