From bbfb4ba1721ba8f89d77c78d6ce4e0466e822210 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 4 Oct 2023 23:32:59 +0800 Subject: [PATCH] style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 15 +++++++++------ test/tangent_types/abstract_zero.jl | 10 ++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index eea2d90e0..afa5b130f 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -116,12 +116,13 @@ zero_tangent(x::Number) = zero(x) return :($MutableTangent{$primal}($backing_expr)) end -function zero_tangent(x::Array{P, N}) where {P, N} - (isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x) - +function zero_tangent(x::Array{P,N}) where {P,N} + (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) && + return map(zero_tangent, x) + # Now we need to handle nonfully assigned arrays # see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 - y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...) + y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...) @inbounds for n in eachindex(y) if isassigned(x, n) y[n] = zero_tangent(x[n]) @@ -131,6 +132,8 @@ function zero_tangent(x::Array{P, N}) where {P, N} end guess_zero_tangent_type(::Type{T}) where {T<:Number} = T -guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N} +function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} + return Array{guess_zero_tangent_type(T),N} +end guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634 - # TODO: we might be able to do better than this. even without. \ No newline at end of file +# TODO: we might be able to do better than this. even without. \ No newline at end of file diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 0c86c3432..c9442707d 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -176,7 +176,7 @@ end @testset "undef elements" begin x = Vector{Vector{Float64}}(undef, 3) - x[2] = [1.0,2.0] + x[2] = [1.0, 2.0] dx = zero_tangent(x) @test dx isa Vector{Vector{Float64}} @test length(dx) == 3 @@ -184,7 +184,6 @@ end @test dx[2] == [0.0, 0.0] @test !isassigned(dx, 3) - a = Vector{MutDemo}(undef, 3) a[2] = MutDemo(1.5) da = zero_tangent(a) @@ -192,10 +191,9 @@ end @test iszero(da[2]) @test !isassigned(da, 3) - db = zero_tangent(Vector{MutDemo}(undef, 3)) - @test all(ii->!isassigned(db,ii), eachindex(db)) - @test length(db)==3 + @test all(ii -> !isassigned(db, ii), eachindex(db)) + @test length(db) == 3 @test db isa Vector - end + end end