Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor XAIBase interface #20

Merged
merged 12 commits into from
Jul 27, 2024
Merged
54 changes: 32 additions & 22 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
# XAIBase.jl
## Version `v4.0.0`
* ![BREAKING][badge-breaking] Implementing new analyzers now requires a `call_analyzer` method instead of making the analyzer struct callable. This helps with type stability ([#20])
* ![BREAKING][badge-breaking] Add `input` field to `Explanation` struct
* ![BREAKING][badge-breaking] Remove `analyze` keyword-argument `add_batch_dim`, which made the assumption of array inputs ([#20])
* ![Feature][badge-feature] Remove type annotations that restricted `analyze` to `AbstractArray` inputs ([#20])
* ![Maintenance][badge-maintenance] XAIBase is now fully type stable and tested with JET.jl ([#20])
* ![Maintenance][badge-maintenance] Modularize tests ([#17])

## Version `v3.0.0`
* ![BREAKING][badge-breaking] Remove heatmapping functionality.
Users are now required to manually load either
[VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/) or
[TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/). ([#16][pr-16])
[TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/). ([#16])

## Version `v2.0.0`
* ![BREAKING][badge-breaking] Rename `AbstractNeuronSelector` to `AbstractOutputSelector` ([#14][pr-14])
* ![Feature][badge-feature] Export output selectors ([#15][pr-15])
* ![Documentation][badge-docs] Add example implementations of XAI methods ([#13][pr-13])
* ![Documentation][badge-docs] Improved documentation of output and feature selectors ([#15][pr-15])
* ![BREAKING][badge-breaking] Rename `AbstractNeuronSelector` to `AbstractOutputSelector` ([#14])
* ![Feature][badge-feature] Export output selectors ([#15])
* ![Documentation][badge-docs] Add example implementations of XAI methods ([#13])
* ![Documentation][badge-docs] Improved documentation of output and feature selectors ([#15])

## Version `v1.3.0`
* ![Feature][badge-feature] Add feature selectors ([#12][pr-12])
* ![Documentation][badge-docs] Add documentation ([#11][pr-11])
* ![Feature][badge-feature] Add feature selectors ([#12])
* ![Documentation][badge-docs] Add documentation ([#11])

## Version `v1.2.0`
* ![Feature][badge-feature] Add API for direct heatmapping ([#9][pr-9])
* ![Feature][badge-feature] Add API for direct heatmapping ([#9])

## Version `v1.1.1`
* ![Bugfix][badge-bugfix] Fix keyword argument `add_batch_dim` ([#8][pr-8])
* ![Bugfix][badge-bugfix] Fix keyword argument `add_batch_dim` ([#8])

## Version `v1.1.0`
This release makes VisionHeatmaps.jl and TextHeatmaps.jl strong dependencies of XAIBase ([#4][pr-4])
* ![Feature][badge-feature] Add `heatmap` preset field to `Explanation` struct ([#5][pr-5], [#6][pr-6])
This release makes VisionHeatmaps.jl and TextHeatmaps.jl strong dependencies of XAIBase ([#4])
* ![Feature][badge-feature] Add `heatmap` preset field to `Explanation` struct ([#5], [#6])
* ![Feature][badge-feature] Add heatmapping preset for CAM methods ([5658de9](https://github.com/Julia-XAI/XAIBase.jl/commit/5658de9))

## Version `v1.0.0`
Expand All @@ -41,17 +49,19 @@ This release makes VisionHeatmaps.jl and TextHeatmaps.jl strong dependencies of
![Documentation][badge-docs]
-->

[pr-16]: https://github.com/Julia-XAI/XAIBase.jl/pull/16
[pr-15]: https://github.com/Julia-XAI/XAIBase.jl/pull/15
[pr-14]: https://github.com/Julia-XAI/XAIBase.jl/pull/14
[pr-13]: https://github.com/Julia-XAI/XAIBase.jl/pull/13
[pr-12]: https://github.com/Julia-XAI/XAIBase.jl/pull/12
[pr-11]: https://github.com/Julia-XAI/XAIBase.jl/pull/11
[pr-9]: https://github.com/Julia-XAI/XAIBase.jl/pull/9
[pr-8]: https://github.com/Julia-XAI/XAIBase.jl/pull/8
[pr-6]: https://github.com/Julia-XAI/XAIBase.jl/pull/6
[pr-5]: https://github.com/Julia-XAI/XAIBase.jl/pull/5
[pr-4]: https://github.com/Julia-XAI/XAIBase.jl/pull/4
[#20]: https://github.com/Julia-XAI/XAIBase.jl/pull/20
[#17]: https://github.com/Julia-XAI/XAIBase.jl/pull/17
[#16]: https://github.com/Julia-XAI/XAIBase.jl/pull/16
[#15]: https://github.com/Julia-XAI/XAIBase.jl/pull/15
[#14]: https://github.com/Julia-XAI/XAIBase.jl/pull/14
[#13]: https://github.com/Julia-XAI/XAIBase.jl/pull/13
[#12]: https://github.com/Julia-XAI/XAIBase.jl/pull/12
[#11]: https://github.com/Julia-XAI/XAIBase.jl/pull/11
[#9]: https://github.com/Julia-XAI/XAIBase.jl/pull/9
[#8]: https://github.com/Julia-XAI/XAIBase.jl/pull/8
[#6]: https://github.com/Julia-XAI/XAIBase.jl/pull/6
[#5]: https://github.com/Julia-XAI/XAIBase.jl/pull/5
[#4]: https://github.com/Julia-XAI/XAIBase.jl/pull/4

[badge-breaking]: https://img.shields.io/badge/BREAKING-red.svg
[badge-deprecation]: https://img.shields.io/badge/deprecation-orange.svg
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "XAIBase"
uuid = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
authors = ["Adrian Hill <gh@adrianhill.de>"]
version = "3.0.0"
version = "4.0.0-DEV"

[compat]
julia = "1.6"
26 changes: 19 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://Julia-XAI.github.io/XAIBase.jl/dev/)
[![Build Status](https://github.com/Julia-XAI/XAIBase.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/Julia-XAI/XAIBase.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/Julia-XAI/XAIBase.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Julia-XAI/XAIBase.jl)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![Aqua][aqua-img]][aqua-url]
[![JET][jet-img]][jet-url]

XAIBase is a light-weight dependency that defines the interface of XAI methods in the [Julia-XAI ecosystem](https://github.com/Julia-XAI),
which focusses on post-hoc, local input space explanations of black-box models.
Expand All @@ -21,13 +22,15 @@ It also allows you to use input-augmentations from [ExplainableAI.jl][url-explai
XAIBase only requires you to fulfill the following two requirements:

1. An XAI algorithm has to be a subtype of [`AbstractXAIMethod`][docs-abstractxaimethod]
2. An XAI algorithm has to implement the following method:
2. An XAI algorithm has to implement a `call_analyzer` method:

```julia
(method::MyMethod)(input, output_selector::AbstractOutputSelector)
import XAIBase: call_analyzer

call_analyzer(input, method::MyMethod, output_selector::AbstractOutputSelector; kwargs...)
```

* the method has to return an [`Explanation`][docs-explanation]
* `call_analyzer` has to return an [`Explanation`][docs-explanation]
* the input is expected to have a batch dimensions as its last dimension
* when applied to a batch, the method returns a single [`Explanation`][docs-explanation],
which contains the batched output in the `val` field.
Expand All @@ -42,17 +45,20 @@ For more information, take a look at the [documentation][docs].
Julia-XAI methods will usually follow the following template:

```julia
using XAIBase
import XAIBase: call_analyzer

struct MyMethod{M} <: AbstractXAIMethod
model::M
end

function (method::MyMethod)(input, output_selector::AbstractOutputSelector)
function call_analyzer(input, method::MyMethod, output_selector::AbstractOutputSelector; kwargs...)
output = method.model(input)
output_selection = output_selector(output)

val = ... # your method's implementation
extras = nothing # optionally add additional information using a named tuple
return Explanation(val, output, output_selection, :MyMethod, :attribution, extras)
return Explanation(val, input, output, output_selection, :MyMethod, :attribution, extras)
end
```

Expand All @@ -73,4 +79,10 @@ end
[docs-abstractxaimethod]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.AbstractXAIMethod
[docs-abstractoutputselector]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.AbstractOutputSelector
[docs-maxactivationselector]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.MaxActivationSelector
[docs-indexselector]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.IndexSelector
[docs-indexselector]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.IndexSelector

[aqua-img]: https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg
[aqua-url]: https://github.com/JuliaTesting/Aqua.jl

[jet-img]: https://img.shields.io/badge/%F0%9F%9B%A9%EF%B8%8F_tested_with-JET.jl-233f9a
[jet-url]: https://github.com/aviatesk/JET.jl
11 changes: 7 additions & 4 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ that returns a random explanation in the shape of the input.

```@example implementations
using XAIBase
import XAIBase: call_analyzer

struct RandomAnalyzer{M} <: AbstractXAIMethod
model::M
end

function (method::RandomAnalyzer)(input, output_selector::AbstractOutputSelector)
function call_analyzer(input, method::RandomAnalyzer, output_selector::AbstractOutputSelector; kwargs...)
output = method.model(input)
output_selection = output_selector(output)

val = rand(size(input)...)
return Explanation(val, output, output_selection, :RandomAnalyzer, :sensitivity, nothing)
return Explanation(val, input, output, output_selection, :RandomAnalyzer, :sensitivity, nothing)
end
```

Expand Down Expand Up @@ -70,19 +71,21 @@ In this second example, we naively reimplement the `Gradient` analyzer from

```@example implementations
using XAIBase
import XAIBase: call_analyzer

using Zygote: gradient

struct MyGradient{M} <: AbstractXAIMethod
model::M
end

function (method::MyGradient)(input, output_selector::AbstractOutputSelector)
function call_analyzer(input, method::MyGradient, output_selector::AbstractOutputSelector; kwargs...)
output = method.model(input)
output_selection = output_selector(output)

grad = gradient((x) -> only(method.model(x)[output_selection]), input)
val = only(grad)
return Explanation(val, output, output_selection, :MyGradient, :sensitivity, nothing)
return Explanation(val, input, output, output_selection, :MyGradient, :sensitivity, nothing)
end
```

Expand Down
15 changes: 10 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ and [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/).
This only requires you to fulfill the following two requirements:

1. An XAI method has to be a subtype of `AbstractXAIMethod`
2. An XAI method has to implement the following method:
2. An XAI algorithm has to implement a `call_analyzer` method:

```julia
(method::MyMethod)(input, output_selector::AbstractOutputSelector)
import XAIBase: call_analyzer

call_analyzer(input, method::MyMethod, output_selector::AbstractOutputSelector; kwargs...)
```

* The method has to return an [`Explanation`](@ref)
* `call_analyzer` has to return an [`Explanation`](@ref)
* The input is expected to have a batch dimensions as its last dimension
* When applied to a batch, the method returns a single [`Explanation`](@ref),
which contains the batched output in the `val` field.
Expand All @@ -35,17 +37,20 @@ For more information, take a look at [`src/XAIBase.jl`](https://github.com/Julia
Julia-XAI methods will usually follow the following template:

```julia
using XAIBase
import XAIBase: call_analyzer

struct MyMethod{M} <: AbstractXAIMethod
model::M
end

function (method::MyMethod)(input, output_selector::AbstractOutputSelector)
function call_analyzer(input, method::MyMethod, output_selector::AbstractOutputSelector; kwargs...)
output = method.model(input)
output_selection = output_selector(output)

val = ... # your method's implementation
extras = nothing # optionally add additional information using a named tuple
return Explanation(val, output, output_selection, :MyMethod, :attribution, extras)
return Explanation(val, input, output, output_selection, :MyMethod, :attribution, extras)
end
```

Expand Down
2 changes: 2 additions & 0 deletions src/XAIBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and `heatmap` functionality by loading either VisionHeatmaps.jl or TextHeatmaps.
"""
abstract type AbstractXAIMethod end

include("exceptions.jl")

# Output selectors of type `AbstractOutputSelector` for class-specific explanations.
# These are used to automatically select the maximally activated output.
include("output_selection.jl")
Expand Down
58 changes: 24 additions & 34 deletions src/analyze.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@

const BATCHDIM_MISSING = ArgumentError(
"""The input is a 1D vector and therefore missing the required batch dimension.
Call `analyze` with the keyword argument `add_batch_dim=true`."""
)

"""
analyze(input, method)
analyze(input, method, output_selection)
Expand All @@ -13,45 +7,41 @@ If `output_selection` is specified, the explanation will be calculated for that
Otherwise, the output with the highest activation is automatically chosen.

See also [`Explanation`](@ref).

## Keyword arguments
- `add_batch_dim`: add batch dimension to the input without allocating. Default is `false`.
"""
function analyze(
input::AbstractArray{<:Real},
input, method::AbstractXAIMethod, output_selector::AbstractOutputSelector; kwargs...
)
return call_analyzer(input, method, output_selector; kwargs...)
end
function analyze(
input,
method::AbstractXAIMethod,
output_selection::Union{Integer,Tuple{<:Integer}};
kwargs...,
)
return _analyze(input, method, IndexSelector(output_selection); kwargs...)
return call_analyzer(input, method, IndexSelector(output_selection); kwargs...)
end

function analyze(input::AbstractArray{<:Real}, method::AbstractXAIMethod; kwargs...)
return _analyze(input, method, MaxActivationSelector(); kwargs...)
function analyze(input, method::AbstractXAIMethod; kwargs...)
return call_analyzer(input, method, MaxActivationSelector(); kwargs...)
end

function (method::AbstractXAIMethod)(
input::AbstractArray{<:Real},
output_selection::Union{Integer,Tuple{<:Integer}};
kwargs...,
)
return _analyze(input, method, IndexSelector(output_selection); kwargs...)
# Direct calls to analyzer
function (method::AbstractXAIMethod)(input; kwargs...)
analyze(input, method, MaxActivationSelector(); kwargs...)
end
function (method::AbstractXAIMethod)(input::AbstractArray{<:Real}; kwargs...)
return _analyze(input, method, MaxActivationSelector(); kwargs...)
function (method::AbstractXAIMethod)(input, output_selector; kwargs...)
analyze(input, method, output_selector; kwargs...)
end

# lower-level call to method
function _analyze(
input::AbstractArray{T,N},
method::AbstractXAIMethod,
sel::AbstractOutputSelector;
add_batch_dim::Bool=false,
kwargs...,
) where {T<:Real,N}
if add_batch_dim
return method(batch_dim_view(input), sel; kwargs...)
end
N < 2 && throw(BATCHDIM_MISSING)
return method(input, sel; kwargs...)
# Throw NotImplementedError as a fallback
function call_analyzer(
input, method::AbstractXAIMethod, output_selector::AbstractOutputSelector; kwargs...
)
return throw(
NotImplementedError(
method,
"call_analyzer(input, method::T, output_selector::AbstractOutputSelector; kwargs...)",
),
)
end
11 changes: 11 additions & 0 deletions src/exceptions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
struct NotImplementedError <: Exception
analyzer::AbstractXAIMethod
method::String
end

function Base.showerror(io::IO, e::NotImplementedError)
T = string(typeof(e.analyzer))
printstyled(io, "NotImplementedError: "; color=:red)
println(io, "The `$T` analyzer doesn't fully implement the XAIBase interface.")
print(io, "Please implement `", e.method, "` for your type `T<:$T`.")
end
9 changes: 6 additions & 3 deletions src/explanation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ Return type of analyzers when calling [`analyze`](@ref).
* `extras`: optional named tuple that can be used by analyzers
to return additional information.
"""
struct Explanation{V,O,S,E<:Union{Nothing,NamedTuple}}
struct Explanation{V,I,O,S,E<:Union{Nothing,NamedTuple}}
val::V
input::I
output::O
output_selection::S
analyzer::Symbol
heatmap::Symbol
extras::E
end
function Explanation(val, output, output_selection, analyzer::Symbol, heatmap::Symbol)
return Explanation(val, output, output_selection, analyzer, heatmap, nothing)
function Explanation(
val, input, output, output_selection, analyzer::Symbol, heatmap::Symbol
)
return Explanation(val, input, output, output_selection, analyzer, heatmap, nothing)
end
8 changes: 6 additions & 2 deletions src/output_selection.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
const NOTE_OUTPUT_SELECTOR = """## Note
const BATCHDIM_MISSING = ArgumentError(
"The input is a 1D vector and therefore missing the required batch dimension."
)

const NOTE_OUTPUT_SELECTOR = "## Note
XAIBase assumes that the batch dimension is the last dimension of the output.
"""
"

"""
Abstract super type of all output selectors in XAIBase.
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
Expand Down
Loading
Loading