From 4facdbb1024277045815abde200947317bd9350d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Mar 2024 20:19:12 -0400 Subject: [PATCH] Extend inplace mul --- src/api.jl | 2 +- src/matrix.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/api.jl b/src/api.jl index 1c12309..348c413 100644 --- a/src/api.jl +++ b/src/api.jl @@ -55,7 +55,7 @@ TODO: Needs Documentation (take from NNlib.jl) """ batched_mul(A, B) = _batched_mul(A, B) -batched_mul!(C, A, B) = _batched_mul!(C, A, B) +batched_mul!(C, A, B, α = true, β = false) = _batched_mul!(C, A, B, α, β) """ batched_transpose(X::AbstractArray{T, 3}) where {T} diff --git a/src/matrix.jl b/src/matrix.jl index 4f7cc1e..c7e4ee2 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -282,6 +282,14 @@ function Base.:*(X::AbstractArray{T, 2}, Y::UniformBlockDiagonalMatrix) where {T return dropdims(C.data; dims=1) end +function LinearAlgebra.mul!(A::AbstractMatrix, B::AbstractMatrix, + C::UniformBlockDiagonalMatrix, α::Number=true, β::Number=false) + A_ = reshape(A, 1, :, nbatches(A)) + B_ = reshape(B, 1, :, nbatches(B)) + batched_mul!(A_, B_, C.data, α, β) + return A +end + # LinearAlgebra abstract type AbstractBatchedMatrixFactorization{T} <: LinearAlgebra.Factorization{T} end