From 0ea9693f80ef986967c96fd2512f5868b0088217 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Sat, 9 Nov 2024 20:14:31 +0100 Subject: [PATCH] improved function partial application design This replaces `Fix` (xref #54653) with `fix`. The usage is similar: use `fix(i)(f, x)` instead of `Fix{i}(f, x)`. Benefits: * Improved type safety: creating an invalid type such as `Fix{:some_symbol}` or `Fix{-7}` is not possible. * The design should be friendlier to future extensions. E.g., suppose that publicly-facing functionality for fixing a keyword (instead of positional) argument was desired, it could be achieved by adding a new method to `fix` taking a `Symbol`, instead of adding new public names. Lots of changes are shared with PR #56425, if one of them gets merged the other will be greatly simplified. --- NEWS.md | 2 +- base/Base_compiler.jl | 1 + base/operators.jl | 158 ++++++++++++++++++++++++++++-------- base/public.jl | 2 +- base/tuple.jl | 19 +++-- base/typedomainnumbers.jl | 167 ++++++++++++++++++++++++++++++++++++++ doc/src/base/base.md | 2 +- stdlib/REPL/test/repl.jl | 8 +- test/choosetests.jl | 2 +- test/functional.jl | 57 +++++++------ test/typedomainnumbers.jl | 34 ++++++++ 11 files changed, 378 insertions(+), 74 deletions(-) create mode 100644 base/typedomainnumbers.jl create mode 100644 test/typedomainnumbers.jl diff --git a/NEWS.md b/NEWS.md index 74cda05e9d0e15..8f017dca6585b7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -87,7 +87,7 @@ New library functions * `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]). * `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]). * `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims` -* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]). +* `Fix1`/`Fix2` are now generalized by `fix` ([#54653], [#56518]). New library features -------------------- diff --git a/base/Base_compiler.jl b/base/Base_compiler.jl index a8604144546347..8a84e87c1f0b10 100644 --- a/base/Base_compiler.jl +++ b/base/Base_compiler.jl @@ -213,6 +213,7 @@ include("error.jl") include("bool.jl") include("number.jl") include("int.jl") +include("typedomainnumbers.jl") include("operators.jl") include("pointer.jl") include("refvalue.jl") diff --git a/base/operators.jl b/base/operators.jl index d01902e3023596..926bbc92e4db9c 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -1153,55 +1153,149 @@ julia> filter(!isletter, str) !(f::Function) = (!) ∘ f !(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f +const _PositiveInteger = _TypeDomainNumbers.PositiveIntegers.PositiveInteger + +struct PartiallyAppliedFunction{Position <: _PositiveInteger, Func, Arg} <: Function + partially_applied_argument_position::Position + f::Func + x::Arg + + function (::Type{PartiallyAppliedFunction{Position}})(func::Func, arg) where {Position <: _PositiveInteger, Func} + Pos = Position::DataType + pos = Pos.instance + new{Pos, _stable_typeof(f), _stable_typeof(x)}(pos, func, arg) + end +end + +function getproperty((@nospecialize v::PartiallyAppliedFunction), s::Symbol) + getfield(v, s) +end # avoid overspecialization + +function Base.show( + (@nospecialize io::Base.IO), + (@nospecialize unused::Type{PartiallyAppliedFunction{Position}}), +) where {Position <: _PositiveInteger} + if Position isa DataType + print(io, "fix(") + show(io, Position.instance) + print(io, ')') + else + show(io, PartiallyAppliedFunction) + print(io, '{') + show(io, Position) + print(io, '}') + end +end + +function Base.show( + (@nospecialize io::Base.IO), + (@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func}}), +) where {Position <: _PositiveInteger, Func} + show(io, PartiallyAppliedFunction{Position}) + print(io, '{') + show(io, Func) + print(io, '}') +end + +function Base.show( + (@nospecialize io::Base.IO), + (@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func, Arg}}), +) where {Position <: _PositiveInteger, Func, Arg} + show(io, PartiallyAppliedFunction{Position, Func}) + print(io, '{') + show(io, Arg) + print(io, '}') +end + +function Base.show((@nospecialize io::Base.IO), @nospecialize p::PartiallyAppliedFunction) + print(io, "fix(") + show(io, p.partially_applied_argument_position) + print(io, ")(") + show(io, p.f) + print(io, ", ") + show(io, p.x) + print(io, ')') +end + +function _partially_applied_function_check(m::Int, nm1::Int) + if m < nm1 + throw(ArgumentError(LazyString("expected at least ", nm1, " arguments to `fix(", nm1 + 1, ")`, but got ", m))) + end +end + +function (partial::PartiallyAppliedFunction)(args::Vararg{Any,M}; kws...) where {M} + n = partial.partially_applied_argument_position + nm1 = _TypeDomainNumbers.PositiveIntegers.natural_predecessor(n) + _partially_applied_function_check(M, Int(nm1)) + (args_left, args_right) = _TypeDomainNumberTupleUtils.split_tuple(args, nm1) + partial.f(args_left..., partial.x, args_right...; kws...) +end + """ - Fix{N}(f, x) + fix(::Integer)::UnionAll + +Return a [`UnionAll`](@ref) type such that: +* It's a constructor taking two arguments: + 1. A function to be partially applied + 2. An argument of the above function to be fixed +* Its instances are partial applications of the function, with one positional argument fixed. The argument to `fix` is the one-based index of the position argument to be fixed. -A type representing a partially-applied version of a function `f`, with the argument -`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to -`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`. +For example, `fix(3)(f, x)` behaves similarly to `(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`. + +See also: [`Fix1`](@ref), [`Fix2`](@ref). !!! compat "Julia 1.12" - This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2` - are available earlier. + Requires at least Julia 1.12 (`Fix1` and `Fix2` are available earlier, too). !!! note - When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current + When nesting multiple `fix`, note that the `n` in `fix(n)` is _relative_ to the current available arguments, rather than an absolute ordering on the target function. For example, - `Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)` + `fix(1)(fix(2)(f, 4), 4)` fixes the first and second arg, while `fix(2)(fix(1)(f, 4), 4)` fixes the first and third arg. -""" -struct Fix{N,F,T} <: Function - f::F - x::T - function Fix{N}(f::F, x) where {N,F} - if !(N isa Int) - throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`"))) - elseif N < 1 - throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N))) - end - new{N,_stable_typeof(f),_stable_typeof(x)}(f, x) - end -end +### Examples -function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M} - M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M))) - return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...) -end +```jldoctest +julia> Base.fix(2)(-, 3)(7) +4 -# Special cases for improved constant propagation -(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...) -(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...) +julia> Base.fix(2) === Base.Fix2 +true +julia> Base.fix(1)(Base.fix(2)(muladd, 3), 2)(5) === (x -> muladd(2, 3, x))(5) +true +``` """ -Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix). +function fix(@nospecialize m::Integer) + n = Int(m)::Int + if n ≤ 0 + throw(ArgumentError("the index of the partially applied argument must be positive")) + end + k = _TypeDomainNumbers.Utils.from_abs_int(n) + PartiallyAppliedFunction{typeof(k)} +end + +""" + Fix1::UnionAll + +[`fix(1)`](@ref Base.fix). """ -const Fix1{F,T} = Fix{1,F,T} +const Fix1 = fix(1) """ -Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix). + Fix2::UnionAll + +[`fix(2)`](@ref Base.fix). """ -const Fix2{F,T} = Fix{2,F,T} +const Fix2 = fix(2) + +# Special cases for improved constant propagation +function (partial::Fix1)(x; kws...) + partial.f(partial.x, x; kws...) +end +function (partial::Fix2)(x; kws...) + partial.f(x, partial.x; kws...) +end """ diff --git a/base/public.jl b/base/public.jl index 8777a454c920ac..f81c7de9c229e6 100644 --- a/base/public.jl +++ b/base/public.jl @@ -14,7 +14,7 @@ public AsyncCondition, CodeUnits, Event, - Fix, + fix, Fix1, Fix2, Generator, diff --git a/base/tuple.jl b/base/tuple.jl index 3791d74bfc6983..a3e9c92252ff92 100644 --- a/base/tuple.jl +++ b/base/tuple.jl @@ -1,5 +1,18 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +module _TupleTypeByLength + export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore + const Tuple1OrMore = Tuple{Any, Vararg} + const Tuple2OrMore = Tuple{Any, Any, Vararg} + const Tuple32OrMore = Tuple{ + Any, Any, Any, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, + Vararg{Any, N}, + } where {N} +end + # Document NTuple here where we have everything needed for the doc system """ NTuple{N, T} @@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2]))) map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3]))) map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...)) # stop inlining after some number of arguments to avoid code blowup -const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any, - Any,Any,Any,Any,Any,Any,Any,Any, - Any,Any,Any,Any,Any,Any,Any,Any, - Any,Any,Any,Any,Any,Any,Any,Any, - Vararg{Any,N}} +const Any32 = _TupleTypeByLength.Tuple32OrMore const All32{T,N} = Tuple{T,T,T,T,T,T,T,T, T,T,T,T,T,T,T,T, T,T,T,T,T,T,T,T, diff --git a/base/typedomainnumbers.jl b/base/typedomainnumbers.jl new file mode 100644 index 00000000000000..f00532a1e0035d --- /dev/null +++ b/base/typedomainnumbers.jl @@ -0,0 +1,167 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +# Adapted from the TypeDomainNaturalNumbers.jl package. +module _TypeDomainNumbers + module Zeros + export Zero + struct Zero end + end + + module PositiveIntegers + module RecursiveStep + using ...Zeros + export recursive_step + function recursive_step(@nospecialize t::Type) + Union{Zero, t} + end + end + module UpperBounds + using ..RecursiveStep + abstract type A end + abstract type B{P <: recursive_step(A)} <: A end + abstract type C{P <: recursive_step(B)} <: B{P} end + abstract type D{P <: recursive_step(C)} <: C{P} end + end + using .RecursiveStep + const PositiveIntegerUpperBound = UpperBounds.A + const PositiveIntegerUpperBoundTighter = UpperBounds.D + export + natural_successor, natural_predecessor, + NonnegativeInteger, NonnegativeIntegerUpperBound, + PositiveInteger, PositiveIntegerUpperBound + struct PositiveInteger{ + Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter), + } <: PositiveIntegerUpperBoundTighter{Predecessor} + predecessor::Predecessor + global const NonnegativeInteger = recursive_step(PositiveInteger) + global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound) + global function natural_successor(p::P) where {P <: NonnegativeInteger} + new{P}(p) + end + end + function natural_predecessor(@nospecialize o::PositiveInteger) + getfield(o, :predecessor) # avoid specializing `getproperty` for each number + end + end + + module IntegersGreaterThanOne + using ..PositiveIntegers + export + IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound, + natural_predecessor_predecessor + const IntegerGreaterThanOne = let t = PositiveInteger + t{P} where {P <: t} + end + const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound + PositiveIntegers.UpperBounds.B{P} where {P <: t} + end + function natural_predecessor_predecessor(@nospecialize x::IntegerGreaterThanOne) + natural_predecessor(natural_predecessor(x)) + end + end + + module Constants + using ..Zeros, ..PositiveIntegers + export n0, n1 + const n0 = Zero() + const n1 = natural_successor(n0) + end + + module Utils + using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants + using Base: @_foldable_meta + function subtracted_nonnegative((@nospecialize l::NonnegativeInteger), @nospecialize r::NonnegativeInteger) + @_foldable_meta + if r isa PositiveIntegerUpperBound + let a = natural_predecessor(l), b = natural_predecessor(r) + subtracted_nonnegative(a, b) + end + else + l + end + end + function abs_decrement(n::Int) + @_foldable_meta + if signbit(n) + n + true + else + n - true + end + end + function to_int(@nospecialize o::NonnegativeInteger) + @_foldable_meta + if o isa PositiveIntegerUpperBound + let p = natural_predecessor(o), t = to_int(p) + t + true + end + else + 0 + end + end + function from_abs_int(n::Int) + @_foldable_meta + ret = n0 + while !iszero(n) + n = abs_decrement(n) + ret = natural_successor(ret) + end + ret + end + end + + module Overloads + using ..PositiveIntegers, ..Utils + function (::Type{Int})(@nospecialize o::NonnegativeInteger) + Utils.to_int(o) + end + function Base.show((@nospecialize io::Base.IO), @nospecialize n::NonnegativeInteger) + i = Int(n) + Base.show(io, i) + end + end +end + +module _TypeDomainNumberTupleUtils + using + .._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne, + .._TypeDomainNumbers.Constants, .._TypeDomainNumbers.Utils, .._TupleTypeByLength + using Base: @_total_meta, @_foldable_meta, front, tail + export tuple_type_domain_length, split_tuple, skip_from_front, skip_from_tail + function tuple_type_domain_length(@nospecialize tup::Tuple) + @_total_meta + if tup isa Tuple1OrMore + let t = tail(tup), rec = tuple_type_domain_length(t) + natural_successor(rec) + end + else + n0 + end + end + function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + @_foldable_meta + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = tail(tup) + @inline skip_from_front(t, cm1) + end + else + tup + end + end + function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + @_foldable_meta + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = front(tup) + @inline skip_from_tail(t, cm1) + end + else + tup + end + end + function split_tuple((@nospecialize tup::Tuple), @nospecialize len_l::NonnegativeInteger) + len = tuple_type_domain_length(tup) + len_r = Utils.subtracted_nonnegative(len, len_l) + tup_l = skip_from_tail(tup, len_r) + tup_r = skip_from_front(tup, len_l) + (tup_l, tup_r) + end +end diff --git a/doc/src/base/base.md b/doc/src/base/base.md index 7181965d9aa816..d276d5d3e7eb8b 100644 --- a/doc/src/base/base.md +++ b/doc/src/base/base.md @@ -285,7 +285,7 @@ Base.:(|>) Base.:(∘) Base.ComposedFunction Base.splat -Base.Fix +Base.fix Base.Fix1 Base.Fix2 ``` diff --git a/stdlib/REPL/test/repl.jl b/stdlib/REPL/test/repl.jl index 809913502c3d75..2ab8157b2f58fd 100644 --- a/stdlib/REPL/test/repl.jl +++ b/stdlib/REPL/test/repl.jl @@ -1156,13 +1156,13 @@ fake_repl() do stdin_write, stdout_read, repl write(stdin_write, " ( 123 , Base.Fix1 , ) \n") s = readuntil(stdout_read, "\n\n") - @test endswith(s, "(123, Base.Fix1)") + @test endswith(s, "(123, Base.fix(1))") repl.mistate.active_module = Base # simulate activate_module(Base) write(stdin_write, " ( 456 , Base.Fix2 , ) \n") s = readuntil(stdout_read, "\n\n") # ".Base" prefix not shown here - @test endswith(s, "(456, Fix2)") + @test endswith(s, "(456, fix(2))") # Close REPL ^D readuntil(stdout_read, "julia> ", keep=true) @@ -1217,9 +1217,9 @@ global some_undef_global @test occursin("does not exist", sprint(show, help_result(".."))) # test that helpmode is sensitive to contextual module @test occursin("No documentation found", sprint(show, help_result("Fix2", Main))) -@test occursin("Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).", # exact string may change +@test occursin("fix", # exact string may change sprint(show, help_result("Base.Fix2", Main))) -@test occursin("Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).", # exact string may change +@test occursin("fix", # exact string may change sprint(show, help_result("Fix2", Base))) diff --git a/test/choosetests.jl b/test/choosetests.jl index affdee412bd869..07bbdd47f1b85b 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -15,7 +15,7 @@ const TESTNAMES = [ "bitarray", "copy", "math", "fastmath", "functional", "iterators", "operators", "ordering", "path", "ccall", "parse", "loading", "gmp", "sorting", "spawn", "backtrace", "exceptions", - "file", "read", "version", "namedtuple", + "file", "read", "version", "namedtuple", "typedomainnumbers", "mpfr", "broadcast", "complex", "floatapprox", "stdlib", "reflection", "regex", "float16", "combinatorics", "sysinfo", "env", "rounding", "ranges", "mod2pi", diff --git a/test/functional.jl b/test/functional.jl index 84c4098308ebd5..6a480041283440 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -236,7 +236,7 @@ let (:)(a,b) = (i for i in Base.:(:)(1,10) if i%2==0) @test Int8[ i for i = 1:2 ] == [2,4,6,8,10] end -@testset "Basic tests of Fix1, Fix2, and Fix" begin +@testset "Basic tests of Fix1, Fix2, and fix" begin function test_fix1(Fix1=Base.Fix1) increment = Fix1(+, 1) @test increment(5) == 6 @@ -278,48 +278,48 @@ end test_fix2() # Now, repeat the Fix1 and Fix2 tests, but - # with a Fix lambda function used in their place - test_fix1((op, arg) -> Base.Fix{1}(op, arg)) - test_fix2((op, arg) -> Base.Fix{2}(op, arg)) + # with a fix lambda function used in their place + test_fix1((op, arg) -> Base.fix(1)(op, arg)) + test_fix2((op, arg) -> Base.fix(2)(op, arg)) - # Now, we do more complex tests of Fix: - let Fix=Base.Fix + # Now, we do more complex tests of fix: + let fix=Base.fix @testset "Argument Fixation" begin let f = (x, y, z) -> x + y * z - fixed_f1 = Fix{1}(f, 10) + fixed_f1 = fix(1)(f, 10) @test fixed_f1(2, 3) == 10 + 2 * 3 - fixed_f2 = Fix{2}(f, 5) + fixed_f2 = fix(2)(f, 5) @test fixed_f2(1, 4) == 1 + 5 * 4 - fixed_f3 = Fix{3}(f, 3) + fixed_f3 = fix(3)(f, 3) @test fixed_f3(1, 2) == 1 + 2 * 3 end end @testset "Helpful errors" begin let g = (x, y) -> x - y # Test minimum N - fixed_g1 = Fix{1}(g, 100) + fixed_g1 = fix(1)(g, 100) @test fixed_g1(40) == 100 - 40 # Test maximum N - fixed_g2 = Fix{2}(g, 100) + fixed_g2 = fix(2)(g, 100) @test fixed_g2(150) == 150 - 100 # One over - fixed_g3 = Fix{3}(g, 100) - @test_throws ArgumentError("expected at least 2 arguments to `Fix{3}`, but got 1") fixed_g3(1) + fixed_g3 = fix(3)(g, 100) + @test_throws ArgumentError("expected at least 2 arguments to `fix(3)`, but got 1") fixed_g3(1) end end @testset "Type Stability and Inference" begin let h = (x, y) -> x / y - fixed_h = Fix{2}(h, 2.0) + fixed_h = fix(2)(h, 2.0) @test @inferred(fixed_h(4.0)) == 2.0 end end @testset "Interaction with varargs" begin vararg_f = (x, y, z...) -> x + 10 * y + sum(z; init=zero(x)) - fixed_vararg_f = Fix{2}(vararg_f, 6) + fixed_vararg_f = fix(2)(vararg_f, 6) # Can call with variable number of arguments: @test fixed_vararg_f(1, 2, 3, 4) == 1 + 10 * 6 + sum((2, 3, 4)) @@ -329,35 +329,34 @@ end end @testset "Errors should propagate normally" begin error_f = (x, y) -> sin(x * y) - fixed_error_f = Fix{2}(error_f, Inf) + fixed_error_f = fix(2)(error_f, Inf) @test_throws DomainError fixed_error_f(10) end - @testset "Chaining Fix together" begin - f1 = Fix{1}(*, "1") - f2 = Fix{1}(f1, "2") - f3 = Fix{1}(f2, "3") + @testset "Chaining fix together" begin + f1 = fix(1)(*, "1") + f2 = fix(1)(f1, "2") + f3 = fix(1)(f2, "3") @test f3() == "123" - g1 = Fix{2}(*, "1") - g2 = Fix{2}(g1, "2") - g3 = Fix{2}(g2, "3") + g1 = fix(2)(*, "1") + g2 = fix(2)(g1, "2") + g3 = fix(2)(g2, "3") @test g3("") == "123" end @testset "Zero arguments" begin - f = Fix{1}(x -> x, 'a') + f = fix(1)(x -> x, 'a') @test f() == 'a' end @testset "Dummy-proofing" begin - @test_throws ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got 0") Fix{0}(>, 1) - @test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `0.5::Float64`") Fix{0.5}(>, 1) - @test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `1::UInt64`") Fix{UInt64(1)}(>, 1) + @test_throws ArgumentError fix(0) + @test_throws MethodError fix(0.5) end @testset "Specialize to structs not in `Base`" begin struct MyStruct x::Int end - f = Fix{1}(MyStruct, 1) - @test f isa Fix{1,Type{MyStruct},Int} + f = fix(1)(MyStruct, 1) + @test f isa fix(1) end end end diff --git a/test/typedomainnumbers.jl b/test/typedomainnumbers.jl new file mode 100644 index 00000000000000..3ac010b205453b --- /dev/null +++ b/test/typedomainnumbers.jl @@ -0,0 +1,34 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +using + Test, + Base._TypeDomainNumbers.PositiveIntegers, + Base._TypeDomainNumbers.IntegersGreaterThanOne, + Base._TypeDomainNumbers.Constants, + Base._TypeDomainNumberTupleUtils + +@testset "type domain numbers" begin + @test n0 isa NonnegativeInteger + @test n1 isa NonnegativeInteger + @test n1 isa PositiveInteger + @testset "succ" begin + for x ∈ (n0, n1) + @test x === natural_predecessor(@inferred natural_successor(x)) + @test x === natural_predecessor_predecessor(natural_successor(natural_successor(x))) + end + end + @testset "type safety" begin + @test_throws TypeError PositiveInteger{Int} + end + @testset "tuple utils" begin + @test n0 === @inferred tuple_type_domain_length(()) + @test n1 === @inferred tuple_type_domain_length((7,)) + @test ((), ()) === @inferred split_tuple((), n0) + @test ((), (7,)) === @inferred split_tuple((7,), n0) + @test ((7,), ()) === @inferred split_tuple((7,), n1) + @test ((), (3, 7)) === @inferred split_tuple((3, 7), n0) + @test ((3,), (7,)) === @inferred split_tuple((3, 7), n1) + @test ((), (3, 7, 9)) === @inferred split_tuple((3, 7, 9), n0) + @test ((3,), (7, 9)) === @inferred split_tuple((3, 7, 9), n1) + end +end