Skip to content

Commit

Permalink
Sort Union types. Implements #22664
Browse files Browse the repository at this point in the history
  • Loading branch information
quinnj committed Jul 28, 2017
1 parent c6e51d0 commit e839095
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 2 deletions.
3 changes: 1 addition & 2 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ show(io::IO, ::Core.TypeofBottom) = print(io, "Union{}")

function show(io::IO, x::Union)
print(io, "Union")
sorted_types = sort!(uniontypes(x); by=string)
show_comma_array(io, sorted_types, '{', '}')
show_comma_array(io, uniontypes(x), '{', '}')
end

function print_without_params(@nospecialize(x))
Expand Down
93 changes: 93 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,98 @@ static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, s
}
}

STATIC_INLINE const char *datatype_module_name(jl_value_t *t)
{
return jl_symbol_name(((jl_datatype_t*)t)->name->module->name);
}

STATIC_INLINE const char *str_(const char *s)
{
return s == NULL ? "" : s;
}

STATIC_INLINE int cmp_(int a, int b)
{
return a < b ? -1 : a > b;
}

// a/b are jl_datatype_t* & not NULL
int datatype_name_cmp(jl_value_t *a, jl_value_t *b)
{
if (!jl_is_datatype(a) && jl_is_datatype(b))
return 1;
if (!jl_is_datatype(b))
return -1;
int cmp = strcmp(str_(datatype_module_name(a)), str_(datatype_module_name(b)));
if (cmp != 0)
return cmp;
cmp = strcmp(str_(jl_typename_str(a)), str_(jl_typename_str(b)));
if (cmp != 0)
return cmp;
cmp = cmp_(jl_nparams(a), jl_nparams(b));
if (cmp != 0)
return cmp;
// compare up to 3 type parameters
for (int i = 0; i < 3 && i < jl_nparams(a); i++) {
jl_value_t *ap = jl_tparam(a, i);
jl_value_t *bp = jl_tparam(b, i);
if (ap == bp) {
continue;
}
else if (jl_is_datatype(ap) && jl_is_datatype(bp)) {
cmp = datatype_name_cmp(ap, bp);
}
else if (jl_is_unionall(ap) && jl_is_unionall(bp)) {
cmp = datatype_name_cmp(jl_unwrap_unionall(ap), jl_unwrap_unionall(bp));
}
else {
// give up
cmp = 0;
}
}
return cmp;
}

// sort singletons first, then DataTypes, then UnionAlls,
// ties broken alphabetically including module name & type parameters
int union_sort_cmp(const void *ap, const void *bp)
{
jl_value_t *a = *(jl_value_t**)ap;
jl_value_t *b = *(jl_value_t**)bp;
if (a == NULL)
return b == NULL ? 0 : 1;
if (b == NULL)
return -1;
if (jl_is_datatype(a)) {
if (!jl_is_datatype(b))
return -1;
if (jl_is_datatype_singleton((jl_datatype_t*)a)) {
if (jl_is_datatype_singleton((jl_datatype_t*)b))
return datatype_name_cmp(a, b);
return -1;
}
else if (jl_is_datatype_singleton((jl_datatype_t*)b)) {
return 1;
}
else if (jl_isbits(a)) {
if (jl_isbits(b))
return datatype_name_cmp(a, b);
return -1;
}
else if (jl_isbits(b)) {
return 1;
}
else {
return datatype_name_cmp(a, b);
}
}
else {
if (jl_is_datatype(b))
return 1;
return datatype_name_cmp(jl_unwrap_unionall(a), jl_unwrap_unionall(b));
}
}

JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
{
if (n == 0) return (jl_value_t*)jl_bottom_type;
Expand Down Expand Up @@ -417,6 +509,7 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
}
}
}
qsort(temp, nt, sizeof(jl_value_t*), union_sort_cmp);
jl_value_t **ptu = &temp[nt];
*ptu = jl_bottom_type;
int k;
Expand Down
38 changes: 38 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5089,3 +5089,41 @@ m22929_2.x = m22929_1
@test !isdefined_22929_x(m22929_1)
@test isdefined_22929_1(m22929_2)
@test isdefined_22929_x(m22929_2)

# Union type sorting
for T in (
(Void, Int8),
(Void, Int64),
(Void, Tuple{Int64, String}),
(Void, Array),
(Float64, Int64),
(Float64, String),
(Float64, Array),
(String, Array),
(Int64, Tuple{Int64, Float64}),
(Tuple{Int64, Float64}, Array)
)
@test Base.uniontypes(Union{T...}) == collect(T)
@test Base.uniontypes(Union{reverse(T)...}) == collect(T)
end
@test Base.uniontypes(Union{Void, Union{Int64, Float64}}) == Any[Void, Float64, Int64]
module AlternativeIntModule
struct Int64
val::UInt64
end
end
@test Base.uniontypes(Union{Int64, AlternativeIntModule.Int64}) == Any[AlternativeIntModule.Int64, Int64]
@test Base.uniontypes(Union{AlternativeIntModule.Int64, Int64}) == Any[AlternativeIntModule.Int64, Int64]
# because DAlternativeIntModule is alphabetically after Core.Int64
module DAlternativeIntModule
struct Int64
val::UInt64
end
end
@test Base.uniontypes(Union{Int64, DAlternativeIntModule.Int64}) == Any[Int64, DAlternativeIntModule.Int64]
@test Base.uniontypes(Union{DAlternativeIntModule.Int64, Int64}) == Any[Int64, DAlternativeIntModule.Int64]
@test Base.uniontypes(Union{Vector{Int8}, Vector{Int16}}) == Base.uniontypes(Union{Vector{Int16}, Vector{Int8}})
mutable struct ANonIsBitsType
v::Int64
end
@test Base.uniontypes(Union{Int64, ANonIsBitsType}) == Base.uniontypes(Union{ANonIsBitsType, Int64})

0 comments on commit e839095

Please sign in to comment.