From 22621f0bce2b60e6e6f02086d6d23847fc909051 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 18:22:01 +0100 Subject: [PATCH 1/2] array multiplication --- src/rulesets/Base/arraymath.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 7e0329e8b..c5640d81d 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -28,12 +28,12 @@ function rrule( return ( NoTangent(), InplaceableThunk( - @thunk(Ȳ * B'), - X̄ -> mul!(X̄, Ȳ, B', true, true) + @thunk(project(A, Ȳ * B')), + X̄ -> project(A, mul!(X̄, Ȳ, B', true, true)) ), InplaceableThunk( - @thunk(A' * Ȳ), - X̄ -> mul!(X̄, A', Ȳ, true, true) + @thunk(project(B, A' * Ȳ)), + X̄ -> project(B, mul!(X̄, A', Ȳ, true, true)) ) ) end From 98a2bdff331c6a48f5877e22e503d32fd425aa9d Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 28 Jun 2021 14:00:20 +0100 Subject: [PATCH 2/2] change to preproject --- src/rulesets/Base/arraymath.jl | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index c5640d81d..892cbddbf 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -24,18 +24,21 @@ function rrule( A::AbstractVecOrMat{<:CommutativeMulNumber}, B::AbstractVecOrMat{<:CommutativeMulNumber}, ) + Ainfo = preproject(A) + Binfo = preproject(B) + valAtype = Val(typeof(A)) + valBtype = Val(typeof(B)) + _val(::Val{T}) where T = T function times_pullback(Ȳ) - return ( - NoTangent(), - InplaceableThunk( - @thunk(project(A, Ȳ * B')), - X̄ -> project(A, mul!(X̄, Ȳ, B', true, true)) - ), - InplaceableThunk( - @thunk(project(B, A' * Ȳ)), - X̄ -> project(B, mul!(X̄, A', Ȳ, true, true)) - ) + dA = InplaceableThunk( + @thunk(project(_val(valAtype), Ȳ * B'; info=Ainfo)), + X̄ -> mul!(X̄, Ȳ, B', true, true) + ) + dB = InplaceableThunk( + @thunk(project(_val(valBtype), A' * Ȳ; info=Binfo)), + X̄ -> mul!(X̄, A', Ȳ, true, true) ) + return NoTangent(), dA, dB end return A * B, times_pullback end