Skip to content

Commit

Permalink
Make matrix multiplication work for more types
Browse files Browse the repository at this point in the history
Currently it is assumed that the type of a sum of x::T and y::T
is T but this may not be the case
  • Loading branch information
blegat committed Aug 25, 2016
1 parent 52346ec commit 05e2791
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
30 changes: 16 additions & 14 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
arithtype(T) = T
arithtype(::Type{Bool}) = Int

matprod(x, y) = x*y + x*y

# multiply by diagonal matrix as vector
function scale!(C::AbstractMatrix, A::AbstractMatrix, b::AbstractVector)
m, n = size(A)
Expand Down Expand Up @@ -76,11 +78,11 @@ At_mul_B{T<:BlasComplex}(x::StridedVector{T}, y::StridedVector{T}) = [BLAS.dotu(

# Matrix-vector multiplication
function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x))
end
function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_B!(similar(x,TS,size(A,1)),A,x)
end
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B
Expand All @@ -99,22 +101,22 @@ end
A_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'N', A, x)

function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x))
end
function At_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
At_mul_B!(similar(x,TS,size(A,2)), A, x)
end
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
At_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'T', A, x)

function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x))
end
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
end

Expand All @@ -125,7 +127,7 @@ Ac_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_m
# Matrix-matrix multiplication

function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
Expand All @@ -142,14 +144,14 @@ end
A_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B)

function At_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B)

function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B)
end
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
Expand All @@ -166,7 +168,7 @@ end
A_mul_Bt!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B)

function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractVecOrMat{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B)
end
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
Expand All @@ -175,7 +177,7 @@ At_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generi
Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B)
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
Expand All @@ -184,14 +186,14 @@ Ac_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic
A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B)
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, arithtype(T), arithtype(S))
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
end
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
A_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B)

Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) =
Ac_mul_Bc!(similar(B, promote_op(*, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!(similar(B, promote_op(matprod, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
Ac_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B)
Ac_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B)
Expand Down Expand Up @@ -424,7 +426,7 @@ end
function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, promote_op(*, arithtype(T), arithtype(S)), mA, nB)
C = similar(B, promote_op(matprod, arithtype(T), arithtype(S)), mA, nB)
generic_matmatmul!(C, tA, tB, A, B)
end

Expand Down
28 changes: 28 additions & 0 deletions test/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,31 @@ let
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
end
end

# #18218
import Base.*, Base.+, Base.zero
let
immutable TypeA
x::Int
end
Base.convert(::Type{TypeA}, x::Int) = TypeA(x)
immutable TypeB
x::Int
end
Base.convert(::Type{TypeB}, x::Int) = TypeB(x)
immutable TypeC
x::Int
end
immutable TypeD
x::Int
end
zero(d::TypeD) = TypeD(0)
zero(::Type{TypeD}) = TypeD(0)
(*)(a::Union{TypeA,TypeB}, b::Union{TypeA,TypeB}) = TypeC(a.x*b.x)
(+)(a::Union{TypeC,TypeD}, b::Union{TypeC,TypeD}) = TypeD(a.x+b.x)
A = TypeA[1 2; 3 4]
b = TypeB[1, 2]
d = A * b
@show typeof(d) == Vector{TypeD}
@show d == TypeD[5, 11]
end

0 comments on commit 05e2791

Please sign in to comment.