From 7827415eeea96e0b17a7450e66f41c2377176435 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Wed, 18 Oct 2023 16:58:52 +0200 Subject: [PATCH] Access color schemes through symbols --- src/VisionHeatmaps.jl | 2 +- src/heatmap.jl | 8 ++++++-- test/test_heatmap.jl | 4 ++++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/VisionHeatmaps.jl b/src/VisionHeatmaps.jl index 002b749..ddc158a 100644 --- a/src/VisionHeatmaps.jl +++ b/src/VisionHeatmaps.jl @@ -1,6 +1,6 @@ module VisionHeatmaps -using ColorSchemes: ColorScheme, get, seismic +using ColorSchemes: ColorScheme, colorschemes, get, seismic using ImageCore include("heatmap.jl") diff --git a/src/heatmap.jl b/src/heatmap.jl index 207c434..ab647f2 100644 --- a/src/heatmap.jl +++ b/src/heatmap.jl @@ -9,7 +9,7 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di (width, height, color channels, batch dimension). ## Keyword arguments -- `colorscheme::ColorScheme`: Color scheme from ColorSchemes.jl. +- `colorscheme::Union{ColorScheme,Symbol}`: Color scheme from ColorSchemes.jl. Defaults to `seismic`. - `reduce::Symbol`: Selects how color channels are reduced to a single number to apply a color scheme. The following methods can be selected, which are then applied over the color channels @@ -32,7 +32,7 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di """ function heatmap( val::AbstractArray{T,N}; - colorscheme::ColorScheme=DEFAULT_COLORSCHEME, + colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME, reduce::Symbol=DEFAULT_REDUCE, rangescale::Symbol=DEFAULT_RANGESCALE, permute::Bool=true, @@ -40,6 +40,7 @@ function heatmap( process_batch::Bool=false, ) where {T,N} N != 4 && throw(InputDimensionError) + colorscheme = get_colorscheme(colorscheme) if unpack_singleton && size(val, 4) == 1 return single_heatmap(val[:, :, :, 1], colorscheme, reduce, rangescale, permute) end @@ -58,6 +59,9 @@ const InputDimensionError = ArgumentError( Please reshape your input to match this format if your model doesn't adhere to this convention.", ) +get_colorscheme(c::ColorScheme) = c +get_colorscheme(s::Symbol)::ColorScheme = colorschemes[s] + # Lower level function, mapped along batch dimension function single_heatmap( val, colorscheme::ColorScheme, reduce::Symbol, rangescale::Symbol, permute::Bool diff --git a/test/test_heatmap.jl b/test/test_heatmap.jl index 15ce649..9d55500 100644 --- a/test/test_heatmap.jl +++ b/test/test_heatmap.jl @@ -30,6 +30,10 @@ end @testset "ColorSchemes" begin h = heatmap(A; colorscheme=ColorSchemes.inferno) @test_reference "references/inferno.txt" h + + # Test colorscheme symbols + h = heatmap(A; colorscheme=:inferno) + @test_reference "references/inferno.txt" h end @testset "Error handling" begin