diff --git a/src/BatchedRoutines.jl b/src/BatchedRoutines.jl index 099ec6e..efb9cec 100644 --- a/src/BatchedRoutines.jl +++ b/src/BatchedRoutines.jl @@ -14,14 +14,14 @@ import PrecompileTools: @recompile_invalidations using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!, pinv using LuxDeviceUtils: LuxDeviceUtils, get_device - using SciMLOperators: AbstractSciMLOperator + using SciMLOperators: SciMLOperators, AbstractSciMLOperator end function __init__() @static if isdefined(Base.Experimental, :register_error_hint) Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs if any(Base.Fix2(isa, UniformBlockDiagonalOperator), exc.args) - print(io, "\nHINT: ") + printstyled(io, "\nHINT: "; bold=true) printstyled( io, "`UniformBlockDiagonalOperator` doesn't support AbstractArray \ operations. If you want this supported open an issue at \ diff --git a/src/chainrules.jl b/src/chainrules.jl index 70baf2e..ecf6d32 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -10,7 +10,7 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) wher gradient_ad = AutoZygote() _map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), x -> batched_gradient(gradient_ad, x_ -> sum(vec(f(x_))[i:i]), x), x, Δᵢ) - ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x)) + ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x)) return NoTangent(), NoTangent(), NoTangent(), ∂x end return J, ∇batched_jacobian @@ -28,13 +28,13 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F} _map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), x -> batched_gradient(AutoZygote(), x_ -> sum(vec(f(x_, p))[i:i]), x), x, Δᵢ) - ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x)) + ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x)) _map_fnₚ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), (x, p_) -> batched_gradient(AutoZygote(), p__ -> sum(vec(f(x, p__))[i:i]), p_), x, Δᵢ, p) - ∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(_eachrow(Δ))), size(p)) + ∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(eachrow(Δ))), size(p)) return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p end diff --git a/src/operator.jl b/src/operator.jl index 62e5c32..4895fd9 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -6,6 +6,10 @@ function UniformBlockDiagonalOperator(X::AbstractMatrix) return UniformBlockDiagonalOperator(reshape(X, size(X, 1), size(X, 2), 1)) end +# SciMLOperators Interface +## Even though it is easily convertible, it is helpful to get warnings +SciMLOperators.isconvertible(::UniformBlockDiagonalOperator) = false + # BatchedRoutines API getdata(op::UniformBlockDiagonalOperator) = op.data nbatches(op::UniformBlockDiagonalOperator) = size(op.data, 3) @@ -172,6 +176,25 @@ end return UniformBlockDiagonalOperator(copy(getdata(op))) end +## Define some of the common operations like `sum` directly since SciMLOperators doesn't +## use a very nice implemented +@inline function Base.sum(op::UniformBlockDiagonalOperator; kwargs...) + return sum(identity, op; kwargs...) +end + +@inline function Base.sum(f::F, op::UniformBlockDiagonalOperator; dims=Colon()) where {F} + return mapreduce(f, +, op; dims) +end + +## Common Operations +function Base.:+(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(getdata(op1) + getdata(op2)) +end + +function Base.:-(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(getdata(op1) - getdata(op2)) +end + # Adapt @inline function Adapt.adapt_structure(to, op::UniformBlockDiagonalOperator) return UniformBlockDiagonalOperator(Adapt.adapt(to, getdata(op))) diff --git a/test/integration_tests.jl b/test/integration_tests.jl index 4517c8e..c9a0ec0 100644 --- a/test/integration_tests.jl +++ b/test/integration_tests.jl @@ -72,7 +72,7 @@ end loss_function = (model, x, target_jac, ps, st) -> begin m = StatefulLuxLayer(model, nothing, st) jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x, ps) - return sum(abs2, jac_full .- target_jac) + return sum(abs2, jac_full - target_jac) end @test loss_function(model, x, target_jac, ps, st) isa Number @@ -94,7 +94,7 @@ end loss_function2 = (model, x, target_jac, ps, st) -> begin m = StatefulLuxLayer(model, ps, st) jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x) - return sum(abs2, jac_full .- target_jac) + return sum(abs2, jac_full - target_jac) end @test loss_function2(model, x, target_jac, ps, st) isa Number