Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Test a weak dependency #521

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.58"
version = "0.10.59"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -21,6 +21,12 @@ TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
KernelFunctionsTestExt = "Test"

[compat]
ChainRulesCore = "1"
Compat = "3.7, 4"
Expand Down
283 changes: 283 additions & 0 deletions ext/KernelFunctionsTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
module KernelFunctionsTestExt

using KernelFunctions
using KernelFunctions: TestUtils, LinearAlgebra, Random
using Test

"""
test_interface(
k::Kernel,
x0::AbstractVector,
x1::AbstractVector,
x2::AbstractVector;
rtol=1e-6,
atol=rtol,
)

Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`.
`x0` and `x1` should be of the same length with different values, while `x0` and `x2` should
be of different lengths.

These tests are intended to pick up on really substantial issues with a kernel implementation
(e.g. substantial asymmetry in the kernel matrix, large negative eigenvalues), rather than to
test the numerics in detail, which can be kernel-specific.
"""
function TestUtils.test_interface(

Check warning on line 25 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L25

Added line #L25 was not covered by tests
k::Kernel,
x0::AbstractVector,
x1::AbstractVector,
x2::AbstractVector;
rtol=1e-6,
atol=rtol,
)
# Ensure that we have the required inputs.
@assert length(x0) == length(x1)
@assert length(x0) ≠ length(x2)

Check warning on line 35 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L34-L35

Added lines #L34 - L35 were not covered by tests

# Check that kernelmatrix_diag basically works.
@test kernelmatrix_diag(k, x0, x1) isa AbstractVector
@test length(kernelmatrix_diag(k, x0, x1)) == length(x0)

Check warning on line 39 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L38-L39

Added lines #L38 - L39 were not covered by tests

# Check that pairwise basically works.
@test kernelmatrix(k, x0, x2) isa AbstractMatrix
@test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2))

Check warning on line 43 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L42-L43

Added lines #L42 - L43 were not covered by tests

# Check that elementwise is consistent with pairwise.
@test kernelmatrix_diag(k, x0, x1) ≈ LinearAlgebra.diag(kernelmatrix(k, x0, x1)) atol =

Check warning on line 46 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L46

Added line #L46 was not covered by tests
atol rtol = rtol

# Check additional binary elementwise properties for kernels.
@test kernelmatrix_diag(k, x0, x1) ≈ kernelmatrix_diag(k, x1, x0)
@test kernelmatrix(k, x0, x2) ≈ permutedims(kernelmatrix(k, x2, x0)) atol = atol rtol =

Check warning on line 51 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
rtol

# Check that unary elementwise basically works.
@test kernelmatrix_diag(k, x0) isa AbstractVector
@test length(kernelmatrix_diag(k, x0)) == length(x0)

Check warning on line 56 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L55-L56

Added lines #L55 - L56 were not covered by tests

# Check that unary pairwise basically works.
@test kernelmatrix(k, x0) isa AbstractMatrix
@test size(kernelmatrix(k, x0)) == (length(x0), length(x0))
@test kernelmatrix(k, x0) ≈ permutedims(kernelmatrix(k, x0)) atol = atol rtol = rtol

Check warning on line 61 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L59-L61

Added lines #L59 - L61 were not covered by tests

# Check that unary elementwise is consistent with unary pairwise.
@test kernelmatrix_diag(k, x0) ≈ LinearAlgebra.diag(kernelmatrix(k, x0)) atol = atol rtol =

Check warning on line 64 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L64

Added line #L64 was not covered by tests
rtol

# Check that unary pairwise produces a positive definite matrix (approximately).
@test LinearAlgebra.eigmin(Matrix(kernelmatrix(k, x0))) > -atol

Check warning on line 68 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L68

Added line #L68 was not covered by tests

# Check that unary elementwise / pairwise are consistent with the binary versions.
@test kernelmatrix_diag(k, x0) ≈ kernelmatrix_diag(k, x0, x0) atol = atol rtol = rtol
@test kernelmatrix(k, x0) ≈ kernelmatrix(k, x0, x0) atol = atol rtol = rtol

Check warning on line 72 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L71-L72

Added lines #L71 - L72 were not covered by tests

# Check that basic kernel evaluation succeeds and is consistent with `kernelmatrix`.
@test k(first(x0), first(x1)) isa Real
@test kernelmatrix(k, x0, x2) ≈ [k(xl, xr) for xl in x0, xr in x2]

Check warning on line 76 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L75-L76

Added lines #L75 - L76 were not covered by tests

tmp = Matrix{Float64}(undef, length(x0), length(x2))
@test kernelmatrix!(tmp, k, x0, x2) ≈ kernelmatrix(k, x0, x2)

Check warning on line 79 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L78-L79

Added lines #L78 - L79 were not covered by tests

tmp_square = Matrix{Float64}(undef, length(x0), length(x0))
@test kernelmatrix!(tmp_square, k, x0) ≈ kernelmatrix(k, x0)

