Skip to content

Commit

Permalink
Aggressive constprop in istriu/istril for structured matrices (#54437)
Browse files Browse the repository at this point in the history
This makes the following evaluate at compile-time:
```julia
julia> U = UpperTriangular(rand(2,2));

julia> @code_typed istriu(U)
CodeInfo(
1 ─     return true
) => Bool
```
Also, this reduces latency in this operation:
```julia
julia> @time (U -> istriu(U))(U)
  0.069995 seconds (158.88 k allocations: 8.715 MiB, 83.72% compilation time) # nightly
  0.035610 seconds (156.62 k allocations: 8.594 MiB, 68.18% compilation time) # This PR
```

Similar methods are annotated for other structured matrix types, where
the results may be trivially obtained from the structure for certain
values of the band index `k`.
  • Loading branch information
jishnub authored May 12, 2024
1 parent d01d256 commit 25c8128
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ end

iszero(M::Bidiagonal) = iszero(M.dv) && iszero(M.ev)
isone(M::Bidiagonal) = all(isone, M.dv) && iszero(M.ev)
function istriu(M::Bidiagonal, k::Integer=0)
Base.@constprop :aggressive function istriu(M::Bidiagonal, k::Integer=0)
if M.uplo == 'U'
if k <= 0
return true
Expand All @@ -328,7 +328,7 @@ function istriu(M::Bidiagonal, k::Integer=0)
end
end
end
function istril(M::Bidiagonal, k::Integer=0)
Base.@constprop :aggressive function istril(M::Bidiagonal, k::Integer=0)
if M.uplo == 'U'
if k >= 1
return true
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ iszero(D::Diagonal) = all(iszero, D.diag)
isone(D::Diagonal) = all(isone, D.diag)
isdiag(D::Diagonal) = all(isdiag, D.diag)
isdiag(D::Diagonal{<:Number}) = true
istriu(D::Diagonal, k::Integer=0) = k <= 0 || iszero(D.diag) ? true : false
istril(D::Diagonal, k::Integer=0) = k >= 0 || iszero(D.diag) ? true : false
Base.@constprop :aggressive istriu(D::Diagonal, k::Integer=0) = k <= 0 || iszero(D.diag) ? true : false
Base.@constprop :aggressive istril(D::Diagonal, k::Integer=0) = k >= 0 || iszero(D.diag) ? true : false
function triu!(D::Diagonal{T}, k::Integer=0) where T
n = size(D,1)
if !(-n + 1 <= k <= n + 1)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ real(H::UpperHessenberg{<:Real}) = H
real(H::UpperHessenberg{<:Complex}) = UpperHessenberg(triu!(real(H.data),-1))
imag(H::UpperHessenberg) = UpperHessenberg(triu!(imag(H.data),-1))

function istriu(A::UpperHessenberg, k::Integer=0)
Base.@constprop :aggressive function istriu(A::UpperHessenberg, k::Integer=0)
k <= -1 && return true
return _istriu(A, k)
end
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,11 @@ function Base.replace_in_print_matrix(A::Union{LowerTriangular,UnitLowerTriangul
return i >= j ? s : Base.replace_with_centered_mark(s)
end

function istril(A::Union{LowerTriangular,UnitLowerTriangular}, k::Integer=0)
Base.@constprop :aggressive function istril(A::Union{LowerTriangular,UnitLowerTriangular}, k::Integer=0)
k >= 0 && return true
return _istril(A, k)
end
function istriu(A::Union{UpperTriangular,UnitUpperTriangular}, k::Integer=0)
Base.@constprop :aggressive function istriu(A::Union{UpperTriangular,UnitUpperTriangular}, k::Integer=0)
k <= 0 && return true
return _istriu(A, k)
end
Expand Down
8 changes: 4 additions & 4 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ end

# tril and triu

function istriu(M::SymTridiagonal, k::Integer=0)
Base.@constprop :aggressive function istriu(M::SymTridiagonal, k::Integer=0)
if k <= -1
return true
elseif k == 0
Expand All @@ -341,7 +341,7 @@ function istriu(M::SymTridiagonal, k::Integer=0)
return iszero(_evview(M)) && iszero(M.dv)
end
end
istril(M::SymTridiagonal, k::Integer) = istriu(M, -k)
Base.@constprop :aggressive istril(M::SymTridiagonal, k::Integer) = istriu(M, -k)
iszero(M::SymTridiagonal) = iszero(_evview(M)) && iszero(M.dv)
isone(M::SymTridiagonal) = iszero(_evview(M)) && all(isone, M.dv)
isdiag(M::SymTridiagonal) = iszero(_evview(M))
Expand Down Expand Up @@ -718,7 +718,7 @@ end

iszero(M::Tridiagonal) = iszero(M.dl) && iszero(M.d) && iszero(M.du)
isone(M::Tridiagonal) = iszero(M.dl) && all(isone, M.d) && iszero(M.du)
function istriu(M::Tridiagonal, k::Integer=0)
Base.@constprop :aggressive function istriu(M::Tridiagonal, k::Integer=0)
if k <= -1
return true
elseif k == 0
Expand All @@ -729,7 +729,7 @@ function istriu(M::Tridiagonal, k::Integer=0)
return iszero(M.dl) && iszero(M.d) && iszero(M.du)
end
end
function istril(M::Tridiagonal, k::Integer=0)
Base.@constprop :aggressive function istril(M::Tridiagonal, k::Integer=0)
if k >= 1
return true
elseif k == 0
Expand Down

0 comments on commit 25c8128

Please sign in to comment.