Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
oxinabox and github-actions[bot] authored Oct 4, 2023
1 parent b9e3376 commit bbfb4ba
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
15 changes: 9 additions & 6 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,13 @@ zero_tangent(x::Number) = zero(x)
return :($MutableTangent{$primal}($backing_expr))
end

function zero_tangent(x::Array{P, N}) where {P, N}
(isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x)

function zero_tangent(x::Array{P,N}) where {P,N}
(isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) &&
return map(zero_tangent, x)

# Now we need to handle nonfully assigned arrays
# see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265
y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...)
y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...)
@inbounds for n in eachindex(y)
if isassigned(x, n)
y[n] = zero_tangent(x[n])
Expand All @@ -131,6 +132,8 @@ function zero_tangent(x::Array{P, N}) where {P, N}
end

guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N}
function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N}
return Array{guess_zero_tangent_type(T),N}
end
guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634
# TODO: we might be able to do better than this. even without.
# TODO: we might be able to do better than this. even without.
10 changes: 4 additions & 6 deletions test/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,26 +176,24 @@ end

@testset "undef elements" begin
x = Vector{Vector{Float64}}(undef, 3)
x[2] = [1.0,2.0]
x[2] = [1.0, 2.0]
dx = zero_tangent(x)
@test dx isa Vector{Vector{Float64}}
@test length(dx) == 3
@test !isassigned(dx, 1)
@test dx[2] == [0.0, 0.0]
@test !isassigned(dx, 3)


a = Vector{MutDemo}(undef, 3)
a[2] = MutDemo(1.5)
da = zero_tangent(a)
@test !isassigned(da, 1)
@test iszero(da[2])
@test !isassigned(da, 3)


db = zero_tangent(Vector{MutDemo}(undef, 3))
@test all(ii->!isassigned(db,ii), eachindex(db))
@test length(db)==3
@test all(ii -> !isassigned(db, ii), eachindex(db))
@test length(db) == 3
@test db isa Vector
end
end
end

0 comments on commit bbfb4ba

Please sign in to comment.