Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add smoothing of contour plots using Gaussian filtering #268

Merged
merged 7 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/src/plotting.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The plot style of the mean can be customized using a `Dict`. For `mean = true`,
`Dict("linestyle" => :dash, "linewidth" => 1, "linecolor" => :black, "alpha" => 1, "label" => "global mode")`

* `marginalmode::Union{Dict, Bool} = true`: indicate the marginal mode(s), i.e. the center of the highest histogram bin(s) (currently only for samples). The style can be passed as a `Dict`. If `marginalmode = true`, the default style is:
`Dict("linestyle" => :dot, "linewidth" => 1, "linecolor" => :black, "alpha" => 1, "label" => "local mode")`
`Dict("linestyle" => :dot, "linewidth" => 1, "linecolor" => :black, "alpha" => 1, "label" => "marginal mode")`

* (only for samples) `filter::Bool = false`: if `true`, `BAT.drop_low_weight_samples()` is applied before plotting

Expand All @@ -72,6 +72,7 @@ plot(
std = false,
globalmode = false,
marginalmode = true,
smoothing = 0,
diagonal = Dict(),
upper = Dict(),
right = Dict(),
Expand Down Expand Up @@ -111,7 +112,9 @@ The plot style of the mean can be customized using a `Dict`. For `mean = true`,
`Dict("linestyle" => :dash, "linewidth" => 1, "linecolor" => :black, "alpha" => 1, "label" => "global mode")`

* `marginalmode::Union{Dict, Bool} = true`: indicate the marginalmode(s), i.e. the center of the highest histogram bin(s) (currently only for samples). The style can be passed as a `Dict`. If `marginalmode = true`, the default style is:
`Dict("linestyle" => :dot, "linewidth" => 1, "linecolor" => :black, "alpha" => 1, "label" => "local mode")`
`Dict("linestyle" => :dot, "linewidth" => 1, "linecolor" => :black, "alpha" => 1, "label" => "marginal mode")`

* `smoothing = 0`: When plotting contours, a Gaussian filtering can be applied for smoothing the contour lines. The keyword `smoothing` accepts positive real number (or a tuple of two positive real numbers), specifying the standard deviation of the Gaussian kernel (for each dimension) of the filtering.

* `diagonal = Dict()`: Used only for the seriestype `:marginal`. The dictionary can contain the seriestypes and plot options for 2D distributions explained above to modify the 2D plot of the marginal plot. Nested Dictionaries are possible to modify the styles of the estimators as described above

Expand Down
8 changes: 4 additions & 4 deletions examples/dev-internal/plotting_examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ using Plots

plot(samples, :a) #default seriestype = :smallest_intervals (alias :HDR)
#or: plot(samples, 2)
# The default seriestype for plotting samples is `:smallest_intervals` (alias `:HDR`), highlighting the smallest intervals (the highest density region) containing 68.3, 95.5 and 99.7 perecent of the posterior probability. By default, the local mode(s) of the histogram is(are) indicated as dotted black line(s).
# The default seriestype for plotting samples is `:smallest_intervals` (alias `:HDR`), highlighting the smallest intervals (the highest density region) containing 68.3, 95.5 and 99.7 perecent of the posterior probability. By default, the marginal mode(s) of the histogram is(are) indicated as dotted black line(s).

# ### Default 1D plot of prior:
# Priors can be plotted either by their index or by using the parameter name:
Expand Down Expand Up @@ -118,7 +118,7 @@ plot(samples, :a, marginalmode=false,
# ### Default 2D plot of samples:
pyplot()
plot(samples, (:a, :(b[1])), mean=true, std=true) #default seriestype = :smallest_intervals (alias :HDR)
# The default seriestype for plotting samples is a 3-color heatmap showing the smallest intervals (highest density regions) containing 68.3%, 95.5% and 99.7% of the posterior probability. By default, the local mode
# The default seriestype for plotting samples is a 3-color heatmap showing the smallest intervals (highest density regions) containing 68.3%, 95.5% and 99.7% of the posterior probability. By default, the marginal mode
# of the histogram is indicated by a black square.

# ### Default 2D plot of priors:
Expand All @@ -136,9 +136,9 @@ plot(samples, (:a,:(b[2])), seriestype = :histogram)
# (currently only correctly supported with `pyplot()` backend)
plot(samples, (:a,:(b[2])), seriestype=:smallest_intervals_contour, bins=40)

# ### smallest intervals as filled contours:
# ### smallest intervals as filled contours with smoothing:
# (currently only correctly supported with `pyplot()` backend)
plot(samples, (:a,:(b[2])), seriestype=:smallest_intervals_contourf, bins=40)
plot(samples, (:a,:(b[2])), seriestype=:smallest_intervals_contourf, bins=40, smoothing=1)

# ### Customizing smallest interval plots:
# The probability intervals to be highlighted can be specified using the `intervals` keyword.
Expand Down
2 changes: 1 addition & 1 deletion src/algotypes/mode_estimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ end

*Experimental feature, not part of stable public API.*

Estimates a local mode of `samples` by finding the maximum of marginalized posterior for each dimension.
Estimates a marginal mode of `samples` by finding the maximum of marginalized posterior for each dimension.

Returns a NamedTuple of the shape

Expand Down
8 changes: 7 additions & 1 deletion src/plotting/recipes_MarginalDist_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
vsel::NTuple{2,Union{Symbol, Expr, Integer}};
intervals = default_credibilities,
colors = default_colors,
smoothing = 0,
diagonal = Dict(),
upper = Dict(),
right = Dict(),
Expand Down Expand Up @@ -43,7 +44,12 @@
lev = calculate_levels(hist, intervals)
x, y = get_bin_centers(marg)
m = hist.weights


if smoothing != 0
ker = gaussian_kernel(smoothing)
m = convolution(m, ker, padding=:same)
end

# quick fix: needed when plotting contour on top of histogram
# otherwise scaling of histogram colorbar would change scaling
lev = lev/10000
Expand Down
6 changes: 3 additions & 3 deletions src/plotting/recipes_samples_1D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@
end
end

# local mode(s)
# marginal mode(s)
if marginalmode_options != ()
marginalmode_values = find_marginalmodes(marg)

for (i, l) in enumerate(marginalmode_values)
@series begin
seriestype := :line
if length(marginalmode_values)==1
label := get(marginalmode_options, "label", "local mode")
label := get(marginalmode_options, "label", "marginal mode")
elseif i ==1
label := get(marginalmode_options, "label", "local modes")
label := get(marginalmode_options, "label", "marginal modes")
else
label :=""
end
Expand Down
6 changes: 4 additions & 2 deletions src/plotting/recipes_samples_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
std = false,
globalmode = false,
marginalmode = true,
smoothing = 0,
diagonal = Dict(),
upper = Dict(),
right = Dict(),
Expand Down Expand Up @@ -92,6 +93,7 @@
diagonal --> diagonal
upper --> upper
right --> right
smoothing --> smoothing

marg, (xindx, yindx)
end
Expand Down Expand Up @@ -168,9 +170,9 @@
@series begin
seriestype := :scatter
if i==1 && length(marginalmode_values)==1
label := get(marginalmode_options, "label", "local mode")
label := get(marginalmode_options, "label", "marginal mode")
elseif i ==1
label := get(marginalmode_options, "label", "local modes")
label := get(marginalmode_options, "label", "marginal modes")
else
label :=""
end
Expand Down
61 changes: 61 additions & 0 deletions src/utils/convolution_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).

# simple 2d convolution with padding
function convolution(input, filter; padding=:same)
input_r, input_c = size(input)
filter_r, filter_c = size(filter)

if padding == :same
pad_r = (filter_r - 1) ÷ 2
pad_c = (filter_c - 1) ÷ 2

input_padded = zeros(input_r+(2*pad_r), input_c+(2*pad_c))
for i in 1:input_r, j in 1:input_c
input_padded[i+pad_r, j+pad_c] = input[i, j]
end
input = input_padded
input_r, input_c = size(input)
end

result = zeros(input_r-filter_r+1, input_c-filter_c+1)
result_r, result_c = size(result)

for i in 1:result_r
for j in 1:result_c
for k in 1:filter_r
for l in 1:filter_c
result[i,j] += input[i+k-1,j+l-1]*filter[k,l]
end
end
end
end

return result
end


# gaussian kernel with same σ in both dimensions
function gaussian_kernel(σ::Real; l::Int = 4*ceil(Int,σ)+1)
isodd(l) || throw(ArgumentError("length must be odd"))
w = l>>1
g = σ == 0 ? [exp(0/(2*oftype(σ, 1)^2))] : [exp(-x^2/(2*σ^2)) for x=-w:w]
k = g/sum(g)
return (k * k')
end

# gaussian kernel with different σs in both dimensions
function gaussian_kernel(
σs::Tuple{Real, Real};
l::Tuple{Int, Int} = (4*ceil(Int,σs[1])+1, 4*ceil(Int,σs[2])+1)
)
all(isodd.(l)) || throw(ArgumentError("length must be odd"))
w1 = l[1]>>1
g1 = σs[1] == 0 ? [exp(0/(2*oftype(σs[1], 1)^2))] : [exp(-x^2/(2*σs[1]^2)) for x=-w1:w1]
k1 = g1/sum(g1)

w2 = l[2]>>1
g2 = σs[2] == 0 ? [exp(0/(2*oftype(σs[2], 1)^2))] : [exp(-x^2/(2*σs[2]^2)) for x=-w2:w2]
k2 = g2/sum(g2)

return (k1 * k2')
end
1 change: 1 addition & 0 deletions src/utils/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

include("util_functions.jl")
include("array_utils.jl")
include("convolution_utils.jl")
include("coord_utils.jl")
include("valueshapes_utils.jl")
2 changes: 1 addition & 1 deletion test/plotting/test_localmodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Plots
using StatsBase


@testset "bin centers & local modes" begin
@testset "bin centers & marginal modes" begin

data1 = [1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9, 10]
hist_1d = fit(Histogram, data1, nbins = 10, closed = :left)
Expand Down