Skip to content

Commit

Permalink
Merge pull request #349 from biaslab/dev-fix-348
Browse files Browse the repository at this point in the history
Fix ambiguity error for * between Diagonal and StandardBasisVector
  • Loading branch information
bvdmitri authored Sep 11, 2023
2 parents a8e6f94 + 135124d commit 3ad37f7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 40 deletions.
36 changes: 23 additions & 13 deletions src/helpers/algebra/standard_basis_vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,16 @@ function LinearAlgebra.dot(e1::StandardBasisVector{T1}, e2::StandardBasisVector{
return ifelse(getind(e1) === getind(e2), convert(T, e1.scale * e2.scale), zero(T))::T
end

function LinearAlgebra.dot(e1::StandardBasisVector, A::AbstractMatrix, e2::StandardBasisVector)
@inline function __dot3_basis_vector_mat(e1, A, e2)
@assert size(A) == (length(e1), length(e2))
return e1.scale * A[getind(e1), getind(e2)] * e2.scale
end

# Julia does not understand union here and throws an ambiguity error
LinearAlgebra.dot(e1::StandardBasisVector, A::AbstractMatrix, e2::StandardBasisVector) = __dot3_basis_vector_mat(e1, A, e2)
LinearAlgebra.dot(e1::StandardBasisVector, A::Diagonal, e2::StandardBasisVector) = __dot3_basis_vector_mat(e1, A, e2)
LinearAlgebra.dot(e1::StandardBasisVector, A::Adjoint{T, <:AbstractMatrix{T}}, e2::StandardBasisVector) where {T} = __dot3_basis_vector_mat(e1, A, e2)

# vector - vector
function Base.:*(v::AbstractVector{T1}, a::Adjoint{T2, StandardBasisVector{T2}}) where {T1 <: Real, T2 <: Real}
parent = a'
Expand Down Expand Up @@ -163,27 +168,26 @@ function Base.:*(v::Adjoint{T1, <:AbstractVector{T1}}, e::StandardBasisVector{T2
end

# vector matrix
function Base.:*(A::AbstractMatrix, e::StandardBasisVector)
@assert size(A, 2) === length(e)
v = A[:, getind(e)]
v = mul_inplace!(e.scale, v)
return v
end

function Base.:*(A::Adjoint{T, <:AbstractMatrix{T}}, e::StandardBasisVector) where {T <: Real}
@inline function __mul_mat_basis_vector(A, e)
@assert size(A, 2) === length(e)
v = A[:, getind(e)]
v = mul_inplace!(e.scale, v)
return v
end

function Base.:*(A::AbstractMatrix{T1}, a::Adjoint{T2, StandardBasisVector{T2}}) where {T2 <: Real, T1 <: Real}
# Julia does not understand `Union` here and throws an ambiguity error
Base.:*(A::AbstractMatrix, e::StandardBasisVector) = __mul_mat_basis_vector(A, e)
Base.:*(A::Diagonal, e::StandardBasisVector) = __mul_mat_basis_vector(A, e)
Base.:*(A::Adjoint{T, <:AbstractMatrix{T}}, e::StandardBasisVector) where {T <: Real} = __mul_mat_basis_vector(A, e)

@inline function __mul_mat_adjoint_basis_vector(A, e)
sA = size(A)
@assert sA[2] === 1
p = a'
p = e'
N = length(p)
I = getind(p)
T = promote_type(T1, T2)
T = promote_type(eltype(A), eltype(e))
s = p.scale
result = zeros(T, sA[1], N)
@inbounds @simd for k in 1:sA[1]
Expand All @@ -192,12 +196,15 @@ function Base.:*(A::AbstractMatrix{T1}, a::Adjoint{T2, StandardBasisVector{T2}})
return result
end

function Base.:*(e::StandardBasisVector{T1}, A::AbstractMatrix{T2}) where {T1 <: Real, T2 <: Real}
Base.:*(A::AbstractMatrix, e::Adjoint{T2, StandardBasisVector{T2}}) where {T2} = __mul_mat_adjoint_basis_vector(A, e)
Base.:*(A::Diagonal, e::Adjoint{T2, StandardBasisVector{T2}}) where {T2} = __mul_mat_adjoint_basis_vector(A, e)

@inline function __mul_basis_vector_mat(e, A)
sA = size(A)
@assert sA[1] === 1
N = length(e)
I = getind(e)
T = promote_type(T1, T2)
T = promote_type(eltype(e), eltype(A))
s = e.scale
result = zeros(T, N, sA[2])
@inbounds @simd for k in 1:sA[2]
Expand All @@ -206,6 +213,9 @@ function Base.:*(e::StandardBasisVector{T1}, A::AbstractMatrix{T2}) where {T1 <:
return result
end

Base.:*(e::StandardBasisVector, A::AbstractMatrix) = __mul_basis_vector_mat(e, A)
Base.:*(e::StandardBasisVector, A::Diagonal) = __mul_basis_vector_mat(e, A)

function Base.:*(e::StandardBasisVector, A::Adjoint{T, <:AbstractMatrix{T}}) where {T <: Real}
@assert size(A, 2) === length(e)
v = A[:, getind(e)]
Expand Down
56 changes: 29 additions & 27 deletions test/algebra/test_standard_basis_vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,35 @@ using LinearAlgebra
@test e * m == e_c * m
@test e' * m == e_c' * m

@test (A * e) == (A * e_c)
@test (A' * e) == (A' * e_c)
@test (e * e') == (e_c * e_c')
@test (e' * e) == (e_c' * e_c)
@test (v' * e) == (v' * e_c)
@test (e' * v) == (e_c' * v)
@test (e' * v) == (e_c' * v)
@test (a' * e) == (a' * e_c)
@test (a * e') == (a * e_c')

t = rand(rng, T)

@test ReactiveMP.v_a_vT(e, t) ReactiveMP.v_a_vT(e_c, t)
@test ReactiveMP.v_a_vT(e, t, e) ReactiveMP.v_a_vT(e_c, t, e_c)

@test dot(e, A, e) === dot(e_c, A, e_c)
@test dot(e, e) === dot(e_c, e_c)
@test dot(e, e_c) === dot(e_c, e_c)
@test dot(e_c, e) === dot(e_c, e_c)
@test dot(v, e) === dot(v, e_c)
@test dot(e, v) === dot(e_c, v)
@test dot(v, e') === dot(v, e_c')
@test dot(e', v) === dot(e_c', v)
@test dot(v', e) === dot(v', e_c)
@test dot(e, v') === dot(e_c, v')
@test dot(v', e') === dot(v', e_c')
@test dot(e', v') === dot(e_c', v')
for A in (A, Diagonal(diag(A)), A')
@test (A * e) == (A * e_c)
@test (A' * e) == (A' * e_c)
@test (e * e') == (e_c * e_c')
@test (e' * e) == (e_c' * e_c)
@test (v' * e) == (v' * e_c)
@test (e' * v) == (e_c' * v)
@test (e' * v) == (e_c' * v)
@test (a' * e) == (a' * e_c)
@test (a * e') == (a * e_c')

t = rand(rng, T)

@test ReactiveMP.v_a_vT(e, t) ReactiveMP.v_a_vT(e_c, t)
@test ReactiveMP.v_a_vT(e, t, e) ReactiveMP.v_a_vT(e_c, t, e_c)

@test dot(e, A, e) === dot(e_c, A, e_c)
@test dot(e, e) === dot(e_c, e_c)
@test dot(e, e_c) === dot(e_c, e_c)
@test dot(e_c, e) === dot(e_c, e_c)
@test dot(v, e) === dot(v, e_c)
@test dot(e, v) === dot(e_c, v)
@test dot(v, e') === dot(v, e_c')
@test dot(e', v) === dot(e_c', v)
@test dot(v', e) === dot(v', e_c)
@test dot(e, v') === dot(e_c, v')
@test dot(v', e') === dot(v', e_c')
@test dot(e', v') === dot(e_c', v')
end
end
end
end
Expand Down

2 comments on commit 3ad37f7

@bvdmitri
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91222

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.9.3 -m "<description of version>" 3ad37f768f4329098e47988cb2ee2b26bbfa75f3
git push origin v3.9.3

Please sign in to comment.