Skip to content

Commit

Permalink
Sort Union types. Implements #22664
Browse files Browse the repository at this point in the history
  • Loading branch information
quinnj committed Jul 14, 2017
1 parent f8ff488 commit 9553558
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 2 deletions.
3 changes: 1 addition & 2 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ show(io::IO, ::Core.TypeofBottom) = print(io, "Union{}")

function show(io::IO, x::Union)
print(io, "Union")
sorted_types = sort!(uniontypes(x); by=string)
show_comma_array(io, sorted_types, '{', '}')
show_comma_array(io, uniontypes(x), '{', '}')
end

function print_without_params(x::ANY)
Expand Down
49 changes: 49 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,54 @@ static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, s
}
}

#define datatype_module_name(t) jl_symbol_name(((jl_datatype_t*)t)->name->module->name)

int datatype_name_cmp(jl_value_t *a, jl_value_t *b)
{
const char *aty = jl_typename_str(a);
const char *bty = jl_typename_str(b);
int cmp = aty == NULL ? 1 : bty == NULL ? -1 : strcmp(aty, bty);
if (cmp == 0) {
// datatype name are equal, which should only happen
// if they were defined in different modules
return strcmp(datatype_module_name(a), datatype_module_name(b));
}
else {
return cmp;
}
}

int union_sort_cmp(const void *ap, const void *bp)
{
jl_value_t *a = *(jl_value_t**)ap;
jl_value_t *b = *(jl_value_t**)bp;
if (a == NULL)
return b == NULL ? 0 : 1;
if (b == NULL)
return -1;
if (jl_is_datatype(a)) {
if (!jl_is_datatype(b))
return -1;
if (jl_is_datatype_singleton((jl_datatype_t*)a)) {
if (jl_is_datatype_singleton((jl_datatype_t*)b))
return datatype_name_cmp(a, b);
return -1;
}
else if (jl_is_datatype_singleton((jl_datatype_t*)b)) {
return 1;
}
else {
return datatype_name_cmp(a, b);
}
}
else {
// a is UnionAll
if (jl_is_datatype(b))
return 1;
return datatype_name_cmp(jl_unwrap_unionall(a), jl_unwrap_unionall(b));
}
}

JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
{
if (n == 0) return (jl_value_t*)jl_bottom_type;
Expand Down Expand Up @@ -419,6 +467,7 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
}
}
}
qsort(temp, nt, sizeof(jl_value_t*), union_sort_cmp);
jl_value_t **ptu = &temp[nt];
*ptu = jl_bottom_type;
int k;
Expand Down
33 changes: 33 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5075,3 +5075,36 @@ f_isdefined_cl_6() = (local x; () -> @isdefined x)
@test !f_isdefined_cl_4()
@test f_isdefined_cl_5()()
@test !f_isdefined_cl_6()()

# Union type sorting
for T in (
(Void, Int8),
(Void, Int64),
(Void, Tuple{Int64, String}),
(Void, Array),
(Float64, Int64),
(Float64, String),
(Float64, Array),
(String, Array),
(Int64, Tuple{Int64, Float64}),
(Tuple{Int64, Float64}, Array)
)
@test Base.uniontypes(Union{T...}) == collect(T)
@test Base.uniontypes(Union{reverse(T)...}) == collect(T)
end
@test Base.uniontypes(Union{Void, Union{Int64, Float64}}) == Any[Void, Float64, Int64]
module AlternativeIntModule
struct Int64
val::UInt64
end
end
@test Base.uniontypes(Union{Int64, AlternativeIntModule.Int64}) == Any[AlternativeIntModule.Int64, Int64]
@test Base.uniontypes(Union{AlternativeIntModule.Int64, Int64}) == Any[AlternativeIntModule.Int64, Int64]
# because DAlternativeIntModule is alphabetically after Core.Int64
module DAlternativeIntModule
struct Int64
val::UInt64
end
end
@test Base.uniontypes(Union{Int64, DAlternativeIntModule.Int64}) == Any[Int64, DAlternativeIntModule.Int64]
@test Base.uniontypes(Union{DAlternativeIntModule.Int64, Int64}) == Any[Int64, DAlternativeIntModule.Int64]

0 comments on commit 9553558

Please sign in to comment.