Skip to content

Commit

Permalink
Add a macro to opt into aggressive constprop
Browse files Browse the repository at this point in the history
Right now aggressive constprop is essentially tied to the inlining
threshold (or to their name being `getproperty` or `setproperty!`
respectively, which can be both somewhat brittle if the inlining cost
changes and insufficient when you do really know that const prop
would be beneficial even if the function is not inlineable. This
adds a simple macro that can be used to manually annotate methods
to force aggressive constprop on them.
  • Loading branch information
Keno committed Jan 22, 2021
1 parent 770d0d5 commit f62c2ef
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 12 deletions.
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
istopfunction(f, :<<) || istopfunction(f, :>>))
return Any
end
force_inference = allconst || InferenceParams(interp).aggressive_constant_propagation
force_inference = allconst || method.aggressive_constprop || InferenceParams(interp).aggressive_constant_propagation
if istopfunction(f, :getproperty) || istopfunction(f, :setproperty!)
force_inference = true
end
Expand Down
13 changes: 13 additions & 0 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,19 @@ macro pure(ex)
esc(isa(ex, Expr) ? pushmeta!(ex, :pure) : ex)
end

"""
@aggressive_constprop ex
@aggressive_constprop(ex)
`@aggressive_constprop` requests more aggressive interprocedural constant
propagation for the annotated function. For a method where the return type
depends on the value of the arguments, this can yield improved inference results
at the cost of additional compile time.
"""
macro aggressive_constprop(ex)
esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex)
end

