Skip to content

Commit

Permalink
Remove checks in tril, triu, and diag (#28480)
Browse files Browse the repository at this point in the history
* Stop throwing for out of bounds diagonals in tril, triu and diag
defined in LinearAlgebra

* Remove checks from tril, triu and diag in SparseArrays
  • Loading branch information
andreasnoack authored and KristofferC committed Feb 11, 2019
1 parent 9228291 commit b9ad92e
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 95 deletions.
33 changes: 7 additions & 26 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,10 @@ julia> triu!(M, 1)
function triu!(M::AbstractMatrix, k::Integer)
@assert !has_offset_axes(M)
m, n = size(M)
if !(-m + 1 <= k <= n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-m + 1) and at most $(n + 1) in an $m-by-$n matrix")))
end
idx = 1
for j = 0:n-1
ii = min(max(0, j+1-k), m)
for i = (idx+ii):(idx+m-1)
M[i] = zero(M[i])
for j in 1:min(n, n + k)
for i in max(1, j - k + 1):m
M[i,j] = zero(M[i,j])
end
idx += m
end
M
end
Expand Down Expand Up @@ -216,17 +209,10 @@ julia> tril!(M, 2)
function tril!(M::AbstractMatrix, k::Integer)
@assert !has_offset_axes(M)
m, n = size(M)
if !(-m - 1 <= k <= n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-m - 1) and at most $(n - 1) in an $m-by-$n matrix")))
end
idx = 1
for j = 0:n-1
ii = min(max(0, j-k), m)
for i = idx:(idx+ii-1)
M[i] = zero(M[i])
for j in max(1, k + 1):n
@inbounds for i in 1:min(j - k - 1, m)
M[i,j] = zero(M[i,j])
end
idx += m
end
M
end
Expand All @@ -249,13 +235,8 @@ function fillband!(A::AbstractMatrix{T}, x, l, u) where T
return A
end

function diagind(m::Integer, n::Integer, k::Integer=0)
if !(-m <= k <= n)
throw(ArgumentError(string("requested diagonal, $k, must be at least $(-m) and ",
"at most $n in an $m-by-$n matrix")))
end
diagind(m::Integer, n::Integer, k::Integer=0) =
k <= 0 ? range(1-k, step=m+1, length=min(m+k, n)) : range(k*m+1, step=m+1, length=min(m, n-k))
end

"""
diagind(M, k::Integer=0)
Expand Down
20 changes: 4 additions & 16 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,7 @@ istriu(A::Transpose) = istril(A.parent)

function tril!(A::UpperTriangular, k::Integer=0)
n = size(A,1)
if !(-n - 1 <= k <= n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n - 1) and at most $(n - 1) in an $n-by-$n matrix")))
elseif k < 0
if k < 0
fill!(A.data,0)
return A
elseif k == 0
Expand All @@ -262,10 +259,7 @@ triu!(A::UpperTriangular, k::Integer=0) = UpperTriangular(triu!(A.data,k))

function tril!(A::UnitUpperTriangular{T}, k::Integer=0) where T
n = size(A,1)
if !(-n - 1 <= k <= n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n - 1) and at most $(n - 1) in an $n-by-$n matrix")))
elseif k < 0
if k < 0
fill!(A.data, zero(T))
return UpperTriangular(A.data)
elseif k == 0
Expand All @@ -291,10 +285,7 @@ end

function triu!(A::LowerTriangular, k::Integer=0)
n = size(A,1)
if !(-n + 1 <= k <= n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n + 1) and at most $(n + 1) in an $n-by-$n matrix")))
elseif k > 0
if k > 0
fill!(A.data,0)
return A
elseif k == 0
Expand All @@ -311,10 +302,7 @@ tril!(A::LowerTriangular, k::Integer=0) = LowerTriangular(tril!(A.data,k))

function triu!(A::UnitLowerTriangular{T}, k::Integer=0) where T
n = size(A,1)
if !(-n + 1 <= k <= n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-n + 1) and at most $(n + 1) in an $n-by-$n matrix")))
elseif k > 0
if k > 0
fill!(A.data, zero(T))
return LowerTriangular(A.data)
elseif k == 0
Expand Down
26 changes: 13 additions & 13 deletions stdlib/LinearAlgebra/test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ bimg = randn(n,2)/2
end
end # for eltya

@testset "test triu/tril bounds checking" begin
@testset "test out of bounds triu/tril" begin
local m, n = 5, 7
ainit = rand(m, n)
for a in (copy(ainit), view(ainit, 1:m, 1:n))
@test_throws ArgumentError triu(a, -m)
@test_throws ArgumentError triu(a, n + 2)
@test_throws ArgumentError tril(a, -m - 2)
@test_throws ArgumentError tril(a, n)
@test triu(a, -m) == a
@test triu(a, n + 2) == zero(a)
@test tril(a, -m - 2) == zero(a)
@test tril(a, n) == a
end
end

Expand Down Expand Up @@ -651,7 +651,7 @@ end
2 6 10
3 7 11
4 8 12 ]
@test_throws ArgumentError diag(A, -5)
@test diag(A,-5) == []
@test diag(A,-4) == []
@test diag(A,-3) == [4]
@test diag(A,-2) == [3,8]
Expand All @@ -660,21 +660,21 @@ end
@test diag(A, 1) == [5,10]
@test diag(A, 2) == [9]
@test diag(A, 3) == []
@test_throws ArgumentError diag(A, 4)
@test diag(A, 4) == []

@test diag(zeros(0,0)) == []
@test_throws ArgumentError diag(zeros(0,0),1)
@test_throws ArgumentError diag(zeros(0,0),-1)
@test diag(zeros(0,0),1) == []
@test diag(zeros(0,0),-1) == []

@test diag(zeros(1,0)) == []
@test diag(zeros(1,0),-1) == []
@test_throws ArgumentError diag(zeros(1,0),1)
@test_throws ArgumentError diag(zeros(1,0),-2)
@test diag(zeros(1,0),1) == []
@test diag(zeros(1,0),-2) == []

@test diag(zeros(0,1)) == []
@test diag(zeros(0,1),1) == []
@test_throws ArgumentError diag(zeros(0,1),-1)
@test_throws ArgumentError diag(zeros(0,1),2)
@test diag(zeros(0,1),-1) == []
@test diag(zeros(0,1),2) == []
end

@testset "Matrix to real power" for elty in (Float64, Complex{Float64})
Expand Down
16 changes: 8 additions & 8 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,24 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo
@test tril(A1,0) == A1
@test tril(A1,-1) == LowerTriangular(tril(Matrix(A1), -1))
@test tril(A1,1) == t1(tril(tril(Matrix(A1), 1)))
@test_throws ArgumentError tril!(A1, -n - 2)
@test_throws ArgumentError tril!(A1, n)
@test tril(A1, -n - 2) == zeros(size(A1))
@test tril(A1, n) == A1
@test triu(A1,0) == t1(diagm(0 => diag(A1)))
@test triu(A1,-1) == t1(tril(triu(A1.data,-1)))
@test triu(A1,1) == zeros(size(A1)) # or just @test iszero(triu(A1,1))?
@test_throws ArgumentError triu!(A1, -n)
@test_throws ArgumentError triu!(A1, n + 2)
@test triu(A1, -n) == A1
@test triu(A1, n + 2) == zeros(size(A1))
else
@test triu(A1,0) == A1
@test triu(A1,1) == UpperTriangular(triu(Matrix(A1), 1))
@test triu(A1,-1) == t1(triu(triu(Matrix(A1), -1)))
@test_throws ArgumentError triu!(A1, -n)
@test_throws ArgumentError triu!(A1, n + 2)
@test triu(A1, -n) == A1
@test triu(A1, n + 2) == zeros(size(A1))
@test tril(A1,0) == t1(diagm(0 => diag(A1)))
@test tril(A1,1) == t1(triu(tril(A1.data,1)))
@test tril(A1,-1) == zeros(size(A1)) # or just @test iszero(tril(A1,-1))?
@test_throws ArgumentError tril!(A1, -n - 2)
@test_throws ArgumentError tril!(A1, n)
@test tril(A1, -n - 2) == zeros(size(A1))
@test tril(A1, n) == A1
end

# factorize
Expand Down
8 changes: 0 additions & 8 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,6 @@ rdiv!(A::SparseMatrixCSC{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =

function triu(S::SparseMatrixCSC{Tv,Ti}, k::Integer=0) where {Tv,Ti}
m,n = size(S)
if !(-m + 1 <= k <= n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-m + 1) and at most $(n + 1) in an $m-by-$n matrix")))
end
colptr = Vector{Ti}(undef, n+1)
nnz = 0
for col = 1 : min(max(k+1,1), n+1)
Expand Down Expand Up @@ -504,10 +500,6 @@ end

function tril(S::SparseMatrixCSC{Tv,Ti}, k::Integer=0) where {Tv,Ti}
m,n = size(S)
if !(-m - 1 <= k <= n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-m - 1) and at most $(n - 1) in an $m-by-$n matrix")))
end
colptr = Vector{Ti}(undef, n+1)
nnz = 0
colptr[1] = 1
Expand Down
18 changes: 2 additions & 16 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1251,20 +1251,10 @@ function fkeep!(A::SparseMatrixCSC, f, trim::Bool = true)
A
end

function tril!(A::SparseMatrixCSC, k::Integer = 0, trim::Bool = true)
if !(-A.m - 1 <= k <= A.n - 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-A.m - 1) and at most $(A.n - 1) in an $(A.m)-by-$(A.n) matrix")))
end
tril!(A::SparseMatrixCSC, k::Integer = 0, trim::Bool = true) =
fkeep!(A, (i, j, x) -> i + k >= j, trim)
end
function triu!(A::SparseMatrixCSC, k::Integer = 0, trim::Bool = true)
if !(-A.m + 1 <= k <= A.n + 1)
throw(ArgumentError(string("the requested diagonal, $k, must be at least ",
"$(-A.m + 1) and at most $(A.n + 1) in an $(A.m)-by-$(A.n) matrix")))
end
triu!(A::SparseMatrixCSC, k::Integer = 0, trim::Bool = true) =
fkeep!(A, (i, j, x) -> j >= i + k, trim)
end

droptol!(A::SparseMatrixCSC, tol; trim::Bool = true) =
fkeep!(A, (i, j, x) -> abs(x) > tol, trim)
Expand Down Expand Up @@ -3354,10 +3344,6 @@ end
function diag(A::SparseMatrixCSC{Tv,Ti}, d::Integer=0) where {Tv,Ti}
m, n = size(A)
k = Int(d)
if !(-m <= k <= n)
throw(ArgumentError(string("requested diagonal, $k, must be at least $(-m) ",
"and at most $n in an $m-by-$n matrix")))
end
l = k < 0 ? min(m+k,n) : min(n-k,m)
r, c = k <= 0 ? (-k, 0) : (0, k) # start row/col -1
ind = Vector{Ti}()
Expand Down
12 changes: 4 additions & 8 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1553,14 +1553,10 @@ end
@test Array(tril(A,1)) == tril(AF,1)
@test Array(triu!(copy(A), 2)) == triu(AF,2)
@test Array(tril!(copy(A), 2)) == tril(AF,2)
@test_throws ArgumentError tril(A, -n - 2)
@test_throws ArgumentError tril(A, n)
@test_throws ArgumentError triu(A, -n)
@test_throws ArgumentError triu(A, n + 2)
@test_throws ArgumentError tril!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), -5)
@test_throws ArgumentError tril!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), 4)
@test_throws ArgumentError triu!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), -3)
@test_throws ArgumentError triu!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), 6)
@test tril(A, -n - 2) == zero(A)
@test tril(A, n) == A
@test triu(A, -n) == A
@test triu(A, n + 2) == zero(A)

# fkeep trim option
@test isequal(length(tril!(sparse([1,2,3], [1,2,3], [1,2,3], 3, 4), -1).rowval), 0)
Expand Down

0 comments on commit b9ad92e

Please sign in to comment.