Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Extend inplace mul
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 21, 2024
1 parent 8037aa7 commit 4facdbb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ TODO: Needs Documentation (take from NNlib.jl)
"""
batched_mul(A, B) = _batched_mul(A, B)

batched_mul!(C, A, B) = _batched_mul!(C, A, B)
batched_mul!(C, A, B, α = true, β = false) = _batched_mul!(C, A, B, α, β)

"""
batched_transpose(X::AbstractArray{T, 3}) where {T}
Expand Down
8 changes: 8 additions & 0 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ function Base.:*(X::AbstractArray{T, 2}, Y::UniformBlockDiagonalMatrix) where {T
return dropdims(C.data; dims=1)
end

function LinearAlgebra.mul!(A::AbstractMatrix, B::AbstractMatrix,
C::UniformBlockDiagonalMatrix, α::Number=true, β::Number=false)
A_ = reshape(A, 1, :, nbatches(A))
B_ = reshape(B, 1, :, nbatches(B))
batched_mul!(A_, B_, C.data, α, β)
return A
end

# LinearAlgebra
abstract type AbstractBatchedMatrixFactorization{T} <: LinearAlgebra.Factorization{T} end

Expand Down

0 comments on commit 4facdbb

Please sign in to comment.