Skip to content

Commit

Permalink
Replace Val-types by singleton types in lu and qr (JuliaLang#40623)
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Noack <andreas@noack.dk>
  • Loading branch information
2 people authored and johanmon committed Jul 5, 2021
1 parent 94c3efa commit 8559e47
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 67 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ Standard library changes
* The shape of an `UpperHessenberg` matrix is preserved under certain arithmetic operations, e.g. when multiplying or dividing by an `UpperTriangular` matrix. ([#40039])
* `cis(A)` now supports matrix arguments ([#40194]).
* `dot` now supports `UniformScaling` with `AbstractMatrix` ([#40250]).
* `qr[!]` and `lu[!]` now support `LinearAlgebra.PivotingStrategy` (singleton type) values
as their optional `pivot` argument: defaults are `qr(A, NoPivot())` (vs.
`qr(A, ColumnNorm())` for pivoting) and `lu(A, RowMaximum())` (vs. `lu(A, NoPivot())`
without pivoting); the former `Val{true/false}`-based calls are deprecated. ([#40623])
* `det(M::AbstractMatrix{BigInt})` now calls `det_bareiss(M)`, which uses the [Bareiss](https://en.wikipedia.org/wiki/Bareiss_algorithm) algorithm to calculate precise values.([#40868]).

#### Markdown
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ export
BunchKaufman,
Cholesky,
CholeskyPivoted,
ColumnNorm,
Eigen,
GeneralizedEigen,
GeneralizedSVD,
GeneralizedSchur,
Hessenberg,
LU,
LDLt,
NoPivot,
QR,
QRPivoted,
LQ,
Schur,
SVD,
Hermitian,
RowMaximum,
Symmetric,
LowerTriangular,
UpperTriangular,
Expand Down Expand Up @@ -164,6 +167,10 @@ abstract type Algorithm end
struct DivideAndConquer <: Algorithm end
struct QRIteration <: Algorithm end

abstract type PivotingStrategy end
struct NoPivot <: PivotingStrategy end
struct RowMaximum <: PivotingStrategy end
struct ColumnNorm <: PivotingStrategy end

# Check that stride of matrix/vector is 1
# Writing like this to avoid splatting penalty when called with multiple arguments,
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ function factorize(A::StridedMatrix{T}) where T
end
return lu(A)
end
qr(A, Val(true))
qr(A, ColumnNorm())
end
factorize(A::Adjoint) = adjoint(factorize(parent(A)))
factorize(A::Transpose) = transpose(factorize(parent(A)))
Expand Down
6 changes: 3 additions & 3 deletions stdlib/LinearAlgebra/src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ size(F::Adjoint{<:Any,<:Factorization}) = reverse(size(parent(F)))
size(F::Transpose{<:Any,<:Factorization}) = reverse(size(parent(F)))

checkpositivedefinite(info) = info == 0 || throw(PosDefException(info))
checknonsingular(info, pivoted::Val{true}) = info == 0 || throw(SingularException(info))
checknonsingular(info, pivoted::Val{false}) = info == 0 || throw(ZeroPivotException(info))
checknonsingular(info) = checknonsingular(info, Val{true}())
checknonsingular(info, ::RowMaximum) = info == 0 || throw(SingularException(info))
checknonsingular(info, ::NoPivot) = info == 0 || throw(ZeroPivotException(info))
checknonsingular(info) = checknonsingular(info, RowMaximum())

"""
issuccess(F::Factorization)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
end
return lu(A) \ B
end
return qr(A,Val(true)) \ B
return qr(A, ColumnNorm()) \ B
end

(\)(a::AbstractVector, b::AbstractArray) = pinv(a) * b
Expand Down
52 changes: 32 additions & 20 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,26 @@ adjoint(F::LU) = Adjoint(F)
transpose(F::LU) = Transpose(F)

# StridedMatrix
function lu!(A::StridedMatrix{T}, pivot::Union{Val{false}, Val{true}} = Val(true);
check::Bool = true) where T<:BlasFloat
if pivot === Val(false)
return generic_lufact!(A, pivot; check = check)
end
lu!(A::StridedMatrix{<:BlasFloat}; check::Bool = true) = lu!(A, RowMaximum(); check=check)
function lu!(A::StridedMatrix{T}, ::RowMaximum; check::Bool = true) where {T<:BlasFloat}
lpt = LAPACK.getrf!(A)
check && checknonsingular(lpt[3])
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
end
function lu!(A::HermOrSym, pivot::Union{Val{false}, Val{true}} = Val(true); check::Bool = true)
function lu!(A::StridedMatrix{<:BlasFloat}, pivot::NoPivot; check::Bool = true)
return generic_lufact!(A, pivot; check = check)
end
function lu!(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true)
copytri!(A.data, A.uplo, isa(A, Hermitian))
lu!(A.data, pivot; check = check)
end
# for backward compatibility
# TODO: remove towards Julia v2
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{true}; check::Bool = true) lu!(A, RowMaximum(); check=check)
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{false}; check::Bool = true) lu!(A, NoPivot(); check=check)

"""
lu!(A, pivot=Val(true); check = true) -> LU
lu!(A, pivot = RowMaximum(); check = true) -> LU
`lu!` is the same as [`lu`](@ref), but saves space by overwriting the
input `A`, instead of creating a copy. An [`InexactError`](@ref)
Expand Down Expand Up @@ -127,19 +131,22 @@ Stacktrace:
[...]
```
"""
lu!(A::StridedMatrix, pivot::Union{Val{false}, Val{true}} = Val(true); check::Bool = true) =
lu!(A::StridedMatrix, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
generic_lufact!(A, pivot; check = check)
function generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot} = Val(true);
check::Bool = true) where {T,Pivot}
function generic_lufact!(A::StridedMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum();
check::Bool = true) where {T}
# Extract values
m, n = size(A)
minmn = min(m,n)

# Initialize variables
info = 0
ipiv = Vector{BlasInt}(undef, minmn)
@inbounds begin
for k = 1:minmn
# find index max
kp = k
if Pivot && k < m
if pivot === RowMaximum() && k < m
amax = abs(A[k, k])
for i = k+1:m
absi = abs(A[i,k])
Expand Down Expand Up @@ -175,7 +182,7 @@ function generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot} = Val(true);
end
end
end
check && checknonsingular(info, Val{Pivot}())
check && checknonsingular(info, pivot)
return LU{T,typeof(A)}(A, ipiv, convert(BlasInt, info))
end

Expand All @@ -200,7 +207,7 @@ end

# for all other types we must promote to a type which is stable under division
"""
lu(A, pivot=Val(true); check = true) -> F::LU
lu(A, pivot = RowMaximum(); check = true) -> F::LU
Compute the LU factorization of `A`.
Expand All @@ -211,7 +218,7 @@ validity (via [`issuccess`](@ref)) lies with the user.
In most cases, if `A` is a subtype `S` of `AbstractMatrix{T}` with an element
type `T` supporting `+`, `-`, `*` and `/`, the return type is `LU{T,S{T}}`. If
pivoting is chosen (default) the element type should also support [`abs`](@ref) and
[`<`](@ref).
[`<`](@ref). Pivoting can be turned off by passing `pivot = NoPivot()`.
The individual components of the factorization `F` can be accessed via [`getproperty`](@ref):
Expand Down Expand Up @@ -267,11 +274,14 @@ julia> l == F.L && u == F.U && p == F.p
true
```
"""
function lu(A::AbstractMatrix{T}, pivot::Union{Val{false}, Val{true}}=Val(true);
check::Bool = true) where T
function lu(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T}
S = lutype(T)
lu!(copy_oftype(A, S), pivot; check = check)
end
# TODO: remove for Julia v2.0
@deprecate lu(A::AbstractMatrix, ::Val{true}; check::Bool = true) lu(A, RowMaximum(); check=check)
@deprecate lu(A::AbstractMatrix, ::Val{false}; check::Bool = true) lu(A, NoPivot(); check=check)


lu(S::LU) = S
function lu(x::Number; check::Bool=true)
Expand Down Expand Up @@ -481,9 +491,11 @@ inv(A::LU{<:BlasFloat,<:StridedMatrix}) = inv!(copy(A))
# Tridiagonal

# See dgttrf.f
function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true);
check::Bool = true) where {T,V}
function lu!(A::Tridiagonal{T,V}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T,V}
# Extract values
n = size(A, 1)

# Initialize variables
info = 0
ipiv = Vector{BlasInt}(undef, n)
dl = A.dl
Expand All @@ -500,7 +512,7 @@ function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true
end
for i = 1:n-2
# pivot or not?
if pivot === Val(false) || abs(d[i]) >= abs(dl[i])
if pivot === NoPivot() || abs(d[i]) >= abs(dl[i])
# No interchange
if d[i] != 0
fact = dl[i]/d[i]
Expand All @@ -523,7 +535,7 @@ function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true
end
if n > 1
i = n-1
if pivot === Val(false) || abs(d[i]) >= abs(dl[i])
if pivot === NoPivot() || abs(d[i]) >= abs(dl[i])
if d[i] != 0
fact = dl[i]/d[i]
dl[i] = fact
Expand Down
29 changes: 18 additions & 11 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,17 +246,17 @@ function qrfactPivotedUnblocked!(A::AbstractMatrix)
end

# LAPACK version
qr!(A::StridedMatrix{<:BlasFloat}, ::Val{false} = Val(false); blocksize=36) =
qr!(A::StridedMatrix{<:BlasFloat}, ::NoPivot; blocksize=36) =
QRCompactWY(LAPACK.geqrt!(A, min(min(size(A)...), blocksize))...)
qr!(A::StridedMatrix{<:BlasFloat}, ::Val{true}) = QRPivoted(LAPACK.geqp3!(A)...)
qr!(A::StridedMatrix{<:BlasFloat}, ::ColumnNorm) = QRPivoted(LAPACK.geqp3!(A)...)

# Generic fallbacks

"""
qr!(A, pivot=Val(false); blocksize)
qr!(A, pivot = NoPivot(); blocksize)
`qr!` is the same as [`qr`](@ref) when `A` is a subtype of
[`StridedMatrix`](@ref), but saves space by overwriting the input `A`, instead of creating a copy.
`qr!` is the same as [`qr`](@ref) when `A` is a subtype of [`StridedMatrix`](@ref),
but saves space by overwriting the input `A`, instead of creating a copy.
An [`InexactError`](@ref) exception is thrown if the factorization produces a number not
representable by the element type of `A`, e.g. for integer types.
Expand Down Expand Up @@ -292,14 +292,17 @@ Stacktrace:
[...]
```
"""
qr!(A::AbstractMatrix, ::Val{false}) = qrfactUnblocked!(A)
qr!(A::AbstractMatrix, ::Val{true}) = qrfactPivotedUnblocked!(A)
qr!(A::AbstractMatrix) = qr!(A, Val(false))
qr!(A::AbstractMatrix, ::NoPivot) = qrfactUnblocked!(A)
qr!(A::AbstractMatrix, ::ColumnNorm) = qrfactPivotedUnblocked!(A)
qr!(A::AbstractMatrix) = qr!(A, NoPivot())
# TODO: Remove in Julia v2.0
@deprecate qr!(A::AbstractMatrix, ::Val{true}) qr!(A, ColumnNorm())
@deprecate qr!(A::AbstractMatrix, ::Val{false}) qr!(A, NoPivot())

_qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))

"""
qr(A, pivot=Val(false); blocksize) -> F
qr(A, pivot = NoPivot(); blocksize) -> F
Compute the QR factorization of the matrix `A`: an orthogonal (or unitary if `A` is
complex-valued) matrix `Q`, and an upper triangular matrix `R` such that
Expand All @@ -310,7 +313,7 @@ A = Q R
The returned object `F` stores the factorization in a packed format:
- if `pivot == Val(true)` then `F` is a [`QRPivoted`](@ref) object,
- if `pivot == ColumnNorm()` then `F` is a [`QRPivoted`](@ref) object,
- otherwise if the element type of `A` is a BLAS type ([`Float32`](@ref), [`Float64`](@ref),
`ComplexF32` or `ComplexF64`), then `F` is a [`QRCompactWY`](@ref) object,
Expand Down Expand Up @@ -340,7 +343,7 @@ and `F.Q*A` are supported. A `Q` matrix can be converted into a regular matrix w
orthogonal matrix.
The block size for QR decomposition can be specified by keyword argument
`blocksize :: Integer` when `pivot == Val(false)` and `A isa StridedMatrix{<:BlasFloat}`.
`blocksize :: Integer` when `pivot == NoPivot()` and `A isa StridedMatrix{<:BlasFloat}`.
It is ignored when `blocksize > minimum(size(A))`. See [`QRCompactWY`](@ref).
!!! compat "Julia 1.4"
Expand Down Expand Up @@ -382,6 +385,10 @@ function qr(A::AbstractMatrix{T}, arg...; kwargs...) where T
copyto!(AA, A)
return qr!(AA, arg...; kwargs...)
end
# TODO: remove in Julia v2.0
@deprecate qr(A::AbstractMatrix, ::Val{false}; kwargs...) qr(A, NoPivot(); kwargs...)
@deprecate qr(A::AbstractMatrix, ::Val{true}; kwargs...) qr(A, ColumnNorm(); kwargs...)

qr(x::Number) = qr(fill(x,1,1))
function qr(v::AbstractVector)
require_one_based_indexing(v)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ end
D = Diagonal(randn(5))
Q = qr(randn(5, 5)).Q
@test D * Q' == Array(D) * Q'
Q = qr(randn(5, 5), Val(true)).Q
Q = qr(randn(5, 5), ColumnNorm()).Q
@test_throws ArgumentError lmul!(Q, D)
end

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,13 @@ LinearAlgebra.Transpose(a::ModInt{n}) where {n} = transpose(a)
A = [ModInt{2}(1) ModInt{2}(0); ModInt{2}(1) ModInt{2}(1)]
b = [ModInt{2}(1), ModInt{2}(0)]

@test A*(lu(A, Val(false))\b) == b
@test A*(lu(A, NoPivot())\b) == b

# Needed for pivoting:
Base.abs(a::ModInt{n}) where {n} = a
Base.:<(a::ModInt{n}, b::ModInt{n}) where {n} = a.k < b.k

@test A*(lu(A, Val(true))\b) == b
@test A*(lu(A, RowMaximum())\b) == b
end

@testset "Issue 18742" begin
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/test/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ rectangularQ(Q::LinearAlgebra.LQPackedQ) = convert(Array, Q)
lqa = lq(a)
x = lqa\b
l,q = lqa.L, lqa.Q
qra = qr(a, Val(true))
qra = qr(a, ColumnNorm())
@testset "Basic ops" begin
@test size(lqa,1) == size(a,1)
@test size(lqa,3) == 1
Expand Down
24 changes: 12 additions & 12 deletions stdlib/LinearAlgebra/test/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dimg = randn(n)/2
lua = factorize(a)
@test_throws ErrorException lua.Z
l,u,p = lua.L, lua.U, lua.p
ll,ul,pl = lu(a)
ll,ul,pl = @inferred lu(a)
@test ll * ul a[pl,:]
@test l*u a[p,:]
@test (l*u)[invperm(p),:] a
Expand All @@ -85,9 +85,9 @@ dimg = randn(n)/2
end
κd = cond(Array(d),1)
@testset "Tridiagonal LU" begin
lud = lu(d)
lud = @inferred lu(d)
@test LinearAlgebra.issuccess(lud)
@test lu(lud) == lud
@test @inferred(lu(lud)) == lud
@test_throws ErrorException lud.Z
@test lud.L*lud.U lud.P*Array(d)
@test lud.L*lud.U Array(d)[lud.p,:]
Expand Down Expand Up @@ -199,14 +199,14 @@ dimg = randn(n)/2
@test lua.L*lua.U lua.P*a[:,1:n1]
end
@testset "Fat LU" begin
lua = lu(a[1:n1,:])
lua = @inferred lu(a[1:n1,:])
@test lua.L*lua.U lua.P*a[1:n1,:]
end
end

@testset "LU of Symmetric/Hermitian" begin
for HS in (Hermitian(a'a), Symmetric(a'a))
luhs = lu(HS)
luhs = @inferred lu(HS)
@test luhs.L*luhs.U luhs.P*Matrix(HS)
end
end
Expand All @@ -229,12 +229,12 @@ end
@test_throws SingularException lu!(copy(A); check = true)
@test !issuccess(lu(A; check = false))
@test !issuccess(lu!(copy(A); check = false))
@test_throws ZeroPivotException lu(A, Val(false))
@test_throws ZeroPivotException lu!(copy(A), Val(false))
@test_throws ZeroPivotException lu(A, Val(false); check = true)
@test_throws ZeroPivotException lu!(copy(A), Val(false); check = true)
@test !issuccess(lu(A, Val(false); check = false))
@test !issuccess(lu!(copy(A), Val(false); check = false))
@test_throws ZeroPivotException lu(A, NoPivot())
@test_throws ZeroPivotException lu!(copy(A), NoPivot())
@test_throws ZeroPivotException lu(A, NoPivot(); check = true)
@test_throws ZeroPivotException lu!(copy(A), NoPivot(); check = true)
@test !issuccess(lu(A, NoPivot(); check = false))
@test !issuccess(lu!(copy(A), NoPivot(); check = false))
F = lu(A; check = false)
@test sprint((io, x) -> show(io, "text/plain", x), F) ==
"Failed factorization of type $(typeof(F))"
Expand Down Expand Up @@ -320,7 +320,7 @@ include("trickyarithmetic.jl")
@testset "lu with type whose sum is another type" begin
A = TrickyArithmetic.A[1 2; 3 4]
ElT = TrickyArithmetic.D{TrickyArithmetic.C,TrickyArithmetic.C}
B = lu(A, Val(false))
B = lu(A, NoPivot())
@test B isa LinearAlgebra.LU{ElT,Matrix{ElT}}
end

Expand Down
Loading

0 comments on commit 8559e47

Please sign in to comment.