Skip to content

Commit

Permalink
Fix promotion and adjoint for complex expressions (#3150)
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 15, 2022
1 parent 73d05bf commit d986530
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 22 deletions.
5 changes: 3 additions & 2 deletions src/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1055,9 +1055,10 @@ Base.ndims(::AbstractJuMPScalar) = 0

# These are required to create symmetric containers of AbstractJuMPScalars.
LinearAlgebra.symmetric_type(::Type{T}) where {T<:AbstractJuMPScalar} = T
LinearAlgebra.hermitian_type(::Type{T}) where {T<:AbstractJuMPScalar} = T
LinearAlgebra.symmetric(scalar::AbstractJuMPScalar, ::Symbol) = scalar
# This is required for linear algebra operations involving transposes.
LinearAlgebra.adjoint(scalar::AbstractJuMPScalar) = scalar
LinearAlgebra.hermitian(scalar::AbstractJuMPScalar, ::Symbol) = adjoint(scalar)
LinearAlgebra.adjoint(scalar::AbstractJuMPScalar) = conj(scalar)

"""
owner_model(s::AbstractJuMPScalar)
Expand Down
23 changes: 14 additions & 9 deletions src/aff_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,7 @@ coefficient(::GenericAffExpr{C,V}, ::V, ::V) where {C,V} = zero(C)
Remove terms in the affine expression with `0` coefficients.
"""
function drop_zeros!(expr::GenericAffExpr)
for (key, coef) in expr.terms
if iszero(coef)
delete!(expr.terms, key)
end
end
_drop_zeros!(expr.terms)
return
end

Expand Down Expand Up @@ -496,16 +492,25 @@ end

Base.hash(aff::GenericAffExpr, h::UInt) = hash(aff.constant, hash(aff.terms, h))

function SparseArrays.dropzeros(aff::GenericAffExpr)
result = copy(aff)
for (coef, var) in linear_terms(aff)
function _drop_zeros!(terms::OrderedDict)
for (var, coef) in terms
if iszero(coef)
delete!(result.terms, var)
delete!(terms, var)
elseif coef isa Complex && iszero(imag(coef))
terms[var] = real(coef)
end
end
return
end

function SparseArrays.dropzeros(aff::GenericAffExpr)
result = copy(aff)
_drop_zeros!(result.terms)
if iszero(result.constant)
# This is to work around isequal(0.0, -0.0) == false.
result.constant = zero(typeof(result.constant))
elseif result.constant isa Complex && iszero(imag(result.constant))
result.constant = real(result.constant)
end
return result
end
Expand Down
8 changes: 7 additions & 1 deletion src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,16 @@ function Base.promote_rule(
end
function Base.promote_rule(
::Type{GenericAffExpr{S,V}},
R::Type{<:Real},
R::Type{<:Number},
) where {S,V}
return GenericAffExpr{promote_type(S, R),V}
end
function Base.promote_rule(
::Type{<:GenericAffExpr{S,V}},
::Type{<:GenericAffExpr{T,V}},
) where {S,T,V}
return GenericAffExpr{promote_type(S, T),V}
end
function Base.promote_rule(
::Type{<:GenericAffExpr{S,V}},
::Type{<:GenericQuadExpr{T,V}},
Expand Down
12 changes: 2 additions & 10 deletions src/quad_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ Remove terms in the quadratic expression with `0` coefficients.
"""
function drop_zeros!(expr::GenericQuadExpr)
drop_zeros!(expr.aff)
for (key, coef) in expr.terms
if iszero(coef)
delete!(expr.terms, key)
end
end
_drop_zeros!(expr.terms)
return
end

Expand Down Expand Up @@ -497,11 +493,7 @@ Base.hash(quad::GenericQuadExpr, h::UInt) = hash(quad.aff, hash(quad.terms, h))

function SparseArrays.dropzeros(quad::GenericQuadExpr)
quad_terms = copy(quad.terms)
for (key, value) in quad.terms
if iszero(value)
delete!(quad_terms, key)
end
end
_drop_zeros!(quad_terms)
return GenericQuadExpr(dropzeros(quad.aff), quad_terms)
end

Expand Down
43 changes: 43 additions & 0 deletions test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ module TestComplexNumberSupport

using JuMP
using Test

import LinearAlgebra
import MutableArithmetics
import SparseArrays

const MA = MutableArithmetics

function runtests()
Expand Down Expand Up @@ -201,6 +205,45 @@ function test_complex_abs2()
@test abs2(x * im + 2) == x^2 + 4
end

function test_hermitian()
model = Model()
@variable(model, x)
A = [3 1im; -1im 2x]
@test A isa Matrix{GenericAffExpr{ComplexF64,VariableRef}}
A = [3x^2 1im; -1im 2x]
@test A isa Matrix{GenericQuadExpr{ComplexF64,VariableRef}}
A = [3x 1im; -1im 2x^2]
@test A isa Matrix{GenericQuadExpr{ComplexF64,VariableRef}}
A = [3x 1im; -1im 2x]
@test A isa Matrix{GenericAffExpr{ComplexF64,VariableRef}}
@test isequal_canonical(A', A)
H = LinearAlgebra.Hermitian(A)
T = GenericAffExpr{ComplexF64,VariableRef}
@test H isa LinearAlgebra.Hermitian{T,Matrix{T}}
@test isequal_canonical(A[1, 2], LinearAlgebra.adjoint(A[2, 1]))
@test isequal_canonical(H[1, 2], LinearAlgebra.adjoint(H[2, 1]))
for i in 1:2, j in 1:2
@test isequal_canonical(A[i, j], H[i, j])
end
return
end

function test_complex_sparse_arrays_dropzeros()
model = Model()
@variable(model, x)
a = 2.0 + 1.0im
for rhs in (0.0 + 0.0im, 0.0 - 0.0im, -0.0 + 0.0im, -0.0 + -0.0im)
# We need to explicitly set the .constant field to avoid a conversion to
# 0.0 + 0.0im
expr = a * x
expr.constant = rhs
@test isequal(SparseArrays.dropzeros(expr), a * x)
expr.constant = 1.0 + rhs
@test isequal(SparseArrays.dropzeros(expr), a * x + 1.0)
end
return
end

end

TestComplexNumberSupport.runtests()

0 comments on commit d986530

Please sign in to comment.