Skip to content

Commit

Permalink
inference: add error/throwness checks for setfield!
Browse files Browse the repository at this point in the history
In a similar spirit to #43587, this commit introduces error check for
`setfield!`. We can bail out from inference if we can prove either of:
- the object is not mutable type
- the object is `Module` object
- the value being assigned is incompatible with the declared type of
  object field

This commit also adds the throwness check for `setfield!` (i.e. `setfield!_nothrow`).
This throwness check won't be used in the current native compilation
pipeline since `setfield!` call can't be eliminated even if we can prove
that it never throws. But this throwness check would be used by
EscapeAnalysis.jl integration and so I'd like to include it in Base.
  • Loading branch information
aviatesk committed Jan 4, 2022
1 parent 754ce5d commit ee72a2b
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 11 deletions.
83 changes: 73 additions & 10 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,8 @@ end

getfield_tfunc(s00, name, boundscheck_or_order) = (@nospecialize; getfield_tfunc(s00, name))
getfield_tfunc(s00, name, order, boundscheck) = (@nospecialize; getfield_tfunc(s00, name))
function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
getfield_tfunc(@nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(s00, name, false)
function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool)
s = unwrap_unionall(s00)
if isa(s, Union)
return tmerge(getfield_tfunc(rewrap_unionall(s.a, s00), name),
Expand All @@ -774,6 +775,7 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
if isa(name, Const)
nv = name.val
if isa(sv, Module)
setfield && return Bottom
if isa(nv, Symbol)
return abstract_eval_global(sv, nv)
end
Expand Down Expand Up @@ -817,9 +819,8 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
return Bottom
end
if s <: Module
if !(Symbol <: widenconst(name))
return Bottom
end
setfield && return Bottom
hasintersect(widenconst(name), Symbol) || return Bottom
return Any
end
if s.name === _NAMEDTUPLE_NAME && !isconcretetype(s)
Expand All @@ -840,9 +841,10 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
return getfield_tfunc(_ts, name)
end
ftypes = datatype_fieldtypes(s)
nf = length(ftypes)
# If no value has this type, then this statement should be unreachable.
# Bail quickly now.
if !has_concrete_subtype(s) || isempty(ftypes)
if !has_concrete_subtype(s) || nf == 0
return Bottom
end
if isa(name, Conditional)
Expand All @@ -853,12 +855,14 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
if !(Int <: name || Symbol <: name)
return Bottom
end
if length(ftypes) == 1
if nf == 1
return rewrap_unionall(unwrapva(ftypes[1]), s00)
end
# union together types of all fields
t = Bottom
for _ft in ftypes
for i in 1:nf
_ft = ftypes[i]
setfield && isconst(s, i) && continue
t = tmerge(t, rewrap_unionall(unwrapva(_ft), s00))
t === Any && break
end
Expand All @@ -871,12 +875,13 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
if !isa(fld, Int)
return Bottom
end
nf = length(ftypes)
if s <: Tuple && fld >= nf && isvarargtype(ftypes[nf])
return rewrap_unionall(unwrapva(ftypes[nf]), s00)
end
if fld < 1 || fld > nf
return Bottom
elseif setfield && isconst(s, fld)
return Bottom
end
R = ftypes[fld]
if isempty(s.parameters)
Expand All @@ -885,8 +890,66 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
return rewrap_unionall(R, s00)
end

setfield!_tfunc(o, f, v, order) = (@nospecialize; v)
setfield!_tfunc(o, f, v) = (@nospecialize; v)
function setfield!_tfunc(o, f, v, order)
@nospecialize
if !isvarargtype(order)
hasintersect(widenconst(order), Symbol) || return Bottom
end
return setfield!_tfunc(o, f, v)
end
function setfield!_tfunc(o, f, v)
@nospecialize
mutability_errorcheck(o) || return Bottom
ft = _getfield_tfunc(o, f, true)
ft === Bottom && return Bottom
hasintersect(widenconst(v), widenconst(ft)) || return Bottom
return v
end
function mutability_errorcheck(@nospecialize obj)
objt0 = widenconst(obj)
objt = unwrap_unionall(objt0)
if isa(objt, Union)
return mutability_errorcheck(rewrap_unionall(objt.a, objt0)) ||
mutability_errorcheck(rewrap_unionall(objt.b, objt0))
elseif isa(objt, DataType)
# Can't say anything about abstract types
isabstracttype(objt) && return true
return ismutabletype(objt)
end
return true
end

function setfield!_nothrow(argtypes::Vector{Any})
if length(argtypes) == 4
order = argtypes[4]
order === Const(:non_atomic) || return false # TODO: this is assuming not atomic
else
length(argtypes) == 3 || return false
end
return setfield!_nothrow(argtypes[1], argtypes[2], argtypes[3])
end
function setfield!_nothrow(s00, name, v)
@nospecialize
s0 = widenconst(s00)
s = unwrap_unionall(s0)
if isa(s, Union)
return setfield!_nothrow(rewrap_unionall(s.a, s00), name, v) &&
setfield!_nothrow(rewrap_unionall(s.b, s00), name, v)
elseif isa(s, DataType)
# Can't say anything about abstract types
isabstracttype(s) && return false
ismutabletype(s) || return false
s.name.atomicfields == C_NULL || return false # TODO: currently we're only testing for ordering == :not_atomic
isa(name, Const) || return false
field = try_compute_fieldidx(s, name.val)
field === nothing && return false
# `try_compute_fieldidx` already check for field index bound.
isconst(s, field) && return false
v_expected = fieldtype(s0, field)
return v v_expected
end
return false
end

swapfield!_tfunc(o, f, v, order) = (@nospecialize; getfield_tfunc(o, f))
swapfield!_tfunc(o, f, v) = (@nospecialize; getfield_tfunc(o, f))
Expand Down
108 changes: 107 additions & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,112 @@ import Core.Compiler.getfield_tfunc
@test getfield_tfunc(NamedTuple{<:Any, T} where {T <: Tuple{Int, Union{Float64, Missing}}},
Const(:x)) == Union{Missing, Float64, Int}

import Core.Compiler: setfield!_tfunc, setfield!_nothrow, Const
mutable struct XY{X,Y}
x::X
y::Y
end
mutable struct ABCDconst
const a
const b::Int
c
const d::Union{Int,Nothing}
end
@test setfield!_tfunc(Base.RefValue{Int}, Const(:x), Int) === Int
@test setfield!_tfunc(Base.RefValue{Int}, Const(:x), Int, Symbol) === Int
@test setfield!_tfunc(Base.RefValue{Int}, Const(1), Int) === Int
@test setfield!_tfunc(Base.RefValue{Int}, Const(1), Int, Symbol) === Int
@test setfield!_tfunc(Base.RefValue{Int}, Int, Int) === Int
@test setfield!_tfunc(Base.RefValue{Any}, Const(:x), Int) === Int
@test setfield!_tfunc(Base.RefValue{Any}, Const(:x), Int, Symbol) === Int
@test setfield!_tfunc(Base.RefValue{Any}, Const(1), Int) === Int
@test setfield!_tfunc(Base.RefValue{Any}, Const(1), Int, Symbol) === Int
@test setfield!_tfunc(Base.RefValue{Any}, Int, Int) === Int
@test setfield!_tfunc(XY{Any,Any}, Const(1), Int) === Int
@test setfield!_tfunc(XY{Any,Any}, Const(2), Float64) === Float64
@test setfield!_tfunc(XY{Int,Float64}, Const(1), Int) === Int
@test setfield!_tfunc(XY{Int,Float64}, Const(2), Float64) === Float64
@test setfield!_tfunc(ABCDconst, Const(:c), Any) === Any
@test setfield!_tfunc(ABCDconst, Const(3), Any) === Any
@test setfield!_tfunc(ABCDconst, Symbol, Any) === Any
@test setfield!_tfunc(ABCDconst, Int, Any) === Any
@test setfield!_tfunc(Union{Base.RefValue{Any},Some{Any}}, Const(:x), Int) === Int
@test setfield!_tfunc(Union{Base.RefValue,Some{Any}}, Const(:x), Int) === Int
@test setfield!_tfunc(Union{Base.RefValue{Any},Some{Any}}, Const(1), Int) === Int
@test setfield!_tfunc(Union{Base.RefValue,Some{Any}}, Const(1), Int) === Int
@test setfield!_tfunc(Union{Base.RefValue{Any},Some{Any}}, Symbol, Int) === Int
@test setfield!_tfunc(Union{Base.RefValue,Some{Any}}, Symbol, Int) === Int
@test setfield!_tfunc(Union{Base.RefValue{Any},Some{Any}}, Int, Int) === Int
@test setfield!_tfunc(Union{Base.RefValue,Some{Any}}, Int, Int) === Int
@test setfield!_tfunc(Any, Symbol, Int) === Int
@test setfield!_tfunc(Any, Int, Int) === Int
@test setfield!_tfunc(Any, Any, Int) === Int
@test setfield!_tfunc(Base.RefValue{Int}, Const(:x), Float64) === Union{}
@test setfield!_tfunc(Base.RefValue{Int}, Const(:x), Float64, Symbol) === Union{}
@test setfield!_tfunc(Base.RefValue{Int}, Const(1), Float64) === Union{}
@test setfield!_tfunc(Base.RefValue{Int}, Const(1), Float64, Symbol) === Union{}
@test setfield!_tfunc(Base.RefValue{Int}, Int, Float64) === Union{}
@test setfield!_tfunc(Base.RefValue{Any}, Const(:y), Int) === Union{}
@test setfield!_tfunc(Base.RefValue{Any}, Const(:y), Int, Bool) === Union{}
@test setfield!_tfunc(Base.RefValue{Any}, Const(2), Int) === Union{}
@test setfield!_tfunc(Base.RefValue{Any}, Const(2), Int, Bool) === Union{}
@test setfield!_tfunc(Base.RefValue{Any}, String, Int) === Union{}
@test setfield!_tfunc(Some{Any}, Const(:value), Int) === Union{}
@test setfield!_tfunc(Some, Const(:value), Int) === Union{}
@test setfield!_tfunc(Some{Any}, Const(1), Int) === Union{}
@test setfield!_tfunc(Some, Const(1), Int) === Union{}
@test setfield!_tfunc(Some{Any}, Symbol, Int) === Union{}
@test setfield!_tfunc(Some, Symbol, Int) === Union{}
@test setfield!_tfunc(Some{Any}, Int, Int) === Union{}
@test setfield!_tfunc(Some, Int, Int) === Union{}
@test setfield!_tfunc(Const(@__MODULE__), Const(:v), Int) === Union{}
@test setfield!_tfunc(Const(@__MODULE__), Int, Int) === Union{}
@test setfield!_tfunc(Module, Const(:v), Int) === Union{}
@test setfield!_tfunc(ABCDconst, Const(:a), Any) === Union{}
@test setfield!_tfunc(ABCDconst, Const(:b), Any) === Union{}
@test setfield!_tfunc(ABCDconst, Const(:d), Any) === Union{}
@test setfield!_tfunc(ABCDconst, Const(1), Any) === Union{}
@test setfield!_tfunc(ABCDconst, Const(2), Any) === Union{}
@test setfield!_tfunc(ABCDconst, Const(4), Any) === Union{}
@test setfield!_nothrow(Base.RefValue{Int}, Const(:x), Int)
@test setfield!_nothrow(Base.RefValue{Int}, Const(1), Int)
@test setfield!_nothrow(Base.RefValue{Any}, Const(:x), Int)
@test setfield!_nothrow(Base.RefValue{Any}, Const(1), Int)
@test setfield!_nothrow(XY{Any,Any}, Const(:x), Int)
@test setfield!_nothrow(XY{Any,Any}, Const(:x), Any)
@test setfield!_nothrow(XY{Int,Float64}, Const(:x), Int)
@test setfield!_nothrow(ABCDconst, Const(:c), Any)
@test setfield!_nothrow(ABCDconst, Const(3), Any)
@test !setfield!_nothrow(XY{Int,Float64}, Symbol, Any)
@test !setfield!_nothrow(XY{Int,Float64}, Int, Any)
@test !setfield!_nothrow(Base.RefValue{Int}, Const(:x), Any)
@test !setfield!_nothrow(Base.RefValue{Int}, Const(1), Any)
@test !setfield!_nothrow(Any[Base.RefValue{Any}, Const(:x), Int, Symbol])
@test !setfield!_nothrow(Base.RefValue{Any}, Symbol, Int)
@test !setfield!_nothrow(Base.RefValue{Any}, Int, Int)
@test !setfield!_nothrow(XY{Int,Float64}, Const(:y), Int)
@test !setfield!_nothrow(XY{Int,Float64}, Symbol, Int)
@test !setfield!_nothrow(XY{Int,Float64}, Int, Int)
@test !setfield!_nothrow(ABCDconst, Const(:a), Any)
@test !setfield!_nothrow(ABCDconst, Const(:b), Any)
@test !setfield!_nothrow(ABCDconst, Const(:d), Any)
@test !setfield!_nothrow(ABCDconst, Symbol, Any)
@test !setfield!_nothrow(ABCDconst, Const(1), Any)
@test !setfield!_nothrow(ABCDconst, Const(2), Any)
@test !setfield!_nothrow(ABCDconst, Const(4), Any)
@test !setfield!_nothrow(ABCDconst, Int, Any)
@test !setfield!_nothrow(Union{Base.RefValue{Any},Some{Any}}, Const(:x), Int)
@test !setfield!_nothrow(Union{Base.RefValue,Some{Any}}, Const(:x), Int)
@test !setfield!_nothrow(Union{Base.RefValue{Any},Some{Any}}, Const(1), Int)
@test !setfield!_nothrow(Union{Base.RefValue,Some{Any}}, Const(1), Int)
@test !setfield!_nothrow(Union{Base.RefValue{Any},Some{Any}}, Symbol, Int)
@test !setfield!_nothrow(Union{Base.RefValue,Some{Any}}, Symbol, Int)
@test !setfield!_nothrow(Union{Base.RefValue{Any},Some{Any}}, Int, Int)
@test !setfield!_nothrow(Union{Base.RefValue,Some{Any}}, Int, Int)
@test !setfield!_nothrow(Any, Symbol, Int)
@test !setfield!_nothrow(Any, Int, Int)
@test !setfield!_nothrow(Any, Any, Int)

struct Foo_22708
x::Ptr{Foo_22708}
end
Expand Down Expand Up @@ -3164,7 +3270,7 @@ end
@test Core.Compiler.return_type(apply26826, Tuple{typeof(===), Any, Vararg}) == Bool
@test Core.Compiler.return_type(apply26826, Tuple{typeof(===), Any, Any, Vararg}) == Bool
@test Core.Compiler.return_type(apply26826, Tuple{typeof(===), Any, Any, Any, Vararg}) == Union{}
@test Core.Compiler.return_type(apply26826, Tuple{typeof(setfield!), Vararg{Symbol}}) == Symbol
@test Core.Compiler.return_type(apply26826, Tuple{typeof(setfield!), Vararg{Symbol}}) == Union{}
@test Core.Compiler.return_type(apply26826, Tuple{typeof(setfield!), Any, Vararg{Symbol}}) == Symbol
@test Core.Compiler.return_type(apply26826, Tuple{typeof(setfield!), Any, Symbol, Vararg{Integer}}) == Integer
@test Core.Compiler.return_type(apply26826, Tuple{typeof(setfield!), Any, Symbol, Integer, Vararg}) == Integer
Expand Down

0 comments on commit ee72a2b

Please sign in to comment.