Check warning on line 82 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L81-L82

Added lines #L81 - L82 were not covered by tests

tmp_diag = Vector{Float64}(undef, length(x0))
@test kernelmatrix_diag!(tmp_diag, k, x0) ≈ kernelmatrix_diag(k, x0)
@test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1)

Check warning on line 86 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L84-L86

Added lines #L84 - L86 were not covered by tests
end

"""
test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}=Float64; kwargs...) where {T}

Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`,
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.

For other input types, please provide the data manually.

The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the
randomly generated inputs.
"""
function TestUtils.test_interface(k::Kernel, T::Type=Float64; kwargs...)
return TestUtils.test_interface(Random.default_rng(), k, T; kwargs...)

Check warning on line 101 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L100-L101

Added lines #L100 - L101 were not covered by tests
end

function TestUtils.test_interface(

Check warning on line 104 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L104

Added line #L104 was not covered by tests
rng::Random.AbstractRNG, k::Kernel, T::Type=Float64; kwargs...
)
return TestUtils.test_with_type(TestUtils.test_interface, rng, k, T; kwargs...)

Check warning on line 107 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L107

Added line #L107 was not covered by tests
end

"""
test_type_stability(
k::Kernel,
x0::AbstractVector,
x1::AbstractVector,
x2::AbstractVector,
)

Run type stability checks over `k(x,y)` and the different functions of the API
(`kernelmatrix`, `kernelmatrix_diag`). `x0` and `x1` should be of the same
length with different values, while `x0` and `x2` should be of different lengths.
"""
function TestUtils.test_type_stability(

Check warning on line 122 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L122

Added line #L122 was not covered by tests
k::Kernel, x0::AbstractVector, x1::AbstractVector, x2::AbstractVector
)
# Ensure that we have the required inputs.
@assert length(x0) == length(x1)
@assert length(x0) ≠ length(x2)
@test @inferred(kernelmatrix(k, x0)) isa AbstractMatrix
@test @inferred(kernelmatrix(k, x0, x2)) isa AbstractMatrix
@test @inferred(kernelmatrix_diag(k, x0)) isa AbstractVector
@test @inferred(kernelmatrix_diag(k, x0, x1)) isa AbstractVector

Check warning on line 131 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L126-L131

Added lines #L126 - L131 were not covered by tests
end

function TestUtils.test_type_stability(k::Kernel, ::Type{T}=Float64; kwargs...) where {T}
return TestUtils.test_type_stability(Random.default_rng(), k, T; kwargs...)

Check warning on line 135 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L134-L135

Added lines #L134 - L135 were not covered by tests
end

function TestUtils.test_type_stability(

Check warning on line 138 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L138

Added line #L138 was not covered by tests
rng::Random.AbstractRNG, k::Kernel, ::Type{T}; kwargs...
) where {T}
return TestUtils.test_with_type(TestUtils.test_type_stability, rng, k, T; kwargs...)

Check warning on line 141 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L141

Added line #L141 was not covered by tests
end

"""
test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T}

Run the functions `f`, (for example [`test_interface`](@ref) or
[`test_type_stable`](@ref)) for randomly generated inputs of types `Vector{T}`,
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.

For other input types, please provide the data manually.

The keyword arguments are forwarded to the invocations of `f` with the
randomly generated inputs.
"""
function TestUtils.test_with_type(

Check warning on line 156 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L156

Added line #L156 was not covered by tests
f, rng::Random.AbstractRNG, k::Kernel, ::Type{T}; kwargs...
) where {T}
@testset "Vector{$T}" begin
TestUtils.test_with_type(f, rng, k, Vector{T}; kwargs...)

Check warning on line 160 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L159-L160

Added lines #L159 - L160 were not covered by tests
end
@testset "ColVecs{$T}" begin
TestUtils.test_with_type(f, rng, k, ColVecs{T}; kwargs...)

Check warning on line 163 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L162-L163

Added lines #L162 - L163 were not covered by tests
end
@testset "RowVecs{$T}" begin
TestUtils.test_with_type(f, rng, k, RowVecs{T}; kwargs...)

Check warning on line 166 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
end
@testset "Vector{Vector{$T}}" begin
TestUtils.test_with_type(f, rng, k, Vector{Vector{T}}; kwargs...)

Check warning on line 169 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L168-L169

Added lines #L168 - L169 were not covered by tests
end
end

function TestUtils.test_with_type(

Check warning on line 173 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L173

Added line #L173 was not covered by tests
f, rng::Random.AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs...
) where {T<:Real}
return f(k, randn(rng, T, 11), randn(rng, T, 11), randn(rng, T, 13); kwargs...)

Check warning on line 176 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L176

Added line #L176 was not covered by tests
end

