From 2bac11d5de8e5985ec951611bff266414070a7f5 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Fri, 29 Dec 2023 15:00:08 -0700 Subject: [PATCH 1/3] allow transform to avoid Z-score transforming when sigma=0 --- src/transformations.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformations.jl b/src/transformations.jl index a4214b5db..b7e57e0b2 100644 --- a/src/transformations.jl +++ b/src/transformations.jl @@ -123,10 +123,22 @@ function fit(::Type{ZScoreTransform}, X::AbstractMatrix{<:Real}; else throw(DomainError(dims, "fit only accept dims to be 1 or 2.")) end - return ZScoreTransform(l, dims, (center ? vec(m) : similar(m, 0)), - (scale ? vec(s) : similar(s, 0))) + if scale + s_vec = vec(s) + # avoid z-score transforming when sigma=0 + if any(s_vec .== 0.0) + zero_variance_indices = s_vec .== 0.0 + s_vec[zero_variance_indices] .= 1.0 + end + else + s_vec = similar(s, 0) + end + + return ZScoreTransform(l, dims, (center ? vec(m) : similar(m, 0)), s_vec) end + + function fit(::Type{ZScoreTransform}, X::AbstractVector{<:Real}; dims::Integer=1, center::Bool=true, scale::Bool=true) if dims != 1 From 25ffe3860562986e663ddf3378c4a432060ea78d Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Fri, 29 Dec 2023 15:02:12 -0700 Subject: [PATCH 2/3] remove extra newlines --- src/transformations.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformations.jl b/src/transformations.jl index b7e57e0b2..06f394235 100644 --- a/src/transformations.jl +++ b/src/transformations.jl @@ -137,8 +137,6 @@ function fit(::Type{ZScoreTransform}, X::AbstractMatrix{<:Real}; return ZScoreTransform(l, dims, (center ? vec(m) : similar(m, 0)), s_vec) end - - function fit(::Type{ZScoreTransform}, X::AbstractVector{<:Real}; dims::Integer=1, center::Bool=true, scale::Bool=true) if dims != 1 From 675055dd03a337aef8fdf95188b689cd22843ac4 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Fri, 29 Dec 2023 15:27:43 -0700 Subject: [PATCH 3/3] add tests for zero variance --- test/transformations.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/transformations.jl b/test/transformations.jl index 7d8e2b0a9..f8f7926ba 100644 --- a/test/transformations.jl +++ b/test/transformations.jl @@ -70,6 +70,20 @@ using Test @test reconstruct!(t, Y) === Y @test Y ≈ X_ + # zero standard deviations + X = ones(5, 8) + t = fit(ZScoreTransform, X, dims=2) + Y = transform(t, X) + @test length(t.mean) == 5 + @test length(t.scale) == 5 + @test Y ≈ zeros(5, 8) + @test reconstruct(t, Y) ≈ X + @test Y ≈ standardize(ZScoreTransform, X, dims=2) + @test transform!(t, X) === X + @test isequal(X, Y) + @test reconstruct!(t, Y) === Y + @test Y ≈ ones(5, 8) + X = copy(X_) t = fit(UnitRangeTransform, X, dims=1, unit=false) Y = transform(t, X)