Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Fix some of the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 29, 2024
1 parent 7dd7160 commit fa4a2f5
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/BatchedRoutines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import PrecompileTools: @recompile_invalidations
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero,
mul!, pinv
using LuxDeviceUtils: LuxDeviceUtils, get_device
using SciMLOperators: AbstractSciMLOperator
using SciMLOperators: SciMLOperators, AbstractSciMLOperator
end

function __init__()
@static if isdefined(Base.Experimental, :register_error_hint)
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs
if any(Base.Fix2(isa, UniformBlockDiagonalOperator), exc.args)
print(io, "\nHINT: ")
printstyled(io, "\nHINT: "; bold=true)
printstyled(
io, "`UniformBlockDiagonalOperator` doesn't support AbstractArray \
operations. If you want this supported open an issue at \
Expand Down
6 changes: 3 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) wher
gradient_ad = AutoZygote()
_map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
x -> batched_gradient(gradient_ad, x_ -> sum(vec(f(x_))[i:i]), x), x, Δᵢ)
∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x))
∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
return J, ∇batched_jacobian
Expand All @@ -28,13 +28,13 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F}
_map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
x -> batched_gradient(AutoZygote(), x_ -> sum(vec(f(x_, p))[i:i]), x), x, Δᵢ)

∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x))
∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x))

_map_fnₚ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
(x, p_) -> batched_gradient(AutoZygote(), p__ -> sum(vec(f(x, p__))[i:i]), p_),
x, Δᵢ, p)

∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(_eachrow(Δ))), size(p))
∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(eachrow(Δ))), size(p))

return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p
end
Expand Down
23 changes: 23 additions & 0 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ function UniformBlockDiagonalOperator(X::AbstractMatrix)
return UniformBlockDiagonalOperator(reshape(X, size(X, 1), size(X, 2), 1))
end

# SciMLOperators Interface
## Even though it is easily convertible, it is helpful to get warnings
SciMLOperators.isconvertible(::UniformBlockDiagonalOperator) = false

# BatchedRoutines API
getdata(op::UniformBlockDiagonalOperator) = op.data
nbatches(op::UniformBlockDiagonalOperator) = size(op.data, 3)
Expand Down Expand Up @@ -172,6 +176,25 @@ end
return UniformBlockDiagonalOperator(copy(getdata(op)))
end

## Define some of the common operations like `sum` directly since SciMLOperators doesn't
## use a very nice implemented
@inline function Base.sum(op::UniformBlockDiagonalOperator; kwargs...)
return sum(identity, op; kwargs...)
end

@inline function Base.sum(f::F, op::UniformBlockDiagonalOperator; dims=Colon()) where {F}
return mapreduce(f, +, op; dims)
end

## Common Operations
function Base.:+(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator)
return UniformBlockDiagonalOperator(getdata(op1) + getdata(op2))
end

function Base.:-(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator)
return UniformBlockDiagonalOperator(getdata(op1) - getdata(op2))
end

# Adapt
@inline function Adapt.adapt_structure(to, op::UniformBlockDiagonalOperator)
return UniformBlockDiagonalOperator(Adapt.adapt(to, getdata(op)))
Expand Down
4 changes: 2 additions & 2 deletions test/integration_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ end
loss_function = (model, x, target_jac, ps, st) -> begin
m = StatefulLuxLayer(model, nothing, st)
jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x, ps)
return sum(abs2, jac_full .- target_jac)
return sum(abs2, jac_full - target_jac)
end

@test loss_function(model, x, target_jac, ps, st) isa Number
Expand All @@ -94,7 +94,7 @@ end
loss_function2 = (model, x, target_jac, ps, st) -> begin
m = StatefulLuxLayer(model, ps, st)
jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x)
return sum(abs2, jac_full .- target_jac)
return sum(abs2, jac_full - target_jac)
end

@test loss_function2(model, x, target_jac, ps, st) isa Number
Expand Down

0 comments on commit fa4a2f5

Please sign in to comment.