From 77e816eaae89366c5a062c189c9fe151be872c24 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Wed, 12 Nov 2014 21:26:29 -0500 Subject: [PATCH] Improve type inference for tuples in static parameters Tuples in static parameters were previously not inferred at all (not even in their own constructors). This came up in Cxx.jl which has a lot of tuples in type parameters. --- base/inference.jl | 74 ++++++++++++++++++++++++++++++++++------------- test/core.jl | 7 +++++ 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/base/inference.jl b/base/inference.jl index 042679c49f955..1029a81d57ee3 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -400,7 +400,39 @@ end t_func[fieldtype] = (2, 2, fieldtype_tfunc) t_func[Box] = (1, 1, (a,)->Box) -valid_tparam(x::ANY) = isa(x,Int) || isa(x,Symbol) || isa(x,Bool) +function valid_tparam(x::ANY) + if isa(x,Int) || isa(x,Symbol) || isa(x,Bool) + return true + elseif isa(x,Tuple) + for t in x + if !valid_tparam(t) + return false + end + end + return true + end + return false +end + +function extract_simple_tparam(Ai) + if (isa(Ai,Int) || isa(Ai,Bool)) + return Ai + elseif isa(Ai,QuoteNode) && valid_tparam(Ai.value) + return Ai.value + elseif isa(inference_stack,CallStack) && isa(Ai,Expr) && + is_known_call(Ai,tuple,inference_stack.sv) + tup = () + for arg in Ai.args[2:end] + val = extract_simple_tparam(arg) + if val === Bottom + return val + end + tup = tuple(tup...,val) + end + return tup + end + return Bottom +end # TODO: handle e.g. apply_type(T, R::Union(Type{Int32},Type{Float64})) const apply_type_tfunc = function (A, args...) @@ -421,28 +453,30 @@ const apply_type_tfunc = function (A, args...) tparams = tuple(tparams..., ai.parameters[1]) elseif isa(ai,Tuple) && all(isType,ai) tparams = tuple(tparams..., map(t->t.parameters[1], ai)) - elseif i<=lA && (isa(A[i],Int) || isa(A[i],Bool)) - tparams = tuple(tparams..., A[i]) - elseif i<=lA && isa(A[i],QuoteNode) && valid_tparam(A[i].value) - tparams = tuple(tparams..., A[i].value) else - if i<=lA && isa(A[i],Symbol) && isa(inference_stack,CallStack) - sp = inference_stack.sv.sp - s = A[i] - found = false - for j=1:2:length(sp) - if is(sp[j].name,s) - # static parameter - val = sp[j+1] - if valid_tparam(val) - tparams = tuple(tparams..., val) - found = true - break + if i<=lA + val = extract_simple_tparam(A[i]) + if val !== Bottom + tparams = tuple(tparams..., val) + continue + elseif isa(inference_stack,CallStack) && isa(A[i],Symbol) + sp = inference_stack.sv.sp + s = A[i] + found = false + for j=1:2:length(sp) + if is(sp[j].name,s) + # static parameter + val = sp[j+1] + if valid_tparam(val) + tparams = tuple(tparams..., val) + found = true + break + end end end - end - if found - continue + if found + continue + end end end if i-1 > length(headtype.parameters) diff --git a/test/core.jl b/test/core.jl index 13c4776c3c510..46c3f537038f5 100644 --- a/test/core.jl +++ b/test/core.jl @@ -2123,3 +2123,10 @@ function f9947() end end @test f9947() == UInt128(1) + +# Type inference for tuple parameters +immutable fooTuple{s}; end +barTuple1() = fooTuple{(:y,)}() +barTuple2() = fooTuple{tuple(:y)}() + +@test Base.return_types(barTuple1,())[1] == Base.return_types(barTuple2,())[1] == fooTuple{(:y,)}