Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to exclude certain lags when selecting variables for bbnue #81

Merged
merged 7 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TransferEntropy"
uuid = "ea221983-52f3-5440-99c7-13ea201cd633"
repo = "https://github.com/kahaaga/TransferEntropy.jl.git"
version = "1.4.0"
version = "1.5.0"

[deps]
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
Expand Down
40 changes: 34 additions & 6 deletions src/transferentropy/autoutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ using DelayEmbeddings, Statistics
construct_candidate_variables(
source::Vector{AbstractVector},
target::Vector{AbstractVector},
cond::Vector{AbstractVector};
[cond::Vector{AbstractVector}];
k::Int = 1, include_instantaneous = true,
τexclude::Union{Int, Nothing} = nothing,
maxlag::Union{Int, Float64} = 0.05
) → ([τs_source, τs_target, τs_cond, ks_targetfuture], [js_source, js_target, js_cond, js_targetfuture])

Expand All @@ -18,9 +19,13 @@ the variables.

If `maxlag` is an integer, `maxlag` is taken as the maximum allowed embedding lag. If `maxlag` is a float,
then the maximum embedding lag is taken as `maximum([length.(source); length.(target); length.(cond)])*maxlag`.

If `τexclude` is an integer, all variables whose embedding lag has absolute value equal to `exclude` will be
excluded.
"""
function construct_candidate_variables(source, target, cond;
k::Int = 1,
k::Int = 1,
τexclude::Union{Int, Nothing} = nothing,
include_instantaneous = true,
method_delay = "ac_min",
maxlag::Union{Int, Float64} = 0.05)
Expand All @@ -39,9 +44,10 @@ function construct_candidate_variables(source, target, cond;
τsmax_source = [estimate_delay(s, method_delay, τs) for s in source]
τsmax_target = [estimate_delay(t, method_delay, τs) for t in target]
τsmax_cond = [estimate_delay(c, method_delay, τs) for c in cond]

# Generate candidate set
startlag = include_instantaneous ? 0 : -1

τs_source = [[startlag:-1:-τ...,] for τ in τsmax_source]
τs_target = [[startlag:-1:-τ...,] for τ in τsmax_target]
τs_cond = [[startlag:-1:-τ...,] for τ in τsmax_cond]
Expand All @@ -50,15 +56,31 @@ function construct_candidate_variables(source, target, cond;
js_targetfuture = [i for i in length(τs_source)+1:length(τs_source)+length(τs_target)]
τs = [τs_source..., τs_target..., τs_cond...]
js = [[i for x in 1:length(τs[i])] for i = 1:length(τs)]


# Variable filtering, if desired
if τexclude isa Int
τs = [filtered_τs(τsᵢ, jsᵢ, τexclude) for (τsᵢ, jsᵢ) in zip(τs, js)]
js = [filtered_js(τsᵢ, jsᵢ, τexclude) for (τsᵢ, jsᵢ) in zip(τs, js)]
end
return [τs..., ks_targetfuture], [js..., js_targetfuture]
end

# Usaully, we use all lags from startlag:-\tau_max to construct variables. In some situations,
# we may want to exclude som of those variables.
function filtered_τs(τs::AbstractVector{Int}, js::AbstractVector{Int}, τexclude::Int)
[τ for τ in τs if abs(τ) != abs.(τexclude)]
end

function filtered_js(τs::AbstractVector{Int}, js::AbstractVector{Int}, τexclude::Int)
[j for (τ, j) in zip(τs, js) if abs(τ) != abs.(τexclude)]
end

# source & target variant
function construct_candidate_variables(source, target;
k::Int = 1,
τexclude::Union{Int, Nothing} = nothing,
include_instantaneous = true,
method_delay = "mi_min",
method_delay = "ac_min",
maxlag::Union{Int, Float64} = 0.05)

# Ensure all time series are of the same length.
Expand All @@ -74,7 +96,7 @@ function construct_candidate_variables(source, target;
# Find the maximum allowed embedding lag for each of the candidates.
τsmax_source = [estimate_delay(s, method_delay, τs) for s in source]
τsmax_target = [estimate_delay(t, method_delay, τs) for t in target]

# Generate candidate set
startlag = include_instantaneous ? 0 : -1
τs_source = [[startlag:-1:-τ...,] for τ in τsmax_source]
Expand All @@ -84,6 +106,12 @@ function construct_candidate_variables(source, target;
js_targetfuture = [i for i in length(τs_source)+1:length(τs_source)+length(τs_target)]
τs = [τs_source..., τs_target...,]
js = [[i for x in 1:length(τs[i])] for i = 1:length(τs)]

# Variable filtering, if desired
if τexclude isa Int
τs = [filtered_τs(τsᵢ, jsᵢ, τexclude) for (τsᵢ, jsᵢ) in zip(τs, js)]
js = [filtered_js(τsᵢ, jsᵢ, τexclude) for (τsᵢ, jsᵢ) in zip(τs, js)]
end

return [τs..., ks_targetfuture], [js..., js_targetfuture]
end
Expand Down
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ end

@testset "Transfer entropy" begin
s, t, c = rand(100), rand(100), rand(100)

println("Starting transfer entropy tests...")
@testset "Generalized Renyi transfer entropy" begin
# Straight-forward estimators
Expand Down Expand Up @@ -120,6 +121,32 @@ end
end

@testset "Automated estimators" begin
@testset "Variable exclusion" begin
# Use periodic signals, so we also can test variable selection methods,
# which for sensible testing, need to have their autocorrelation minima
# > 1.
s = sin.(1:100) .+ rand(100)
t = sin.(1:100) .+ rand(100)
c = sin.(1:100) .+ rand(100)

τexclude = 1
vars = TransferEntropy.construct_candidate_variables([s], [t], τexclude = nothing)
vars_ex = TransferEntropy.construct_candidate_variables([s], [t], τexclude = τexclude)
fvars = Iterators.flatten(vars[1][1:end-1]) |> collect
fvars_ex = Iterators.flatten(vars_ex[1][1:end-1]) |> collect
@test τexclude ∈ abs.(fvars)
@test τexclude ∉ abs.(fvars_ex)
@test length(fvars) > length(fvars_ex)

vars = TransferEntropy.construct_candidate_variables([s], [t], [c], τexclude = nothing)
vars_ex = TransferEntropy.construct_candidate_variables([s], [t], [c], τexclude = τexclude)
fvars = Iterators.flatten(vars[1][1:end-1]) |> collect
fvars_ex = Iterators.flatten(vars_ex[1][1:end-1]) |> collect
@test τexclude ∈ abs.(fvars)
@test τexclude ∉ abs.(fvars_ex)
@test length(fvars) > length(fvars_ex)
end

est = VisitationFrequency(RectangularBinning(3))
te_st, params_st = bbnue(s, t, est)
te_stc, params_stc = bbnue(s, t, c, est)
Expand Down