Skip to content

Commit

Permalink
inference: introduce more error/throwness checks for array primitives (
Browse files Browse the repository at this point in the history
…#43587)

Obvious errors are usually caught in higher places like `convert`, but
it's better to have error checks at each builtin level in order to enable
early bail out from errorneous code compilation (when somehow it does not
rely on common abstraction). These checks are also useful for third
consumers like JET.
  • Loading branch information
aviatesk authored Jan 8, 2022
1 parent 95cfbcc commit acb6e16
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 17 deletions.
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ const _PURE_BUILTINS = Any[tuple, svec, ===, typeof, nfields]
# known to be effect-free if the are nothrow
const _PURE_OR_ERROR_BUILTINS = [
fieldtype, apply_type, isa, UnionAll,
getfield, arrayref, const_arrayref, isdefined, Core.sizeof,
getfield, arrayref, const_arrayref, arraysize, isdefined, Core.sizeof,
Core.kwfunc, Core.ifelse, Core._typevar, (<:)
]

Expand Down
75 changes: 59 additions & 16 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,25 @@ end
add_tfunc(Core._typevar, 3, 3, typevar_tfunc, 100)
add_tfunc(applicable, 1, INT_INF, (@nospecialize(f), args...)->Bool, 100)
add_tfunc(Core.Intrinsics.arraylen, 1, 1, @nospecialize(x)->Int, 4)
add_tfunc(arraysize, 2, 2, (@nospecialize(a), @nospecialize(d))->Int, 4)

function arraysize_tfunc(@nospecialize(ary), @nospecialize(dim))
hasintersect(widenconst(ary), Array) || return Bottom
hasintersect(widenconst(dim), Int) || return Bottom
return Int
end
add_tfunc(arraysize, 2, 2, arraysize_tfunc, 4)

function arraysize_nothrow(argtypes::Vector{Any})
length(argtypes) == 2 || return false
ary = argtypes[1]
dim = argtypes[2]
ary Array || return false
if isa(dim, Const)
dimval = dim.val
return isa(dimval, Int) && dimval > 0
end
return false
end

function pointer_eltype(@nospecialize(ptr))
a = widenconst(ptr)
Expand Down Expand Up @@ -1505,8 +1523,38 @@ function tuple_tfunc(argtypes::Vector{Any})
return anyinfo ? PartialStruct(typ, argtypes) : typ
end

function arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(a), @nospecialize i...)
a = widenconst(a)
arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(ary), @nospecialize idxs...) =
_arrayref_tfunc(boundscheck, ary, idxs)
function _arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(ary),
@nospecialize idxs::Tuple)
isempty(idxs) && return Bottom
array_builtin_common_errorcheck(boundscheck, ary, idxs) || return Bottom
return array_elmtype(ary)
end
add_tfunc(arrayref, 3, INT_INF, arrayref_tfunc, 20)
add_tfunc(const_arrayref, 3, INT_INF, arrayref_tfunc, 20)

function arrayset_tfunc(@nospecialize(boundscheck), @nospecialize(ary), @nospecialize(item),
@nospecialize idxs...)
hasintersect(widenconst(item), _arrayref_tfunc(boundscheck, ary, idxs)) || return Bottom
return ary
end
add_tfunc(arrayset, 4, INT_INF, arrayset_tfunc, 20)

function array_builtin_common_errorcheck(@nospecialize(boundscheck), @nospecialize(ary),
@nospecialize idxs::Tuple)
hasintersect(widenconst(boundscheck), Bool) || return false
hasintersect(widenconst(ary), Array) || return false
for i = 1:length(idxs)
idx = getfield(idxs, i)
idx = isvarargtype(idx) ? unwrapva(idx) : widenconst(idx)
hasintersect(idx, Int) || return false
end
return true
end

function array_elmtype(@nospecialize ary)
a = widenconst(ary)
if !has_free_typevars(a) && a <: Array
a0 = a
if isa(a, UnionAll)
Expand All @@ -1520,13 +1568,6 @@ function arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(a), @nospecial
end
return Any
end
add_tfunc(arrayref, 3, INT_INF, arrayref_tfunc, 20)
add_tfunc(const_arrayref, 3, INT_INF, arrayref_tfunc, 20)
function arrayset_tfunc(@nospecialize(boundscheck), @nospecialize(a), @nospecialize(v), @nospecialize i...)
# TODO: we could check that the type-intersect of arrayref_tfunc and v is non-empty or always throws
return a
end
add_tfunc(arrayset, 4, INT_INF, arrayset_tfunc, 20)

