From 89c8d4311b8a866ac90edcc4fbae3b594c8febb5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 1 Nov 2023 23:02:09 -0400 Subject: [PATCH] in-place destructure --- src/Optimisers.jl | 2 +- src/destructure.jl | 73 +++++++++++++++++++++++++++++++++++++++++---- test/destructure.jl | 42 ++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 6 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 1451bc89..42d5b3d6 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -9,7 +9,7 @@ export AbstractRule include("adjust.jl") include("destructure.jl") -export destructure +export destructure, destructure! include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, diff --git a/src/destructure.jl b/src/destructure.jl index 3b21d918..9cf1e6d1 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -9,6 +9,8 @@ Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model to a vector, and returns also a function which reverses this transformation. Differentiable. +See also [`destructure!`](@ref). + # Example ```jldoctest julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im]))) @@ -31,6 +33,36 @@ function destructure(x) flat, Restructure(x, off, len) end +""" + destructure!(model) -> vector, reconstructor + +This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model. +Requires that all trainable parameters in the model be mutable arrays! + +# Example +```jldoctest +julia> m = (x=[1.0, 2.0], y=(sin, Float32[3.0 4.0], cos)) + +julia> v, re! = destructure!(m) +([1.0, 2.0, 3.0, 4.0], Restructure!(NamedTuple, ..., 4)) + +julia> m === re!([3, 5, 7, 9]) # mutates the original m, and returns it +true + +julia> m +(x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos)) +``` +""" +function destructure!(x) + flat, off, len = _flatten(x) + flat, Restructure!(x, off, len) +end + +# function destructure!(flat::AbstractVector, x) +# flat, off, len = _flatten!(flat, x) +# flat, Restructure!(x, off, len) +# end + """ Restructure(Model, ..., length) @@ -55,12 +87,20 @@ struct Restructure{T,S} model::T offsets::S length::Int + mutate::Bool end -(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length) +Restructure(model, offsets, length) = Restructure(model, offsets, length, false) +Restructure!(model, offsets, length) = Restructure(model, offsets, length, true) + +(re::Restructure)(flat::AbstractVector) = re.mutate ? _rebuild!(re.model, re.offsets, flat, re.length) : _rebuild(re.model, re.offsets, flat, re.length) (re::Restructure)(x, flat::AbstractVector) = re(flat)(x) -Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") Base.length(re::Restructure) = re.length +function Base.show(io::IO, re::Restructure{T}) where T + print(io, "Restructure", re.mutate ? "!" : "") + print(io, "(", T.name.name, ", ..., ", re.length, ")") +end + # This flattens a model, and returns a web of offsets for later use: function _flatten(x) isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case @@ -75,6 +115,17 @@ function _flatten(x) isempty(arrays) && return Bool[], off, 0 reduce(vcat, arrays), off, len[] end +# function _flatten!(flat, x) +# isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case +# len = Ref(0) +# off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y +# o = len[] +# copyto!(flat, o, _vec(y)) +# len[] = o + length(y) +# o +# end +# flat, off, len[] +# end struct _TrainableStructWalk <: AbstractWalk end @@ -97,10 +148,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trai _getat(y, o, flat) end end +# (mutating version, same arguments & same return) +function _rebuild!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...) + len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))")) + fmap(x, off; exclude = isnumeric, walk, kw...) do y, o + copyto!(y, _getat(y, o, flat, view)) + end + x +end -_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) -_getat(y::AbstractArray, o::Int, flat::AbstractVector) = - ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes +_getat(y::Number, o::Int, flat::AbstractVector, _...) = ProjectTo(y)(flat[o + 1]) +_getat(y::AbstractArray, o::Int, flat::AbstractVector, get=getindex) = + ProjectTo(y)(reshape(get(flat, o .+ (1:length(y))), axes(y))) # ProjectTo is just correcting eltypes struct _Trainable_biwalk <: AbstractWalk end @@ -135,6 +194,10 @@ function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...) _rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT) _rebuild(x, off, flat, len; kw...), _rebuild_back end +function ChainRulesCore.rrule(::typeof(_rebuild!), x, off, flat, len; kw...) + _rebuild!_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT) + _rebuild!(x, off, flat, len; kw...), _rebuild!_back +end _zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad! ChainRulesCore.@non_differentiable _zero(x) diff --git a/test/destructure.jl b/test/destructure.jl index 90f28fb4..8b712d07 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -24,8 +24,10 @@ m9 = (a = m1, b = mat, c = [mat, m1]) @test destructure(m9)[1] == 1:7 @test destructure(m1)[2](7:9) == [7,8,9] + @test m1 == 1:3 # not mutated @test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9]) @test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9]) + @test m3.z == 4:6 # not mutated m4′ = destructure(m4)[2](4:9) @test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9]) @test m4′.x === m4′.y @@ -60,11 +62,31 @@ m9 = (a = m1, b = mat, c = [mat, m1]) @test_throws Exception destructure(m7)[2]([10,20,30,40]) end +@testset "destructure!" begin + m3′ = deepcopy(m3) + @test destructure!(m3′)[1] == 1:6 + @test destructure!(m3′)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9]) + @test m3′ == (x = [4,5,6], y = sin, z = [7,8,9]) + + m7′ = deepcopy(m7) + @test destructure!(m7′)[1] == 1:3 + destructure!(m7′)[2]([10,20,30]) + @test m7′.a == (sin, [10,20,30]) + @test m7′.b == (cos, [4,5,6]) + @test m7′.c == (tan, [7,8,9]) + + # errors + @test_throws Exception destructure!(m7)[2]([10,20]) + @test_throws Exception destructure!(m7)[2]([10,20,30,40]) +end + @testset "gradient of flatten" begin @test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] + @test gradient(m -> destructure!(m)[1][1], m1)[1] == [1,0,0] @test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) @test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) @test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) + @test gradient(m -> destructure!(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) @test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) g5 = gradient(m -> destructure(m)[1][3], m5)[1] @@ -206,6 +228,26 @@ end end end +@testset "gradient of rebuild!" begin + re1 = destructure!(deepcopy(m1))[2] + @test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] + + re2 = destructure!(deepcopy(m2))[2] + @test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] + + re3 = destructure!(deepcopy(m3))[2] + @test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] + @test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] + + re4 = destructure!(deepcopy(m4))[2] + @test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] + @test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] + @test gradient(rand(6)) do x + m = re4(x) + m.x[1] + 2*m.y[2] + 3*m.z[3] + end[1] == [1,2,0, 0,0,3] +end + @testset "Flux issue 1826" begin v, re = destructure((x=[1,2.0], y=[3,4,5.0])) @test gradient(zero(v)) do w