function TestUtils.test_with_type(

Check warning on line 179 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L179

Added line #L179 was not covered by tests
f,
rng::Random.AbstractRNG,
k::MOKernel,
::Type{Vector{Tuple{T,Int}}};
dim_out=3,
kwargs...,
) where {T<:Real}
return f(

Check warning on line 187 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L187

Added line #L187 was not covered by tests
k,
[(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11],
[(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11],
[(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:13];
kwargs...,
)
end

function TestUtils.test_with_type(

Check warning on line 196 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L196

Added line #L196 was not covered by tests
f, rng::Random.AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs...
) where {T<:Real}
return f(

Check warning on line 199 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L199

Added line #L199 was not covered by tests
k,
ColVecs(randn(rng, T, dim_in, 11)),
ColVecs(randn(rng, T, dim_in, 11)),
ColVecs(randn(rng, T, dim_in, 13));
kwargs...,
)
end

function TestUtils.test_with_type(

Check warning on line 208 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L208

Added line #L208 was not covered by tests
f, rng::Random.AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs...
) where {T<:Real}
return f(

Check warning on line 211 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L211

Added line #L211 was not covered by tests
k,
RowVecs(randn(rng, T, 11, dim_in)),
RowVecs(randn(rng, T, 11, dim_in)),
RowVecs(randn(rng, T, 13, dim_in));
kwargs...,
)
end

function TestUtils.test_with_type(

Check warning on line 220 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L220

Added line #L220 was not covered by tests
f, rng::Random.AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs...
) where {T<:Real}
return f(

Check warning on line 223 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L223

Added line #L223 was not covered by tests
k,
[randn(rng, T, dim_in) for _ in 1:11],
[randn(rng, T, dim_in) for _ in 1:11],
[randn(rng, T, dim_in) for _ in 1:13];
kwargs...,
)
end

function TestUtils.test_with_type(

Check warning on line 232 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L232

Added line #L232 was not covered by tests
f, rng::Random.AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...
)
return f(

Check warning on line 235 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L235

Added line #L235 was not covered by tests
k,
[Random.randstring(rng) for _ in 1:3],
[Random.randstring(rng) for _ in 1:3],
[Random.randstring(rng) for _ in 1:4];
kwargs...,
)
end

function test_with_type(

Check warning on line 244 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L244

Added line #L244 was not covered by tests
f, rng::Random.AbstractRNG, k::Kernel, ::Type{ColVecs{String}}; dim_in=2, kwargs...
)
return f(

Check warning on line 247 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L247

Added line #L247 was not covered by tests
k,
ColVecs([Random.randstring(rng) for _ in 1:dim_in, _ in 1:3]),
ColVecs([Random.randstring(rng) for _ in 1:dim_in, _ in 1:3]),
ColVecs([Random.randstring(rng) for _ in 1:dim_in, _ in 1:4]);
kwargs...,
)
end

function TestUtils.test_with_type(f, k::Kernel, T::Type{<:Real}; kwargs...)
return TestUtils.test_with_type(f, Random.default_rng(), k, T; kwargs...)

Check warning on line 257 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L256-L257

Added lines #L256 - L257 were not covered by tests
end

"""
example_inputs(rng::AbstractRNG, type)

Return a tuple of 4 inputs of type `type`. See `methods(example_inputs)` for information
around supported types. It is recommended that you utilise `StableRNGs.jl` for `rng` here
to ensure consistency across Julia versions.
"""
function TestUtils.example_inputs(rng::Random.AbstractRNG, ::Type{Vector{Float64}})
return map(n -> randn(rng, Float64, n), (1, 2, 3, 4))

Check warning on line 268 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L267-L268

Added lines #L267 - L268 were not covered by tests
end

function TestUtils.example_inputs(

Check warning on line 271 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L271

Added line #L271 was not covered by tests
rng::Random.AbstractRNG, ::Type{ColVecs{Float64,Matrix{Float64}}}; dim::Int=2
)
return map(n -> ColVecs(randn(rng, dim, n)), (1, 2, 3, 4))

Check warning on line 274 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L274

Added line #L274 was not covered by tests
end

function TestUtils.example_inputs(

Check warning on line 277 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L277

Added line #L277 was not covered by tests
rng::Random.AbstractRNG, ::Type{RowVecs{Float64,Matrix{Float64}}}; dim::Int=2
)
return map(n -> RowVecs(randn(rng, n, dim)), (1, 2, 3, 4))

Check warning on line 280 in ext/KernelFunctionsTestExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/KernelFunctionsTestExt.jl#L280

Added line #L280 was not covered by tests
end

end # module
4 changes: 4 additions & 0 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ using CompositionsBase
using Distances
using FillArrays
using Functors
using Random
using LinearAlgebra
using Requires
using SpecialFunctions: loggamma, besselk, polygamma
Expand Down Expand Up @@ -125,6 +126,9 @@ include("chainrules.jl")
include("zygoterules.jl")

include("TestUtils.jl")
if !isdefined(Base, :get_extension)
include("../ext/KernelFunctionsTestExt.jl")
end

function __init__()
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
Expand Down
Loading
Loading