function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva),
@nospecialize(lb), @nospecialize(ub), @nospecialize(source), env::Vector{Any},
Expand Down Expand Up @@ -1563,16 +1604,16 @@ end

function array_builtin_common_nothrow(argtypes::Vector{Any}, first_idx_idx::Int)
length(argtypes) >= 4 || return false
boundcheck = argtypes[1]
boundscheck = argtypes[1]
arytype = argtypes[2]
array_builtin_common_typecheck(boundcheck, arytype, argtypes, first_idx_idx) || return false
array_builtin_common_typecheck(boundscheck, arytype, argtypes, first_idx_idx) || return false
# If we could potentially throw undef ref errors, bail out now.
arytype = widenconst(arytype)
array_type_undefable(arytype) && return false
# If we have @inbounds (first argument is false), we're allowed to assume
# we don't throw bounds errors.
if isa(boundcheck, Const)
!(boundcheck.val::Bool) && return true
if isa(boundscheck, Const)
!(boundscheck.val::Bool) && return true
end
# Else we can't really say anything here
# TODO: In the future we may be able to track the shapes of arrays though
Expand All @@ -1581,9 +1622,9 @@ function array_builtin_common_nothrow(argtypes::Vector{Any}, first_idx_idx::Int)
end

function array_builtin_common_typecheck(
@nospecialize(boundcheck), @nospecialize(arytype),
@nospecialize(boundscheck), @nospecialize(arytype),
argtypes::Vector{Any}, first_idx_idx::Int)
(boundcheck Bool && arytype Array) || return false
(boundscheck Bool && arytype Array) || return false
for i = first_idx_idx:length(argtypes)
argtypes[i] Int || return false
end
Expand All @@ -1609,6 +1650,8 @@ function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecializ
return arrayset_typecheck(argtypes[2], argtypes[3])
elseif f === arrayref || f === const_arrayref
return array_builtin_common_nothrow(argtypes, 3)
elseif f === arraysize
return arraysize_nothrow(argtypes)
elseif f === Core._expr
length(argtypes) >= 1 || return false
return argtypes[1] Symbol
Expand Down
28 changes: 28 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,34 @@ using Core.Compiler: typeof_tfunc
f_typeof_tfunc(x) = typeof(x)
@test Base.return_types(f_typeof_tfunc, (Union{<:T, Int} where T<:Complex,)) == Any[Union{Type{Int}, Type{Complex{T}} where T<:Real}]

# arrayref / arrayset / arraysize
import Core.Compiler: Const, arrayref_tfunc, arrayset_tfunc, arraysize_tfunc
@test arrayref_tfunc(Const(true), Vector{Int}, Int) === Int
@test arrayref_tfunc(Const(true), Vector{<:Integer}, Int) === Integer
@test arrayref_tfunc(Const(true), Vector, Int) === Any
@test arrayref_tfunc(Const(true), Vector{Int}, Int, Vararg{Int}) === Int
@test arrayref_tfunc(Const(true), Vector{Int}, Vararg{Int}) === Int
@test arrayref_tfunc(Const(true), Vector{Int}) === Union{}
@test arrayref_tfunc(Const(true), String, Int) === Union{}
@test arrayref_tfunc(Const(true), Vector{Int}, Float64) === Union{}
@test arrayref_tfunc(Int, Vector{Int}, Int) === Union{}
@test arrayset_tfunc(Const(true), Vector{Int}, Int, Int) === Vector{Int}
let ua = Vector{<:Integer}
@test arrayset_tfunc(Const(true), ua, Int, Int) === ua
end
@test arrayset_tfunc(Const(true), Vector, Int, Int) === Vector
@test arrayset_tfunc(Const(true), Any, Int, Int) === Any
@test arrayset_tfunc(Const(true), Vector{String}, String, Int, Vararg{Int}) === Vector{String}
@test arrayset_tfunc(Const(true), Vector{String}, String, Vararg{Int}) === Vector{String}
@test arrayset_tfunc(Const(true), Vector{String}, String) === Union{}
@test arrayset_tfunc(Const(true), String, Char, Int) === Union{}
@test arrayset_tfunc(Const(true), Vector{Int}, Int, Float64) === Union{}
@test arrayset_tfunc(Int, Vector{Int}, Int, Int) === Union{}
@test arrayset_tfunc(Const(true), Vector{Int}, Float64, Int) === Union{}
@test arraysize_tfunc(Vector, Int) === Int
@test arraysize_tfunc(Vector, Float64) === Union{}
@test arraysize_tfunc(String, Int) === Union{}

function f23024(::Type{T}, ::Int) where T
1 + 1
end
Expand Down

0 comments on commit acb6e16

Please sign in to comment.