From 68b226acb6c7604fa30e8bc946230c6f3f0c0865 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 29 Mar 2024 11:29:58 -0400 Subject: [PATCH] Add rules for common sum --- src/chainrules.jl | 29 +++++++++++++++++++++++++++++ src/operator.jl | 1 + 2 files changed, 30 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index ecf6d32..15ac2c6 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -116,3 +116,32 @@ function CRC.rrule(::typeof(getproperty), op::UniformBlockDiagonalOperator, x::S ∇getproperty(Δ) = (NoTangent(), UniformBlockDiagonalOperator(Δ)) return op.data, ∇getproperty end + +# mapreduce fallback rules for UniformBlockDiagonalOperator +@inline _unsum(x, dy, dims) = broadcast(last ∘ tuple, x, dy) +@inline _unsum(x, dy, ::Colon) = broadcast(last ∘ tuple, x, Ref(dy)) + +function CRC.rrule(::typeof(sum), ::typeof(abs2), op::UniformBlockDiagonalOperator{T}; + dims=:) where {T <: Union{Real, Complex}} + y = sum(abs2, op; dims) + ∇sum_abs2 = @closure Δ -> begin + ∂op = if dims isa Colon + UniformBlockDiagonalOperator(2 .* real.(Δ) .* getdata(op)) + else + UniformBlockDiagonalOperator(2 .* real.(getdata(Δ)) .* getdata(op)) + end + return NoTangent(), NoTangent(), ∂op + end + return y, ∇sum_abs2 +end + +function CRC.rrule(::typeof(sum), ::typeof(identity), op::UniformBlockDiagonalOperator{T}; + dims=:) where {T <: Union{Real, Complex}} + y = sum(abs2, op; dims) + project = CRC.ProjectTo(getdata(op)) + ∇sum_abs2 = @closure Δ -> begin + ∂op = project(_unsum(getdata(op), getdata(Δ), dims)) + return NoTangent(), NoTangent(), UniformBlockDiagonalOperator(∂op) + end + return y, ∇sum_abs2 +end diff --git a/src/operator.jl b/src/operator.jl index 4895fd9..b632045 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -12,6 +12,7 @@ SciMLOperators.isconvertible(::UniformBlockDiagonalOperator) = false # BatchedRoutines API getdata(op::UniformBlockDiagonalOperator) = op.data +getdata(x) = x nbatches(op::UniformBlockDiagonalOperator) = size(op.data, 3) batchview(op::UniformBlockDiagonalOperator) = batchview(op.data) batchview(op::UniformBlockDiagonalOperator, i::Int) = batchview(op.data, i)