Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get better type info from partially generated functions #31025

Merged
merged 2 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV
if world < min_world(method) || world > max_world(method)
return nothing
end
if isdefined(method, :generator) && !isdispatchtuple(atypes)
# don't call staged functions on abstract types.
# (see issues #8504, #10230)
# we can't guarantee that their type behavior is monotonic.
if isdefined(method, :generator) && !may_invoke_generator(method, atypes, sparams)
return nothing
end
if preexisting
Expand Down
71 changes: 68 additions & 3 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -941,15 +941,80 @@ struct CodegenParams
emit_function, emitted_function)
end

const SLOT_USED = 0x8
ast_slotflag(@nospecialize(code), i) = ccall(:jl_ast_slotflag, UInt8, (Any, Csize_t), code, i - 1)

"""
may_invoke_generator(method, atypes, sparams)

Computes whether or not we may invoke the generator for the given `method` on
the given atypes and sparams. For correctness, all generated function are
required to return monotonic answers. However, since we don't expect users to
be able to successfully implement this criterion, we only call generated
functions on concrete types. The one exception to this is that we allow calling
generators with abstract types if the generator does not use said abstract type
(and thus cannot incorrectly use it to break monotonicity). This function
computes whether we are in either of these cases.
"""
function may_invoke_generator(method::Method, @nospecialize(atypes), sparams::SimpleVector)
# If we have complete information, we may always call the generator
isdispatchtuple(atypes) && return true

# We don't have complete information, but it is possible that the generator
# syntactically doesn't make use of the information we don't have. Check
# for that.

# For now, only handle the (common, generated by the frontend case) that the
# generator only has one method
isa(method.generator, Core.GeneratedFunctionStub) || return false
generator_mt = typeof(method.generator.gen).name.mt
length(generator_mt) == 1 || return false

generator_method = first(MethodList(generator_mt))
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it's introducing a bit of excess design debt, making it harder to fix #14919. The mt field only exists as a implementation detail, and you shouldn't depend on it being meaning for anything. Also first(MethodList(generator_mt)) seems like a clear design smell...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just use methods from reflection.jl?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was concerned that that would try to do a method lookup that would be more expensive, than just doing this, which is O(1).

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let the runtime worry about that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, will do.

nsparams = length(sparams)
isdefined(generator_method, :source) || return false
code = generator_method.source
nslots = ccall(:jl_ast_nslots, Int, (Any,), code)
at = unwrap_unionall(atypes)
(nslots >= 1 + length(sparams) + length(at.parameters)) || return false

for i = 1:nsparams
if isa(sparams[i], TypeVar)
if (ast_slotflag(code, 1 + i) & SLOT_USED) != 0
return false
end
end
end
for i = 1:length(at.parameters)
if !isdispatchelem(at.parameters[i])
if (ast_slotflag(code, 1 + i + nsparams) & SLOT_USED) != 0
return false
end
end
end
return true
end

# give a decent error message if we try to instantiate a staged function on non-leaf types
function func_for_method_checked(m::Method, @nospecialize types)
function func_for_method_checked(m::Method, @nospecialize(types), sparams::SimpleVector)
if isdefined(m, :generator) && !Core.Compiler.may_invoke_generator(m, types, sparams)
error("cannot call @generated function `", m, "` ",
"with abstract argument types: ", types)
end
return m
end

function func_for_method_checked(m::Method, @nospecialize(types))
depwarn("The two argument form of `func_for_method_checked` is deprecated. Pass sparams in addition.",
:func_for_method_checked)
if isdefined(m, :generator) && !isdispatchtuple(types)
error("cannot call @generated function `", m, "` ",
"with abstract argument types: ", types)
end
return m
end


