Skip to content

Commit

Permalink
Merge pull request #81 from JuliaDynamics/bbnue_filter_lags
Browse files Browse the repository at this point in the history
Add possibility to exclude certain lags when selecting variables for `bbnue`
  • Loading branch information
kahaaga authored Dec 14, 2021
2 parents 101130e + eae0a86 commit e3c4f2a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TransferEntropy"
uuid = "ea221983-52f3-5440-99c7-13ea201cd633"
repo = "https://github.com/kahaaga/TransferEntropy.jl.git"
version = "1.4.0"
version = "1.5.0"

[deps]
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
Expand Down
40 changes: 34 additions & 6 deletions src/transferentropy/autoutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ using DelayEmbeddings, Statistics
construct_candidate_variables(
source::Vector{AbstractVector},
target::Vector{AbstractVector},
cond::Vector{AbstractVector};
[cond::Vector{AbstractVector}];
k::Int = 1, include_instantaneous = true,
τexclude::Union{Int, Nothing} = nothing,
maxlag::Union{Int, Float64} = 0.05
) → ([τs_source, τs_target, τs_cond, ks_targetfuture], [js_source, js_target, js_cond, js_targetfuture])
Expand All @@ -18,9 +19,13 @@ the variables.
If `maxlag` is an integer, `maxlag` is taken as the maximum allowed embedding lag. If `maxlag` is a float,
then the maximum embedding lag is taken as `maximum([length.(source); length.(target); length.(cond)])*maxlag`.
If `τexclude` is an integer, all variables whose embedding lag has absolute value equal to `exclude` will be
excluded.
"""
function construct_candidate_variables(source, target, cond;
k::Int = 1,
k::Int = 1,
τexclude::Union{Int, Nothing} = nothing,
include_instantaneous = true,
method_delay = "ac_min",
maxlag::Union{Int, Float64} = 0.05)
Expand All @@ -39,9 +44,10 @@ function construct_candidate_variables(source, target, cond;
τsmax_source = [estimate_delay(s, method_delay, τs) for s in source]
τsmax_target = [estimate_delay(t, method_delay, τs) for t in target]
τsmax_cond = [estimate_delay(c, method_delay, τs) for c in cond]

# Generate candidate set
startlag = include_instantaneous ? 0 : -1

τs_source = [[startlag:-1:-τ...,] for τ in τsmax_source]
τs_target = [[startlag:-1:-τ...,] for τ in τsmax_target]
τs_cond = [[startlag:-1:-τ...,] for τ in τsmax_cond]
Expand All @@ -50,15 +56,31 @@ function construct_candidate_variables(source, target, cond;
js_targetfuture = [i for i in length(τs_source)+1:length(τs_source)+length(τs_target)]
τs = [τs_source..., τs_target..., τs_cond...]
js = [[i for x in 1:length(τs[i])] for i = 1:length(τs)]


# Variable filtering, if desired
if τexclude isa Int
τs = [filtered_τs(τsᵢ, jsᵢ, τexclude) for (τsᵢ, jsᵢ) in zip(τs, js)]
js = [filtered_js(τsᵢ, jsᵢ, τexclude) for (τsᵢ, jsᵢ) in zip(τs, js)]
end
return [τs..., ks_targetfuture], [js..., js_targetfuture]
end

# Usaully, we use all lags from startlag:-\tau_max to construct variables. In some situations,
# we may want to exclude som of those variables.
function filtered_τs(τs::AbstractVector{Int}, js::AbstractVector{Int}, τexclude::Int)
for τ in τs if abs(τ) != abs.(τexclude)]
end

function filtered_js(τs::AbstractVector{Int}, js::AbstractVector{Int}, τexclude::Int)
[j for (τ, j) in zip(τs, js) if abs(τ) != abs.(τexclude)]
end

# source & target variant
function construct_candidate_variables(source, target;
k::Int = 1,
τexclude::Union{Int, Nothing} = nothing,
include_instantaneous = true,
method_delay = "mi_min",
method_delay = "ac_min",
maxlag::Union{Int, Float64} = 0.05)

# Ensure all time series are of the same length.
Expand All @@ -74,7 +96,7 @@ function construct_candidate_variables(source, target;
# Find the maximum allowed embedding lag for each of the candidates.
τsmax_source = [estimate_delay(s, method_delay, τs) for s in source]
τsmax_target = [estimate_delay(t, method_delay, τs) for t in target]

# Generate candidate set
startlag = include_instantaneous ? 0 : -1
τs_source = [[startlag:-1:-τ...,] for τ in τsmax_source]
Expand All @@ -84,6 +106,12 @@ function construct_candidate_variables(source, target;
js_targetfuture = [i for i in length(τs_source)+1:length(τs_source)+length(τs_target)]
τs = [τs_source..., τs_target...,]
js = [[i for x in 1:length(τs[i])] for i = 1:length(τs)]

# Variable filtering, if desired
if τexclude isa Int
τs = [filtered_τs(τsᵢ, jsᵢ, τexclude) for (τsᵢ, jsᵢ) in zip(τs, js)]
js = [filtered_js(τsᵢ, jsᵢ, τexclude) for (τsᵢ, jsᵢ) in zip(τs, js)]
end

return [τs..., ks_targetfuture], [js..., js_targetfuture]
end
Expand Down
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ end

@testset "Transfer entropy" begin
s, t, c = rand(100), rand(100), rand(100)

println("Starting transfer entropy tests...")
@testset "Generalized Renyi transfer entropy" begin
# Straight-forward estimators
Expand Down Expand Up @@ -120,6 +121,32 @@ end
end

@testset "Automated estimators" begin
@testset "Variable exclusion" begin
# Use periodic signals, so we also can test variable selection methods,
# which for sensible testing, need to have their autocorrelation minima
# > 1.
s = sin.(1:100) .+ rand(100)
t = sin.(1:100) .+ rand(100)
c = sin.(1:100) .+ rand(100)

τexclude = 1
vars = TransferEntropy.construct_candidate_variables([s], [t], τexclude = nothing)
vars_ex = TransferEntropy.construct_candidate_variables([s], [t], τexclude = τexclude)
fvars = Iterators.flatten(vars[1][1:end-1]) |> collect
fvars_ex = Iterators.flatten(vars_ex[1][1:end-1]) |> collect
@test τexclude abs.(fvars)
@test τexclude abs.(fvars_ex)
@test length(fvars) > length(fvars_ex)

vars = TransferEntropy.construct_candidate_variables([s], [t], [c], τexclude = nothing)
vars_ex = TransferEntropy.construct_candidate_variables([s], [t], [c], τexclude = τexclude)
fvars = Iterators.flatten(vars[1][1:end-1]) |> collect
fvars_ex = Iterators.flatten(vars_ex[1][1:end-1]) |> collect
@test τexclude abs.(fvars)
@test τexclude abs.(fvars_ex)
@test length(fvars) > length(fvars_ex)
end

est = VisitationFrequency(RectangularBinning(3))
te_st, params_st = bbnue(s, t, est)
te_stc, params_stc = bbnue(s, t, c, est)
Expand Down

0 comments on commit e3c4f2a

Please sign in to comment.