diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 7d7f77c282bc0..caa61cf94a52d 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -510,7 +510,9 @@ for (fname, elty) in ((:daxpy_,:Float64), end end end -function axpy!(alpha::Number, x::AbstractArray{T}, y::AbstractArray{T}) where T<:BlasFloat + +#TODO: replace with `x::AbstractArray{T}` once we separate `BLAS.axpy!` and `LinearAlgebra.axpy!` +function axpy!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where T<:BlasFloat if length(x) != length(y) throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) end @@ -582,7 +584,8 @@ for (fname, elty) in ((:daxpby_,:Float64), (:saxpby_,:Float32), end end -function axpby!(alpha::Number, x::AbstractArray{T}, beta::Number, y::AbstractArray{T}) where T<:BlasFloat +#TODO: replace with `x::AbstractArray{T}` once we separate `BLAS.axpby!` and `LinearAlgebra.axpby!` +function axpby!(alpha::Number, x::Union{DenseArray{T},AbstractVector{T}}, beta::Number, y::Union{DenseArray{T},AbstractVector{T}},) where T<:BlasFloat require_one_based_indexing(x, y) if length(x) != length(y) throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index b56edf9439fe0..cd52d30da6c8d 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -295,6 +295,14 @@ end @test LinearAlgebra.axpy!(α, x, rx, y, ry) == [1 1 1 1; 11 1 1 26] end +@testset "LinearAlgebra.axp(b)y! for non strides input" begin + a = rand(5, 5) + @test LinearAlgebra.axpby!(1, Hermitian(a), 1, zeros(size(a))) == Hermitian(a) + @test_broken LinearAlgebra.axpby!(1, 1.:5, 1, zeros(5)) == 1.:5 + @test LinearAlgebra.axpy!(1, Hermitian(a), zeros(size(a))) == Hermitian(a) + @test LinearAlgebra.axpy!(1, 1.:5, zeros(5)) == 1.:5 +end + @testset "norm and normalize!" begin vr = [3.0, 4.0] for Tr in (Float32, Float64)