Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
kongdd committed May 7, 2024
1 parent 6520b74 commit cc44c10
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/Smooth/Whittaker/WHIT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function WHIT(y::AbstractVector, w::AbstractVector, x::AbstractVector;
r = @. (y - z) / (1 - h)
cve = sqrt(sum(r .* r .* w) / sum(w))
end
z, h, cve
z, cve
end

export WHIT, speye, diff, ddmat
Expand Down
2 changes: 1 addition & 1 deletion src/Smooth/Whittaker/lambda_vcurve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function lambda_cv(y::AbstractVector{T}, w::AbstractVector{T2};
for i in 1:n
lambda = 10^lg_lambdas[i]
# z, cvs[i] = whit2(y, w, lambda)
cvs[i] = whit2!(y, w, lambda, interm; include_cve=true)
cvs[i] = whit2!(y, w, lambda, interm; include_cve=true)[2]
# fits[i] = fidelity(y, z, w)
# pens[i] = roughness(z, d)
end
Expand Down
1 change: 0 additions & 1 deletion src/Smooth/Whittaker/whit3.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
export whit3
include("whit3_hat.jl")

function whit3(y::AbstractVector{T1}, w::AbstractVector{T2};
lambda::Real, include_cve=true) where {T1<:Real,T2<:Real}
Expand Down
20 changes: 16 additions & 4 deletions test/test-smooth_whit.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test
using UnPack

@testset "whittaker smoother" begin
y = [5.0, 8, 9, 10, 12, 10, 15, 10, 9, 19, 19, 17, 13, 14, 18, 19, 18, 12, 18,
Expand All @@ -11,13 +12,13 @@ using Test
lambda = 2.0
z = ones(m)
interm = interm_whit{FT}(; n=length(y))
cve = whit2!(y, w, lambda, interm; include_cve=true)
z, cve = whit2!(y, w, lambda, interm; include_cve=true)
z, cve2 = whit2(y, w; lambda)

@test cve cve2

lamb_cv = lambda_cv(y, w, is_plot=true)
lamb_vcurve = lambda_vcurve(y, w, is_plot=true)

z1, cve_cv = whit2(y, w, lambda=lamb_cv)
z2, cve_vcurve = whit2(y, w, lambda=lamb_vcurve)
@test cve_cv < cve
Expand All @@ -34,8 +35,19 @@ end
z, cve2 = whit2(y, w, lambda=2)
@test cve1 cve2
@test_nowarn r = smooth_whit(y, w)
end

# whit3
z1, cve1 = whit3(y, w; lambda=2)
z2, cve2 = WHIT(y, w; lambda=2, d=3)
@test cve1 cve2
@test maximum(z1 - z2) <= 1e-10

# whit2
z1, cve1 = whit2(y, w; lambda=2)
z2, cve2 = WHIT(y, w; lambda=2, d=2)
@test cve1 cve2
@test maximum(z1 - z2) <= 1e-10
end

# using BenchmarkTools
# @benchmark
Expand Down

0 comments on commit cc44c10

Please sign in to comment.