Skip to content

Commit

Permalink
Multi arg fwd gradient (#1952)
Browse files Browse the repository at this point in the history
* Multi arg fwd gradient

* multi arg deriv

* fix

* fix

* Update Enzyme.jl

* cleanup

* fix

* Update Enzyme.jl

* Update Enzyme.jl
  • Loading branch information
wsmoses authored Oct 10, 2024
1 parent 4d4c546 commit ad86689
Show file tree
Hide file tree
Showing 3 changed files with 869 additions and 568 deletions.
322 changes: 222 additions & 100 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1794,16 +1794,30 @@ end
@inline tupleconcat(x, y) = (x..., y...)
@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)

function create_shadows(::Nothing, x)
return (onehot(x),)
end

function create_shadows(::Val{1}, x)
return (onehot(x),)
end

function create_shadows(::Val{chunk}, x) where {chunk}
return (chunkedonehot(x, Val(chunk)),)
@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N}
args = Union{Symbol,Expr}[:x]
tys = Type[X]
for i in 1:N
push!(args, :(vargs[$i]))
push!(tys, vargs[i])
end

exprs = Union{Symbol,Expr}[]
for (arg, ty) in zip(args, tys)
if ty <: Enzyme.Const
push!(exprs, :(nothing))
elseif ty <: AbstractFloat
push!(exprs, :(nothing))
elseif ChunkTy == Nothing || ChunkTy == Val{1}
push!(exprs, :(onehot($arg)))
else
push!(exprs, :(chunkedonehot($arg, chunk)))
end
end
return quote
Base.@_inline_meta
($(exprs...),)
end
end

struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N}
Expand Down Expand Up @@ -1890,7 +1904,7 @@ gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1))
(derivs = ([3.0, 2.0],), val = 6.0)
```
For functions which return an AbstractArray or scalar, this function will return an AbstracttArray
For functions which return an AbstractArray or scalar, this function will return an AbstractArray
whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made
about the type of the AbstractArray returned by this function (which may or may not be the same
as the input AbstractArray if provided).
Expand All @@ -1905,119 +1919,227 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0])
# output
([3.0 2.0 0.0; 0.0 1.0 1.0],)
```
This function supports multiple arguments and computes the gradient with respect to each
```jldoctest gradfwd2
mul(x, y) = x[1]*y[2] + x[2]*y[1]
gradient(Forward, mul, [2.0, 3.0], [2.7, 3.1])
# output
([3.1, 2.7], [3.0, 2.0])
```
This includes the ability to mark some arguments as `Const` if its derivative is not needed, returning nothing in the corresponding derivative map.
```jldoctest gradfwd2
gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
# output
([3.1, 2.7], nothing)
```
"""
@inline function gradient(
@generated function gradient(
fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity},
f,
x;
f::F,
x::ty_0,
args::Vararg{Any,N};
chunk::CS = nothing,
shadows = create_shadows(chunk, x),
) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS}
if length(shadows[1]) == 0
return if ReturnPrimal
(; derivs = (x,), val = f(x.val))
shadows::ST = create_shadows(chunk, x, args...),
) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS,ST, ty_0, N}

syms = Union{Symbol,Expr}[:x]
shads = Union{Symbol,Expr}[:(shadows[1])]
tys = Type[ty_0]
for i in 1:N
push!(syms, :(args[$i]))
push!(tys, args[i])
push!(shads, :(shadows[1+$i]))
end
fval = if F <: Annotation
:(f.val)
else
:f
end

vals = Union{Symbol,Expr}[]
consts = Union{Symbol,Expr}[]
for (arg, ty) in zip(syms, tys)
if ty <: Const
push!(vals, :($arg.val))
push!(consts, arg)
else
(x,)
push!(vals, arg)
push!(consts, :(Const($arg)))
end
end
if chunk == Val(0)
throw(ErrorException("Cannot differentiate with a batch size of 0"))

if CS == Val{0}
return quote
Base.@_inline_meta
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
end

gradtup = if chunk == nothing
resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1]))
exprs = Union{Symbol,Expr}[]
primal = nothing
derivatives = Union{Symbol,Expr}[]

res = values(resp[1])
dres = if x isa AbstractFloat
res[1]
else
res
primmode = :(fm)
for (i, (arg, ty)) in enumerate(zip(syms, tys))
if ty <: Const
push!(derivatives, :(nothing))
continue
end
if ReturnPrimal
((dres,), resp[2])
else
(dres,)
end
elseif chunk == Val(1)
if ReturnPrimal
rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1]))
dres1 = rp[1]
fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=#

res = ntuple(length(shadows[1]) - 1) do i
autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1]
argnum = length(ST.parameters[i].parameters)

argderivative = if ty <: AbstractFloat
dargs = Union{Symbol,Expr}[]
for (j, arg2) in enumerate(syms)
if i == j
push!(dargs, :(Duplicated($arg, one($arg))))
else
push!(dargs, consts[j])
end
end
gres = if x isa AbstractFloat
dres1[1]
else
(dres1, res...)

resp = Symbol("resp_$i")
push!(exprs, quote
$resp = autodiff($primmode, f, Duplicated, $(dargs...))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end
((gres,), rp[2])
else
res = ntuple(length(shadows[1])) do i
autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][i]))[1]

:($resp[1])
elseif argnum == 0
vals[i]
elseif CS == Nothing
dargs = Union{Symbol,Expr}[]
for (j, arg2) in enumerate(syms)
if i == j
push!(dargs, :(BatchDuplicated($arg, $(shads[i]))))
else
push!(dargs, consts[j])
end
end
(if x isa AbstractFloat
res[1]
else
res
end,)
end
else
if ReturnPrimal
rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][1]))
dres1 = values(rp[1])
gres = if x isa AbstractFloat
dres1[1]
else
fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=#
tmp = ntuple(length(shadows[1]) - 1) do i
values(
autodiff(
fm2,
f,
BatchDuplicated,
BatchDuplicated(x, shadows[1][i+1]),
)[1],
)

df = :f
if F <: Enzyme.Duplicated
zeros = Expr[]
for i in 1:argnum
push!(zeros, :(f.dval))
end
tupleconcat(dres1, tmp...)
df = :(BatchDuplicated(f.val, ($(zeros...),) ))
end

resp = Symbol("resp_$i")
push!(exprs, quote
$resp = autodiff($primmode, $df, BatchDuplicated, $(dargs...))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end
((gres,), rp[2])

:(values($resp[1]))
elseif CS == Val{1}
subderivatives = Union{Symbol,Expr}[]
for an in 1:argnum
dargs = Union{Symbol,Expr}[]
for (j, arg2) in enumerate(syms)
if i == j
push!(dargs, :(Duplicated($arg, $(shads[i])[$an])))
else
push!(dargs, consts[j])
end
end

resp = Symbol("resp_$i"*"_"*string(an))
push!(exprs, quote
$resp = autodiff($primmode, f, Duplicated, $(dargs...))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end

push!(subderivatives, :(values($resp[1])))
end
:(($(subderivatives...),))
else
tmp = ntuple(length(shadows[1])) do i
values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i]))[1])
subderivatives = Union{Symbol,Expr}[]
for an in 1:argnum
dargs = Union{Symbol,Expr}[]
for (j, arg2) in enumerate(syms)
if i == j
push!(dargs, :(BatchDuplicated($arg, $(shads[i])[$an])))
else
push!(dargs, consts[j])
end
end

resp = Symbol("resp_$i"*"_"*string(an))
push!(exprs, quote
$resp = autodiff($primmode, f, BatchDuplicated, $(dargs...))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end

push!(subderivatives, :(values($resp[1])))
end
res = tupleconcat(tmp...)
(if x isa AbstractFloat
res[1]
:(tupleconcat($(subderivatives...)))
end

deriv = if ty <: AbstractFloat
argderivative
else
tmp = Symbol("tmp_$i")
push!(exprs, :($tmp = $argderivative))
if ty <: AbstractArray
if argnum > 0
quote
if $tmp[1] isa AbstractArray
inshape = size($(vals[1]))
outshape = size($tmp[1])
# st : outshape x total inputs
tupstack($tmp, outshape, inshape)
else
TupleArray($tmp, size($arg))
end
end
else
:(TupleArray($tmp, size($arg)))
end
else
res
end,)
tmp
end
end
push!(derivatives, deriv)
end

cols = if ReturnPrimal
gradtup[1][1]
else
gradtup[1]
end
res = if x isa AbstractFloat
cols
elseif length(cols) > 0 && cols[1] isa AbstractArray && x isa AbstractArray
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
tupstack(cols, outshape, inshape)
elseif x isa AbstractArray
TupleArray(cols, size(x))
else
cols
# We weirdly asked for no derivatives
if ReturnPrimal && primal == nothing
primal = :($fval($(vals...)))
end
if ReturnPrimal
(; derivs = (res,), val = gradtup[2])

result = if ReturnPrimal
:((; derivs = ($(derivatives...),), val = $primal))
else
(res,)
:(($(derivatives...),))
end

return quote
Base.@_inline_meta
$(exprs...)
$result
end
end

Expand Down
Loading

0 comments on commit ad86689

Please sign in to comment.