diff --git a/CHANGELOG.md b/CHANGELOG.md index 3859d13..feccd61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,29 +1,37 @@ # XAIBase.jl +## Version `v4.0.0` +* ![BREAKING][badge-breaking] Implementing new analyzers now requires a `call_analyzer` method instead of making the analyzer struct callable. This helps with type stability ([#20]) +* ![BREAKING][badge-breaking] Add `input` field to `Explanation` struct +* ![BREAKING][badge-breaking] Remove `analyze` keyword-argument `add_batch_dim`, which made the assumption of array inputs ([#20]) +* ![Feature][badge-feature] Remove type annotations that restricted `analyze` to `AbstractArray` inputs ([#20]) +* ![Maintenance][badge-maintenance] XAIBase is now fully type stable and tested with JET.jl ([#20]) +* ![Maintenance][badge-maintenance] Modularize tests ([#17]) + ## Version `v3.0.0` * ![BREAKING][badge-breaking] Remove heatmapping functionality. Users are now required to manually load either [VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/) or - [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/). ([#16][pr-16]) + [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/). ([#16]) ## Version `v2.0.0` -* ![BREAKING][badge-breaking] Rename `AbstractNeuronSelector` to `AbstractOutputSelector` ([#14][pr-14]) -* ![Feature][badge-feature] Export output selectors ([#15][pr-15]) -* ![Documentation][badge-docs] Add example implementations of XAI methods ([#13][pr-13]) -* ![Documentation][badge-docs] Improved documentation of output and feature selectors ([#15][pr-15]) +* ![BREAKING][badge-breaking] Rename `AbstractNeuronSelector` to `AbstractOutputSelector` ([#14]) +* ![Feature][badge-feature] Export output selectors ([#15]) +* ![Documentation][badge-docs] Add example implementations of XAI methods ([#13]) +* ![Documentation][badge-docs] Improved documentation of output and feature selectors ([#15]) ## Version `v1.3.0` -* ![Feature][badge-feature] Add feature selectors ([#12][pr-12]) -* ![Documentation][badge-docs] Add documentation ([#11][pr-11]) +* ![Feature][badge-feature] Add feature selectors ([#12]) +* ![Documentation][badge-docs] Add documentation ([#11]) ## Version `v1.2.0` -* ![Feature][badge-feature] Add API for direct heatmapping ([#9][pr-9]) +* ![Feature][badge-feature] Add API for direct heatmapping ([#9]) ## Version `v1.1.1` -* ![Bugfix][badge-bugfix] Fix keyword argument `add_batch_dim` ([#8][pr-8]) +* ![Bugfix][badge-bugfix] Fix keyword argument `add_batch_dim` ([#8]) ## Version `v1.1.0` -This release makes VisionHeatmaps.jl and TextHeatmaps.jl strong dependencies of XAIBase ([#4][pr-4]) -* ![Feature][badge-feature] Add `heatmap` preset field to `Explanation` struct ([#5][pr-5], [#6][pr-6]) +This release makes VisionHeatmaps.jl and TextHeatmaps.jl strong dependencies of XAIBase ([#4]) +* ![Feature][badge-feature] Add `heatmap` preset field to `Explanation` struct ([#5], [#6]) * ![Feature][badge-feature] Add heatmapping preset for CAM methods ([5658de9](https://github.com/Julia-XAI/XAIBase.jl/commit/5658de9)) ## Version `v1.0.0` @@ -41,17 +49,19 @@ This release makes VisionHeatmaps.jl and TextHeatmaps.jl strong dependencies of ![Documentation][badge-docs] --> -[pr-16]: https://github.com/Julia-XAI/XAIBase.jl/pull/16 -[pr-15]: https://github.com/Julia-XAI/XAIBase.jl/pull/15 -[pr-14]: https://github.com/Julia-XAI/XAIBase.jl/pull/14 -[pr-13]: https://github.com/Julia-XAI/XAIBase.jl/pull/13 -[pr-12]: https://github.com/Julia-XAI/XAIBase.jl/pull/12 -[pr-11]: https://github.com/Julia-XAI/XAIBase.jl/pull/11 -[pr-9]: https://github.com/Julia-XAI/XAIBase.jl/pull/9 -[pr-8]: https://github.com/Julia-XAI/XAIBase.jl/pull/8 -[pr-6]: https://github.com/Julia-XAI/XAIBase.jl/pull/6 -[pr-5]: https://github.com/Julia-XAI/XAIBase.jl/pull/5 -[pr-4]: https://github.com/Julia-XAI/XAIBase.jl/pull/4 +[#20]: https://github.com/Julia-XAI/XAIBase.jl/pull/20 +[#17]: https://github.com/Julia-XAI/XAIBase.jl/pull/17 +[#16]: https://github.com/Julia-XAI/XAIBase.jl/pull/16 +[#15]: https://github.com/Julia-XAI/XAIBase.jl/pull/15 +[#14]: https://github.com/Julia-XAI/XAIBase.jl/pull/14 +[#13]: https://github.com/Julia-XAI/XAIBase.jl/pull/13 +[#12]: https://github.com/Julia-XAI/XAIBase.jl/pull/12 +[#11]: https://github.com/Julia-XAI/XAIBase.jl/pull/11 +[#9]: https://github.com/Julia-XAI/XAIBase.jl/pull/9 +[#8]: https://github.com/Julia-XAI/XAIBase.jl/pull/8 +[#6]: https://github.com/Julia-XAI/XAIBase.jl/pull/6 +[#5]: https://github.com/Julia-XAI/XAIBase.jl/pull/5 +[#4]: https://github.com/Julia-XAI/XAIBase.jl/pull/4 [badge-breaking]: https://img.shields.io/badge/BREAKING-red.svg [badge-deprecation]: https://img.shields.io/badge/deprecation-orange.svg diff --git a/Project.toml b/Project.toml index 39bcef5..c82db28 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "XAIBase" uuid = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" authors = ["Adrian Hill "] -version = "3.0.0" +version = "4.0.0-DEV" [compat] julia = "1.6" diff --git a/README.md b/README.md index efaf495..a6bc083 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@ [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://Julia-XAI.github.io/XAIBase.jl/dev/) [![Build Status](https://github.com/Julia-XAI/XAIBase.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/Julia-XAI/XAIBase.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/Julia-XAI/XAIBase.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Julia-XAI/XAIBase.jl) -[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) +[![Aqua][aqua-img]][aqua-url] +[![JET][jet-img]][jet-url] XAIBase is a light-weight dependency that defines the interface of XAI methods in the [Julia-XAI ecosystem](https://github.com/Julia-XAI), which focusses on post-hoc, local input space explanations of black-box models. @@ -21,13 +22,15 @@ It also allows you to use input-augmentations from [ExplainableAI.jl][url-explai XAIBase only requires you to fulfill the following two requirements: 1. An XAI algorithm has to be a subtype of [`AbstractXAIMethod`][docs-abstractxaimethod] -2. An XAI algorithm has to implement the following method: +2. An XAI algorithm has to implement a `call_analyzer` method: ```julia -(method::MyMethod)(input, output_selector::AbstractOutputSelector) +import XAIBase: call_analyzer + +call_analyzer(input, method::MyMethod, output_selector::AbstractOutputSelector; kwargs...) ``` -* the method has to return an [`Explanation`][docs-explanation] +* `call_analyzer` has to return an [`Explanation`][docs-explanation] * the input is expected to have a batch dimensions as its last dimension * when applied to a batch, the method returns a single [`Explanation`][docs-explanation], which contains the batched output in the `val` field. @@ -42,17 +45,20 @@ For more information, take a look at the [documentation][docs]. Julia-XAI methods will usually follow the following template: ```julia +using XAIBase +import XAIBase: call_analyzer + struct MyMethod{M} <: AbstractXAIMethod model::M end -function (method::MyMethod)(input, output_selector::AbstractOutputSelector) +function call_analyzer(input, method::MyMethod, output_selector::AbstractOutputSelector; kwargs...) output = method.model(input) output_selection = output_selector(output) val = ... # your method's implementation extras = nothing # optionally add additional information using a named tuple - return Explanation(val, output, output_selection, :MyMethod, :attribution, extras) + return Explanation(val, input, output, output_selection, :MyMethod, :attribution, extras) end ``` @@ -73,4 +79,10 @@ end [docs-abstractxaimethod]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.AbstractXAIMethod [docs-abstractoutputselector]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.AbstractOutputSelector [docs-maxactivationselector]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.MaxActivationSelector -[docs-indexselector]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.IndexSelector \ No newline at end of file +[docs-indexselector]: https://julia-xai.github.io/XAIDocs/XAIBase/stable/api/#XAIBase.IndexSelector + +[aqua-img]: https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg +[aqua-url]: https://github.com/JuliaTesting/Aqua.jl + +[jet-img]: https://img.shields.io/badge/%F0%9F%9B%A9%EF%B8%8F_tested_with-JET.jl-233f9a +[jet-url]: https://github.com/aviatesk/JET.jl diff --git a/docs/src/examples.md b/docs/src/examples.md index f39172a..c71bae2 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -29,17 +29,18 @@ that returns a random explanation in the shape of the input. ```@example implementations using XAIBase +import XAIBase: call_analyzer struct RandomAnalyzer{M} <: AbstractXAIMethod model::M end -function (method::RandomAnalyzer)(input, output_selector::AbstractOutputSelector) +function call_analyzer(input, method::RandomAnalyzer, output_selector::AbstractOutputSelector; kwargs...) output = method.model(input) output_selection = output_selector(output) val = rand(size(input)...) - return Explanation(val, output, output_selection, :RandomAnalyzer, :sensitivity, nothing) + return Explanation(val, input, output, output_selection, :RandomAnalyzer, :sensitivity, nothing) end ``` @@ -70,19 +71,21 @@ In this second example, we naively reimplement the `Gradient` analyzer from ```@example implementations using XAIBase +import XAIBase: call_analyzer + using Zygote: gradient struct MyGradient{M} <: AbstractXAIMethod model::M end -function (method::MyGradient)(input, output_selector::AbstractOutputSelector) +function call_analyzer(input, method::MyGradient, output_selector::AbstractOutputSelector; kwargs...) output = method.model(input) output_selection = output_selector(output) grad = gradient((x) -> only(method.model(x)[output_selection]), input) val = only(grad) - return Explanation(val, output, output_selection, :MyGradient, :sensitivity, nothing) + return Explanation(val, input, output, output_selection, :MyGradient, :sensitivity, nothing) end ``` diff --git a/docs/src/index.md b/docs/src/index.md index b494bf2..a9edc6e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -13,13 +13,15 @@ and [TextHeatmaps.jl](https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/). This only requires you to fulfill the following two requirements: 1. An XAI method has to be a subtype of `AbstractXAIMethod` -2. An XAI method has to implement the following method: +2. An XAI algorithm has to implement a `call_analyzer` method: ```julia -(method::MyMethod)(input, output_selector::AbstractOutputSelector) +import XAIBase: call_analyzer + +call_analyzer(input, method::MyMethod, output_selector::AbstractOutputSelector; kwargs...) ``` -* The method has to return an [`Explanation`](@ref) +* `call_analyzer` has to return an [`Explanation`](@ref) * The input is expected to have a batch dimensions as its last dimension * When applied to a batch, the method returns a single [`Explanation`](@ref), which contains the batched output in the `val` field. @@ -35,17 +37,20 @@ For more information, take a look at [`src/XAIBase.jl`](https://github.com/Julia Julia-XAI methods will usually follow the following template: ```julia +using XAIBase +import XAIBase: call_analyzer + struct MyMethod{M} <: AbstractXAIMethod model::M end -function (method::MyMethod)(input, output_selector::AbstractOutputSelector) +function call_analyzer(input, method::MyMethod, output_selector::AbstractOutputSelector; kwargs...) output = method.model(input) output_selection = output_selector(output) val = ... # your method's implementation extras = nothing # optionally add additional information using a named tuple - return Explanation(val, output, output_selection, :MyMethod, :attribution, extras) + return Explanation(val, input, output, output_selection, :MyMethod, :attribution, extras) end ``` diff --git a/src/XAIBase.jl b/src/XAIBase.jl index fc9c7b9..1b5f299 100644 --- a/src/XAIBase.jl +++ b/src/XAIBase.jl @@ -17,6 +17,8 @@ and `heatmap` functionality by loading either VisionHeatmaps.jl or TextHeatmaps. """ abstract type AbstractXAIMethod end +include("exceptions.jl") + # Output selectors of type `AbstractOutputSelector` for class-specific explanations. # These are used to automatically select the maximally activated output. include("output_selection.jl") diff --git a/src/analyze.jl b/src/analyze.jl index bc1a248..9d80543 100644 --- a/src/analyze.jl +++ b/src/analyze.jl @@ -1,9 +1,3 @@ - -const BATCHDIM_MISSING = ArgumentError( - """The input is a 1D vector and therefore missing the required batch dimension. - Call `analyze` with the keyword argument `add_batch_dim=true`.""" -) - """ analyze(input, method) analyze(input, method, output_selection) @@ -13,45 +7,41 @@ If `output_selection` is specified, the explanation will be calculated for that Otherwise, the output with the highest activation is automatically chosen. See also [`Explanation`](@ref). - -## Keyword arguments -- `add_batch_dim`: add batch dimension to the input without allocating. Default is `false`. """ function analyze( - input::AbstractArray{<:Real}, + input, method::AbstractXAIMethod, output_selector::AbstractOutputSelector; kwargs... +) + return call_analyzer(input, method, output_selector; kwargs...) +end +function analyze( + input, method::AbstractXAIMethod, output_selection::Union{Integer,Tuple{<:Integer}}; kwargs..., ) - return _analyze(input, method, IndexSelector(output_selection); kwargs...) + return call_analyzer(input, method, IndexSelector(output_selection); kwargs...) end -function analyze(input::AbstractArray{<:Real}, method::AbstractXAIMethod; kwargs...) - return _analyze(input, method, MaxActivationSelector(); kwargs...) +function analyze(input, method::AbstractXAIMethod; kwargs...) + return call_analyzer(input, method, MaxActivationSelector(); kwargs...) end -function (method::AbstractXAIMethod)( - input::AbstractArray{<:Real}, - output_selection::Union{Integer,Tuple{<:Integer}}; - kwargs..., -) - return _analyze(input, method, IndexSelector(output_selection); kwargs...) +# Direct calls to analyzer +function (method::AbstractXAIMethod)(input; kwargs...) + analyze(input, method, MaxActivationSelector(); kwargs...) end -function (method::AbstractXAIMethod)(input::AbstractArray{<:Real}; kwargs...) - return _analyze(input, method, MaxActivationSelector(); kwargs...) +function (method::AbstractXAIMethod)(input, output_selector; kwargs...) + analyze(input, method, output_selector; kwargs...) end -# lower-level call to method -function _analyze( - input::AbstractArray{T,N}, - method::AbstractXAIMethod, - sel::AbstractOutputSelector; - add_batch_dim::Bool=false, - kwargs..., -) where {T<:Real,N} - if add_batch_dim - return method(batch_dim_view(input), sel; kwargs...) - end - N < 2 && throw(BATCHDIM_MISSING) - return method(input, sel; kwargs...) +# Throw NotImplementedError as a fallback +function call_analyzer( + input, method::AbstractXAIMethod, output_selector::AbstractOutputSelector; kwargs... +) + return throw( + NotImplementedError( + method, + "call_analyzer(input, method::T, output_selector::AbstractOutputSelector; kwargs...)", + ), + ) end diff --git a/src/exceptions.jl b/src/exceptions.jl new file mode 100644 index 0000000..8b6656d --- /dev/null +++ b/src/exceptions.jl @@ -0,0 +1,11 @@ +struct NotImplementedError <: Exception + analyzer::AbstractXAIMethod + method::String +end + +function Base.showerror(io::IO, e::NotImplementedError) + T = string(typeof(e.analyzer)) + printstyled(io, "NotImplementedError: "; color=:red) + println(io, "The `$T` analyzer doesn't fully implement the XAIBase interface.") + print(io, "Please implement `", e.method, "` for your type `T<:$T`.") +end diff --git a/src/explanation.jl b/src/explanation.jl index c0792ba..f04056f 100644 --- a/src/explanation.jl +++ b/src/explanation.jl @@ -13,14 +13,17 @@ Return type of analyzers when calling [`analyze`](@ref). * `extras`: optional named tuple that can be used by analyzers to return additional information. """ -struct Explanation{V,O,S,E<:Union{Nothing,NamedTuple}} +struct Explanation{V,I,O,S,E<:Union{Nothing,NamedTuple}} val::V + input::I output::O output_selection::S analyzer::Symbol heatmap::Symbol extras::E end -function Explanation(val, output, output_selection, analyzer::Symbol, heatmap::Symbol) - return Explanation(val, output, output_selection, analyzer, heatmap, nothing) +function Explanation( + val, input, output, output_selection, analyzer::Symbol, heatmap::Symbol +) + return Explanation(val, input, output, output_selection, analyzer, heatmap, nothing) end diff --git a/src/output_selection.jl b/src/output_selection.jl index afd29f9..25b6e17 100644 --- a/src/output_selection.jl +++ b/src/output_selection.jl @@ -1,6 +1,10 @@ -const NOTE_OUTPUT_SELECTOR = """## Note +const BATCHDIM_MISSING = ArgumentError( + "The input is a 1D vector and therefore missing the required batch dimension." +) + +const NOTE_OUTPUT_SELECTOR = "## Note XAIBase assumes that the batch dimension is the last dimension of the output. -""" +" """ Abstract super type of all output selectors in XAIBase. diff --git a/test/Project.toml b/test/Project.toml index e56440b..5f87f2d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" diff --git a/test/runtests.jl b/test/runtests.jl index 164220b..b045f53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,18 +3,26 @@ using Test using JuliaFormatter using Aqua +using JuliaFormatter +using Aqua +using JET +using ReferenceTests @testset "XAIBase.jl" begin if VERSION >= v"1.10" @info "Testing formalities..." @testset "Code formatting" begin - @info "- Testing code formatting with JuliaFormatter..." + @info "- running JuliaFormatter code formatting tests..." @test JuliaFormatter.format(XAIBase; verbose=false, overwrite=false) end @testset "Aqua.jl" begin - @info "- Running Aqua.jl tests. These might print warnings from dependencies..." + @info "- running Aqua.jl tests. These might print warnings from dependencies..." Aqua.test_all(XAIBase; ambiguities=false) end + @testset "JET.jl" begin + @info "- running JET.jl type stability tests..." + JET.test_package(XAIBase; target_defined_modules=true) + end end @testset "API" begin diff --git a/test/test_api.jl b/test/test_api.jl index dd64aa8..9c9d9ee 100644 --- a/test/test_api.jl +++ b/test/test_api.jl @@ -1,15 +1,20 @@ using XAIBase using Test +using XAIBase: AbstractXAIMethod, NotImplementedError +import XAIBase: call_analyzer + # Create dummy analyzer to test API struct DummyAnalyzer <: AbstractXAIMethod end -function (method::DummyAnalyzer)(input, output_selector::AbstractOutputSelector) +function call_analyzer( + input, ::DummyAnalyzer, output_selector::AbstractOutputSelector; kwargs... +) output = input output_selection = output_selector(output) batchsize = size(input)[end] v = reshape(output[output_selection], :, batchsize) val = input .* v - return Explanation(val, output, output_selection, :Dummy, :attribution) + return Explanation(val, input, output, output_selection, :Dummy, :attribution) end analyzer = DummyAnalyzer() @@ -23,11 +28,6 @@ expl = analyze(input, analyzer) expl = analyzer(input) @test expl.val == val -# Max activation + add_batch_dim -input_vec = [1, 2, 3] -expl = analyzer(input_vec; add_batch_dim=true) -@test expl.val == val[:, 1:1] - # Ouput selection output_index = 2 val = [2 30; 4 25; 6 20] @@ -38,6 +38,29 @@ expl = analyze(input, analyzer, output_index) expl = analyzer(input, output_index) @test expl.val == val -# Ouput selection + add_batch_dim -expl = analyzer(input_vec, output_index; add_batch_dim=true) -@test expl.val == val[:, 1:1] +# Dummy analyzer to test exceptions +struct EmptyAnalyzer <: AbstractXAIMethod end + +analyzer = EmptyAnalyzer() +@test_throws NotImplementedError analyze(input, analyzer, output_index) + +# Dummy analyzer to test "unusual" inputs +struct AnyInputAnalyzer <: AbstractXAIMethod end +function call_analyzer( + input, ::AnyInputAnalyzer, output_selector::AbstractOutputSelector; kwargs... +) + output = 42 + output_selection = 42 + val = 42 + return Explanation(val, input, output, output_selection, :AnyInput, :attribution) +end + +analyzer = AnyInputAnalyzer() + +input1 = (foo=1, bar=2) # NamedTuple +expl1 = analyze(input1, analyzer) +@test expl1.input isa NamedTuple + +input2 = "Hello world" # String +expl2 = analyze(input2, analyzer) +@test expl2.input isa String diff --git a/test/test_output_selection.jl b/test/test_output_selection.jl index 5a67092..6cab70e 100644 --- a/test/test_output_selection.jl +++ b/test/test_output_selection.jl @@ -2,6 +2,7 @@ using XAIBase using Test using XAIBase: MaxActivationSelector, IndexSelector +using Test using Random ns_max = @inferred MaxActivationSelector()