"""
code_typed(f, types; optimize=true, debuginfo=:default)

Expand Down Expand Up @@ -978,7 +1043,7 @@ function code_typed(@nospecialize(f), @nospecialize(types=Tuple);
types = to_tuple_type(types)
asts = []
for x in _methods(f, types, -1, world)
meth = func_for_method_checked(x[3], types)
meth = func_for_method_checked(x[3], types, x[2])
(code, ty) = Core.Compiler.typeinf_code(meth, x[1], x[2], optimize, params)
code === nothing && error("inference not successful") # inference disabled?
debuginfo == :none && remove_linenums!(code)
Expand All @@ -997,7 +1062,7 @@ function return_types(@nospecialize(f), @nospecialize(types=Tuple))
world = ccall(:jl_get_world_counter, UInt, ())
params = Core.Compiler.Params(world)
for x in _methods(f, types, -1, world)
meth = func_for_method_checked(x[3], types)
meth = func_for_method_checked(x[3], types, x[2])
ty = Core.Compiler.typeinf_type(meth, x[1], x[2], params)
ty === nothing && error("inference not successful") # inference disabled?
push!(rt, ty)
Expand Down
6 changes: 3 additions & 3 deletions src/ast.scm
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@
(define (vinfo:capt v) (< 0 (logand (caddr v) 1)))
(define (vinfo:asgn v) (< 0 (logand (caddr v) 2)))
(define (vinfo:never-undef v) (< 0 (logand (caddr v) 4)))
(define (vinfo:const v) (< 0 (logand (caddr v) 8)))
(define (vinfo:read v) (< 0 (logand (caddr v) 8)))
(define (vinfo:sa v) (< 0 (logand (caddr v) 16)))
(define (set-bit x b val) (if val (logior x b) (logand x (lognot b))))
;; record whether var is captured
Expand All @@ -443,8 +443,8 @@
(define (vinfo:set-asgn! v a) (set-car! (cddr v) (set-bit (caddr v) 2 a)))
;; whether the assignments to var are known to dominate its usages
(define (vinfo:set-never-undef! v a) (set-car! (cddr v) (set-bit (caddr v) 4 a)))
;; whether var is const
(define (vinfo:set-const! v a) (set-car! (cddr v) (set-bit (caddr v) 8 a)))
;; whether var is ever read
(define (vinfo:set-read! v a) (set-car! (cddr v) (set-bit (caddr v) 8 a)))
;; whether var is assigned once
(define (vinfo:set-sa! v a) (set-car! (cddr v) (set-bit (caddr v) 16 a)))
;; occurs undef: mask 32
Expand Down
26 changes: 22 additions & 4 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ static void write_mod_list(ios_t *s, jl_array_t *a)
}

// "magic" string and version header of .ji file
static const int JI_FORMAT_VERSION = 7;
static const int JI_FORMAT_VERSION = 8;
static const char JI_MAGIC[] = "\373jli\r\n\032\n"; // based on PNG signature
static const uint16_t BOM = 0xFEFF; // byte-order marker
static void write_header(ios_t *s)
Expand Down Expand Up @@ -2459,6 +2459,13 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
size_t nsyms = jl_array_len(code->slotnames);
assert(nsyms >= m->nargs && nsyms < INT32_MAX); // required by generated functions
write_int32(s.s, nsyms);
assert(nsyms == jl_array_len(code->slotflags));
ios_write(s.s, (char*)jl_array_data(code->slotflags), nsyms);

// N.B.: The layout of everything before this point is explicitly referenced
// by the various jl_ast_ accessors. Make sure to adjust those if you change
// the data layout.

for (i = 0; i < nsyms; i++) {
jl_sym_t *name = (jl_sym_t*)jl_array_ptr_ref(code->slotnames, i);
assert(jl_is_symbol(name));
Expand All @@ -2468,7 +2475,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
}

size_t nf = jl_datatype_nfields(jl_code_info_type);
for (i = 0; i < nf - 5; i++) {
for (i = 0; i < nf - 6; i++) {
if (i == 1) // skip codelocs
continue;
int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field
Expand Down Expand Up @@ -2536,6 +2543,9 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
code->pure = !!(flags & (1 << 0));

size_t nslots = read_int32(&src);
code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots);
ios_read(s.s, (char*)jl_array_data(code->slotflags), nslots);

jl_array_t *syms = jl_alloc_vec_any(nslots);
code->slotnames = syms;
for (i = 0; i < nslots; i++) {
Expand All @@ -2547,7 +2557,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
}

size_t nf = jl_datatype_nfields(jl_code_info_type);
for (i = 0; i < nf - 5; i++) {
for (i = 0; i < nf - 6; i++) {
if (i == 1)
continue;
assert(jl_field_isptr(jl_code_info_type, i));
Expand Down Expand Up @@ -2620,6 +2630,14 @@ JL_DLLEXPORT ssize_t jl_ast_nslots(jl_array_t *data)
}
}

JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i)
{
assert(i < jl_ast_nslots(data));
if (jl_is_code_info(data))
return ((uint8_t*)((jl_code_info_t*)data)->slotflags->data)[i];
return ((uint8_t*)data->data)[1 + sizeof(int32_t) + i];
}

JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names)
{
size_t i, nargs = jl_array_len(names);
Expand All @@ -2637,7 +2655,7 @@ JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names)
int nslots = jl_load_unaligned_i32(d + 1);
assert(nslots >= nargs);
(void)nslots;
char *namestr = d + 5;
char *namestr = d + 5 + nslots;
for (i = 0; i < nargs; i++) {
size_t namelen = strlen(namestr);
jl_sym_t *name = jl_symbol_n(namestr, namelen);
Expand Down
7 changes: 6 additions & 1 deletion src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -2641,7 +2641,12 @@
;; where var-info-lst is a list of var-info records
(define (analyze-vars e env captvars sp)
(if (or (atom? e) (quoted? e))
e
(begin
(if (symbol? e)
(let ((vi (var-info-for e env)))
(if vi
(vinfo:set-read! vi #t))))
e)
(case (car e)
((local-def) ;; a local that we know has an assignment that dominates all usages
(let ((vi (var-info-for (cadr e) env)))
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
JL_DLLEXPORT uint8_t jl_ast_flag_inferred(jl_array_t *data);
JL_DLLEXPORT uint8_t jl_ast_flag_inlineable(jl_array_t *data);
JL_DLLEXPORT uint8_t jl_ast_flag_pure(jl_array_t *data);
JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i);
JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names);

JL_DLLEXPORT int jl_is_operator(char *sym);
Expand Down
7 changes: 4 additions & 3 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
li->ssaflags = jl_alloc_array_1d(jl_array_uint8_type, 0);

// Flags that need to be copied to slotflags
const uint8_t vinfo_mask = 16 | 32 | 64;
const uint8_t vinfo_mask = 8 | 16 | 32 | 64;
int i;
for (i = 0; i < nslots; i++) {
jl_value_t *vi = jl_array_ptr_ref(vis, i);
Expand Down Expand Up @@ -383,7 +383,7 @@ STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator
JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
{
JL_TIMING(STAGED_FUNCTION);
jl_tupletype_t *tt = (jl_tupletype_t*)linfo->specTypes;
jl_value_t *tt = linfo->specTypes;
jl_method_t *def = linfo->def.method;
jl_value_t *generator = def->generator;
assert(generator != NULL);
Expand All @@ -402,7 +402,8 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
ptls->world_age = def->min_world;

// invoke code generator
ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(tt->parameters), jl_nparams(tt));
jl_tupletype_t *ttdt = (jl_tupletype_t*)jl_unwrap_unionall(tt);
ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt));

if (jl_is_code_info(ex)) {
func = (jl_code_info_t*)ex;
Expand Down
2 changes: 1 addition & 1 deletion stdlib/InteractiveUtils/src/codeview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe
t = to_tuple_type(t)
tt = signature_type(f, t)
(ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), tt, meth.sig)::Core.SimpleVector
meth = Base.func_for_method_checked(meth, ti)
meth = Base.func_for_method_checked(meth, ti, env)
linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, ti, env, world)
# get the code for it
return _dump_function_linfo(linfo, world, native, wrapper, strip_ir_metadata, dump_module, syntax, optimize, debuginfo, params)
Expand Down
23 changes: 22 additions & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ function get_linfo(@nospecialize(f), @nospecialize(t))
tt = Tuple{ft, t.parameters...}
precompile(tt)
(ti, env) = ccall(:jl_type_intersection_with_env, Ref{Core.SimpleVector}, (Any, Any), tt, meth.sig)
meth = Base.func_for_method_checked(meth, tt)
meth = Base.func_for_method_checked(meth, tt, env)
return ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, tt, env, world)
end
Expand Down Expand Up @@ -2224,3 +2224,24 @@ _call_rttf_test() = Core.Compiler.return_type(_rttf_test, Tuple{Any})
f_with_Type_arg(::Type{T}) where {T} = T
@test Base.return_types(f_with_Type_arg, (Any,)) == Any[Type]
@test Base.return_types(f_with_Type_arg, (Type{Vector{T}} where T,)) == Any[Type{Vector{T}} where T]

# Generated functions that only reference some of their arguments
@inline function my_ntuple(f::F, ::Val{N}) where {F,N}
N::Int
(N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N)))
if @generated
quote
@Base.nexprs $N i -> t_i = f(i)
@Base.ncall $N tuple t
end
else
Tuple(f(i) for i = 1:N)
end
end
call_ntuple(a, b) = my_ntuple(i->(a+b; i), Val(4))
@test Base.return_types(call_ntuple, Tuple{Any,Any}) == [NTuple{4, Int}]
@test length(code_typed(my_ntuple, Tuple{Any, Val{4}})) == 1
@test_throws ErrorException code_typed(my_ntuple, Tuple{Any, Val})

@generated unionall_sig_generated(::Vector{T}, b::Vector{S}) where {T, S} = :($b)
@test length(code_typed(unionall_sig_generated, Tuple{Any, Vector{Int}})) == 1