"""
@propagate_inbounds
Expand Down
2 changes: 2 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jl_sym_t *static_parameter_sym; jl_sym_t *inline_sym;
jl_sym_t *noinline_sym; jl_sym_t *generated_sym;
jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym;
jl_sym_t *propagate_inbounds_sym; jl_sym_t *specialize_sym;
jl_sym_t *aggressive_constprop_sym;
jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym;
jl_sym_t *colon_sym; jl_sym_t *hygienicscope_sym;
jl_sym_t *throw_undef_if_not_sym; jl_sym_t *getfield_undefref_sym;
Expand Down Expand Up @@ -385,6 +386,7 @@ void jl_init_common_symbols(void)
noinline_sym = jl_symbol("noinline");
polly_sym = jl_symbol("polly");
propagate_inbounds_sym = jl_symbol("propagate_inbounds");
aggressive_constprop_sym = jl_symbol("aggressive_constprop");
isdefined_sym = jl_symbol("isdefined");
nospecialize_sym = jl_symbol("nospecialize");
specialize_sym = jl_symbol("specialize");
Expand Down
2 changes: 2 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
write_int8(s->s, m->isva);
write_int8(s->s, m->pure);
write_int8(s->s, m->is_for_opaque_closure);
write_int8(s->s, m->aggressive_constprop);
jl_serialize_value(s, (jl_value_t*)m->slot_syms);
jl_serialize_value(s, (jl_value_t*)m->roots);
jl_serialize_value(s, (jl_value_t*)m->ccallable);
Expand Down Expand Up @@ -1442,6 +1443,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
m->isva = read_int8(s->s);
m->pure = read_int8(s->s);
m->is_for_opaque_closure = read_int8(s->s);
m->aggressive_constprop = read_int8(s->s);
m->slot_syms = jl_deserialize_value(s, (jl_value_t**)&m->slot_syms);
jl_gc_wb(m, m->slot_syms);
m->roots = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&m->roots);
Expand Down
4 changes: 3 additions & 1 deletion src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
jl_get_ptls_states()
};

uint8_t flags = (code->inferred << 3)
uint8_t flags = (code->aggressive_constprop << 4)
| (code->inferred << 3)
| (code->inlineable << 2)
| (code->propagate_inbounds << 1)
| (code->pure << 0);
Expand Down Expand Up @@ -787,6 +788,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t

jl_code_info_t *code = jl_new_code_info_uninit();
uint8_t flags = read_uint8(s.s);
code->aggressive_constprop = !!(flags & (1 << 4));
code->inferred = !!(flags & (1 << 3));
code->inlineable = !!(flags & (1 << 2));
code->propagate_inbounds = !!(flags & (1 << 1));
Expand Down
18 changes: 11 additions & 7 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2150,7 +2150,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(18,
jl_perm_symsvec(19,
"code",
"codelocs",
"ssavaluetypes",
Expand All @@ -2168,8 +2168,9 @@ void jl_init_types(void) JL_GC_DISABLED
"inferred",
"inlineable",
"propagate_inbounds",
"pure"),
jl_svec(18,
"pure",
"aggressive_constprop"),
jl_svec(19,
jl_array_any_type,
jl_array_int32_type,
jl_any_type,
Expand All @@ -2187,13 +2188,14 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
0, 1, 18);
0, 1, 19);

jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(23,
jl_perm_symsvec(24,
"name",
"module",
"file",
Expand All @@ -2216,8 +2218,9 @@ void jl_init_types(void) JL_GC_DISABLED
"nkw",
"isva",
"pure",
"is_for_opaque_closure"),
jl_svec(23,
"is_for_opaque_closure",
"aggressive_constprop"),
jl_svec(24,
jl_symbol_type,
jl_module_type,
jl_symbol_type,
Expand All @@ -2240,6 +2243,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_int32_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
0, 1, 10);

Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ typedef struct _jl_code_info_t {
uint8_t inlineable;
uint8_t propagate_inbounds;
uint8_t pure;
uint8_t aggressive_constprop;
} jl_code_info_t;

// This type describes a single method definition, and stores data
Expand Down Expand Up @@ -328,6 +329,7 @@ typedef struct _jl_method_t {
uint8_t isva;
uint8_t pure;
uint8_t is_for_opaque_closure;
uint8_t aggressive_constprop;

// hidden fields:
// lock for modifications to the method
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,7 @@ extern jl_sym_t *static_parameter_sym; extern jl_sym_t *inline_sym;
extern jl_sym_t *noinline_sym; extern jl_sym_t *generated_sym;
extern jl_sym_t *generated_only_sym; extern jl_sym_t *isdefined_sym;
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *specialize_sym;
extern jl_sym_t *aggressive_constprop_sym;
extern jl_sym_t *nospecialize_sym; extern jl_sym_t *macrocall_sym;
extern jl_sym_t *colon_sym; extern jl_sym_t *hygienicscope_sym;
extern jl_sym_t *throw_undef_if_not_sym; extern jl_sym_t *getfield_undefref_sym;
Expand Down
3 changes: 3 additions & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
li->inlineable = 1;
else if (ma == (jl_value_t*)propagate_inbounds_sym)
li->propagate_inbounds = 1;
else if (ma == (jl_value_t*)aggressive_constprop_sym)
li->aggressive_constprop = 1;
else
jl_array_ptr_set(meta, ins++, ma);
}
Expand Down Expand Up @@ -528,6 +530,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
}
m->called = called;
m->pure = src->pure;
m->aggressive_constprop = src->aggressive_constprop;
jl_add_function_name_to_lineinfo(src, (jl_value_t*)m->name);

jl_array_t *copy = NULL;
Expand Down
17 changes: 14 additions & 3 deletions stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ mutable struct Serializer{I<:IO} <: AbstractSerializer
table::IdDict{Any,Any}
pending_refs::Vector{Int}
known_object_data::Dict{UInt64,Any}
Serializer{I}(io::I) where I<:IO = new(io, 0, IdDict(), Int[], Dict{UInt64,Any}())
version::Int
Serializer{I}(io::I) where I<:IO = new(io, 0, IdDict(), Int[], Dict{UInt64,Any}(), ser_version)
end

Serializer(io::IO) = Serializer{typeof(io)}(io)
Expand Down Expand Up @@ -78,7 +79,7 @@ const TAGS = Any[

@assert length(TAGS) == 255

const ser_version = 13 # do not make changes without bumping the version #!
const ser_version = 14 # do not make changes without bumping the version #!

const NTAGS = length(TAGS)

Expand Down Expand Up @@ -414,6 +415,7 @@ function serialize(s::AbstractSerializer, meth::Method)
serialize(s, meth.nargs)
serialize(s, meth.isva)
serialize(s, meth.is_for_opaque_closure)
serialize(s, meth.aggressive_constprop)
if isdefined(meth, :source)
serialize(s, Base._uncompressed_ast(meth, meth.source))
else
Expand Down Expand Up @@ -717,6 +719,8 @@ function readheader(s::AbstractSerializer)
error("""Cannot read stream serialized with a newer version of Julia.
Got data version $version > current version $ser_version""")
end
s.version = version
return
end

"""
Expand Down Expand Up @@ -985,12 +989,15 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
else
slot_syms = syms::String
end
nargs = deserialize(s)::Int32
isva = deserialize(s)::Bool
is_for_opaque_closure = false
aggressive_constprop = false
template_or_is_opaque = deserialize(s)
if isa(template_or_is_opaque, Bool)
is_for_opaque_closure = template_or_is_opaque
if version >= 14
aggressive_constprop = deserialize(s)::Bool
end
template = deserialize(s)
else
template = template_or_is_opaque
Expand All @@ -1005,6 +1012,7 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
meth.nargs = nargs
meth.isva = isva
meth.is_for_opaque_closure = is_for_opaque_closure
meth.aggressive_constprop = aggressive_constprop
if template !== nothing
# TODO: compress template
meth.source = template::CodeInfo
Expand Down Expand Up @@ -1125,6 +1133,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
ci.inlineable = deserialize(s)
ci.propagate_inbounds = deserialize(s)
ci.pure = deserialize(s)
if version >= 14
ci.aggressive_constprop = deserialize(s)::Bool
end
return ci
end

Expand Down
12 changes: 12 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3011,3 +3011,15 @@ g38888() = S38888(Base.inferencebarrier(3), nothing)

f_inf_error_bottom(x::Vector) = isempty(x) ? error(x[1]) : x
@test Core.Compiler.return_type(f_inf_error_bottom, Tuple{Vector{Any}}) == Vector{Any}

# @aggressive_constprop
@noinline g_nonaggressive(y, x) = Val{x}()
@noinline @Base.aggressive_constprop g_aggressive(y, x) = Val{x}()

f_nonaggressive(x) = g_nonaggressive(x, 1)
f_aggressive(x) = g_aggressive(x, 1)

# The first test just makes sure that improvements to the compiler don't
# render the annotation effectless.
@test Base.return_types(f_nonaggressive, Tuple{Int})[1] == Val
@test Base.return_types(f_aggressive, Tuple{Int})[1] == Val{1}

0 comments on commit f62c2ef

Please sign in to comment.