Skip to content

Commit

Permalink
handle abstract fields right in mutable tangents outside of zero tangent
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Dec 28, 2023
1 parent 24185d8 commit d116906
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
10 changes: 6 additions & 4 deletions src/tangent_types/structural_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ It itself is also mutable.
struct MutableTangent{P,F} <: StructuralTangent{P}
backing::F

function MutableTangent{P}(fieldvals) where P
backing = map(Ref, fieldvals)
return new{P, typeof(backing)}(backing)
end
function MutableTangent{P}(
any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names}
) where {names, P}
Expand All @@ -91,8 +87,14 @@ struct MutableTangent{P,F} <: StructuralTangent{P}
end
return new{P, typeof(backing)}(backing)
end

function MutableTangent{P}(fvals) where P
any_mask = NamedTuple{fieldnames(P)}((!isconcretetype).(fieldtypes(P)))
return MutableTangent{P}(any_mask, fvals)
end
end


####################################################################
# StructuralTangent Common

Expand Down
41 changes: 36 additions & 5 deletions test/tangent_types/structural_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ struct Foo
y::Float64
end

mutable struct MFoo
x::Float64
y
end

# For testing Primal + Tangent performance
struct Bar
x::Float64
Expand Down Expand Up @@ -452,14 +457,40 @@ end
end

@testset "== and hash" begin
@test MutableTangent{Any}(; x=1.0) == MutableTangent{MDemo}(; x=1.0)
@test MutableTangent{MDemo}(; x=1.0) == MutableTangent{Any}(; x=1.0)
@test MutableTangent{Any}(; x=2.0) != MutableTangent{MDemo}(; x=1.0)
@test MutableTangent{MDemo}(; x=1.0) != MutableTangent{Any}(; x=2.0)
@test MutableTangent{MDemo}(; x=1f0) == MutableTangent{MDemo}(; x=1.0)
@test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1f0)
@test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0)
@test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0)

nt = (; x=1.0)
@test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0)

@test hash(MutableTangent{Any}(; x=1.0)) == hash(MutableTangent{MDemo}(; x=1.0))
@test hash(MutableTangent{MDemo}(; x=1f0)) == hash(MutableTangent{MDemo}(; x=1.0))
end

@testset "Mutation" begin
v = MutableTangent{MFoo}(x=1.5, y=2.4)
v.x = 1.6
@test v == MutableTangent{MFoo}(x=1.6, y=2.4)
v.y = [1.0, 2.0] # change type, because primal can change type
@test v == MutableTangent{MFoo}(x=1.6, y=[1.0, 2.0])
end
end

@testset "map" begin
@testset "Tangent" begin
∂foo = Tangent{Foo}(x=1.5, y=2.4)
@test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0, y=4.8)

∂foo = Tangent{Foo}(x=1.5)
@test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0)
end
@testset "MutableTangent" begin
∂foo = MutableTangent{MFoo}(x=1.5, y=2.4)
∂foo2 = map(v->2*v, ∂foo)
@test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=4.8)
# Check can still be mutated to new typ
∂foo2.y=[1.0, 2.0]
@test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=[1.0, 2.0])
end
end

0 comments on commit d116906

Please sign in to comment.