Skip to content

Commit

Permalink
fix #32386, type intersection bug in var bounds (#32425)
Browse files Browse the repository at this point in the history
might be related to #24333 and/or #21153
  • Loading branch information
JeffBezanson authored Jun 27, 2019
1 parent 8ff014f commit 0d9c72d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
78 changes: 52 additions & 26 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ typedef struct jl_stenv_t {
jl_value_t **envout; // for passing caller the computed bounds of right-side variables
int envsz; // length of envout
int envidx; // current index in envout
int invdepth; // current number of invariant constructors we're nested in
int invdepth; // # of invariant constructors we're nested in on the left
int Rinvdepth; // # of invariant constructors we're nested in on the right
int ignore_free; // treat free vars as black boxes; used during intersection
int intersection; // true iff subtype is being called from intersection
int emptiness_only; // true iff intersection only needs to test for emptiness
Expand Down Expand Up @@ -535,7 +536,7 @@ static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param)
{
if (vb != NULL && param) {
// saturate counters at 2; we don't need values bigger than that
if (param == 2 && e->invdepth > vb->depth0 && vb->occurs_inv < 2)
if (param == 2 && (vb->right ? e->Rinvdepth : e->invdepth) > vb->depth0 && vb->occurs_inv < 2)
vb->occurs_inv++;
else if (vb->occurs_cov < 2)
vb->occurs_cov++;
Expand All @@ -554,7 +555,7 @@ static int var_outside(jl_stenv_t *e, jl_tvar_t *x, jl_tvar_t *y)
return 0;
}

static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth);
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d);

// check that type var `b` is <: `a`, and update b's upper bound.
static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
Expand All @@ -572,7 +573,7 @@ static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
// for this to work we need to compute issub(left,right) before issub(right,left),
// since otherwise the issub(a, bb.ub) check in var_gt becomes vacuous.
if (e->intersection) {
jl_value_t *ub = intersect_aside(bb->ub, a, e, bb->depth0);
jl_value_t *ub = intersect_aside(bb->ub, a, e, 0, bb->depth0);
if (ub != (jl_value_t*)b)
bb->ub = ub;
}
Expand Down Expand Up @@ -687,7 +688,8 @@ typedef int (*tvar_callback)(void*, int8_t, jl_stenv_t *, int);

static int with_tvar(tvar_callback callback, void *context, jl_unionall_t *u, int8_t R, jl_stenv_t *e, int param)
{
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, NULL, 0, 0, 0, 0, e->invdepth, 0, NULL, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, NULL, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
int ans;
Expand Down Expand Up @@ -854,8 +856,10 @@ static int check_vararg_length(jl_value_t *v, ssize_t n, jl_stenv_t *e)
jl_value_t *nn = jl_box_long(n);
JL_GC_PUSH1(&nn);
e->invdepth++;
e->Rinvdepth++;
int ans = subtype(nn, N, e, 2) && subtype(N, nn, e, 0);
e->invdepth--;
e->Rinvdepth--;
JL_GC_POP();
if (!ans)
return 0;
Expand Down Expand Up @@ -955,6 +959,7 @@ static int subtype_tuple_varargs(struct subtype_tuple_env *env, jl_stenv_t *e, i

// Vararg{T,N} <: Vararg{T2,N2}; equate N and N2
e->invdepth++;
e->Rinvdepth++;
JL_GC_PUSH2(&xp1, &yp1);
if (jl_is_long(xp1) && env->vx != 1)
xp1 = jl_box_long(jl_unbox_long(xp1) - env->vx + 1);
Expand All @@ -963,6 +968,7 @@ static int subtype_tuple_varargs(struct subtype_tuple_env *env, jl_stenv_t *e, i
int ans = forall_exists_equal(xp1, yp1, e);
JL_GC_POP();
e->invdepth--;
e->Rinvdepth--;
return ans;
}

Expand Down Expand Up @@ -1146,9 +1152,11 @@ static int subtype_naked_vararg(jl_datatype_t *xd, jl_datatype_t *yd, jl_stenv_t
if (!subtype(xp1, yp1, e, 1)) return 0;
if (!subtype(xp1, yp1, e, 1)) return 0;
e->invdepth++;
e->Rinvdepth++;
// Vararg{T,N} <: Vararg{T2,N2}; equate N and N2
int ans = forall_exists_equal(xp2, yp2, e);
e->invdepth--;
e->Rinvdepth--;
return ans;
}

Expand Down Expand Up @@ -1287,13 +1295,15 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
size_t i, np = jl_nparams(xd);
int ans = 1;
e->invdepth++;
e->Rinvdepth++;
for (i=0; i < np; i++) {
jl_value_t *xi = jl_tparam(xd, i), *yi = jl_tparam(yd, i);
if (!(xi == yi || forall_exists_equal(xi, yi, e))) {
ans = 0; break;
}
}
e->invdepth--;
e->Rinvdepth--;
return ans;
}
if (jl_is_type(y))
Expand Down Expand Up @@ -1404,7 +1414,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
if (envsz)
memset(env, 0, envsz*sizeof(void*));
e->envidx = 0;
e->invdepth = 0;
e->invdepth = e->Rinvdepth = 0;
e->ignore_free = 0;
e->intersection = 0;
e->emptiness_only = 0;
Expand Down Expand Up @@ -1750,20 +1760,31 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env,
return subtype;
}

static int subtype_in_env(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
static int subtype_in_env_(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int invdepth, int Rinvdepth)
{
jl_stenv_t e2;
init_stenv(&e2, NULL, 0);
e2.vars = e->vars;
e2.intersection = e->intersection;
e2.ignore_free = e->ignore_free;
e2.invdepth = e->invdepth;
e2.invdepth = invdepth;
e2.Rinvdepth = Rinvdepth;
e2.envsz = e->envsz;
e2.envout = e->envout;
e2.envidx = e->envidx;
return forall_exists_subtype(x, y, &e2, 0);
}

static int subtype_in_env(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
return subtype_in_env_(x, y, e, e->invdepth, e->Rinvdepth);
}

static int subtype_bounds_in_env(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
{
return subtype_in_env_(x, y, e, R ? e->invdepth : d, R ? d : e->Rinvdepth);
}

JL_DLLEXPORT int jl_subtype(jl_value_t *x, jl_value_t *y)
{
return jl_subtype_env(x, y, NULL, 0);
Expand Down Expand Up @@ -1978,22 +1999,24 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e);

// intersect in nested union environment, similar to subtype_ccheck
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth)
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
{
// band-aid for #30335
if (x == (jl_value_t*)jl_any_type && !jl_is_typevar(y))
return y;
if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x))
return x;

int savedepth = e->invdepth;
jl_unionstate_t oldRunions = e->Runions;
e->invdepth = depth;
int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth;
// TODO: this doesn't quite make sense
e->invdepth = e->Rinvdepth = d;

jl_value_t *res = intersect_all(x, y, e);

e->Runions = oldRunions;
e->invdepth = savedepth;
e->Rinvdepth = Rsavedepth;
return res;
}

Expand Down Expand Up @@ -2037,12 +2060,12 @@ static jl_value_t *set_var_to_const(jl_varbinding_t *bb, jl_value_t *v JL_MAYBE_
return v;
}

static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e)
static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e, int R, int d)
{
jl_value_t *root=NULL; jl_savedenv_t se;
JL_GC_PUSH1(&root);
save_env(e, &root, &se);
int ret = subtype_in_env(a, b, e);
int ret = subtype_bounds_in_env(a, b, e, R, d);
restore_env(e, root, &se);
free(se.buf);
JL_GC_POP();
Expand All @@ -2064,7 +2087,7 @@ static void set_bound(jl_value_t **bound, jl_value_t *val, jl_tvar_t *v, jl_sten
}

// subtype, treating all vars as existential
static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
{
jl_varbinding_t *v = e->vars;
int len = 0;
Expand All @@ -2083,7 +2106,7 @@ static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *
v->right = 1;
v = v->prev;
}
int issub = subtype_in_env(x, y, e);
int issub = subtype_bounds_in_env(x, y, e, R, d);
n = 0; v = e->vars;
while (n < len) {
assert(v != NULL);
Expand All @@ -2098,15 +2121,15 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
{
jl_varbinding_t *bb = lookup(e, b);
if (bb == NULL)
return R ? intersect_aside(a, b->ub, e, 0) : intersect_aside(b->ub, a, e, 0);
return R ? intersect_aside(a, b->ub, e, 1, 0) : intersect_aside(b->ub, a, e, 0, 0);
if (bb->lb == bb->ub && jl_is_typevar(bb->lb) && bb->lb != (jl_value_t*)b)
return intersect(a, bb->lb, e, param);
if (!jl_is_type(a) && !jl_is_typevar(a))
return set_var_to_const(bb, a, NULL);
int d = bb->depth0;
jl_value_t *root=NULL; jl_savedenv_t se;
if (param == 2) {
if (!(subtype_in_env_existential(bb->lb, a, e) && subtype_in_env_existential(a, bb->ub, e)))
if (!(subtype_in_env_existential(bb->lb, a, e, 0, d) && subtype_in_env_existential(a, bb->ub, e, 1, d)))
return jl_bottom_type;
jl_value_t *ub = a;
if (ub != (jl_value_t*)b) {
Expand All @@ -2131,17 +2154,17 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
}
else if (bb->constraintkind == 0) {
if (!jl_is_typevar(bb->ub) && !jl_is_typevar(a)) {
if (try_subtype_in_env(bb->ub, a, e))
if (try_subtype_in_env(bb->ub, a, e, 0, d))
return (jl_value_t*)b;
}
return R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
return R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
}
else if (bb->concrete || bb->constraintkind == 1) {
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
JL_GC_PUSH1(&ub);
if (ub == jl_bottom_type ||
// this fixes issue #30122. TODO: better fix for R flag.
(!R && !subtype_in_env(bb->lb, a, e))) {
(!R && !subtype_bounds_in_env(bb->lb, a, e, 0, d))) {
JL_GC_POP();
return jl_bottom_type;
}
Expand All @@ -2152,7 +2175,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
else if (bb->constraintkind == 2) {
// TODO: removing this case fixes many test_brokens in test/subtype.jl
// but breaks other tests.
if (!subtype_in_env(a, bb->ub, e)) {
if (!subtype_bounds_in_env(a, bb->ub, e, 1, d)) {
// mark var as unsatisfiable by making it circular
bb->lb = (jl_value_t*)b;
return jl_bottom_type;
Expand All @@ -2162,7 +2185,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
return a;
}
assert(bb->constraintkind == 3);
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
if (ub == jl_bottom_type)
return jl_bottom_type;
if (jl_is_typevar(a))
Expand All @@ -2180,7 +2203,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
root = NULL;
JL_GC_PUSH2(&root, &ub);
save_env(e, &root, &se);
jl_value_t *ii = R ? intersect_aside(a, bb->lb, e, d) : intersect_aside(bb->lb, a, e, d);
jl_value_t *ii = R ? intersect_aside(a, bb->lb, e, 1, d) : intersect_aside(bb->lb, a, e, 0, d);
if (ii == jl_bottom_type) {
restore_env(e, root, &se);
ii = (jl_value_t*)b;
Expand Down Expand Up @@ -2414,7 +2437,8 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
jl_savedenv_t se, se2;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, NULL, 0, 0, 0, 0, e->invdepth, 0, NULL, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, NULL, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
res = intersect_unionall_(t, u, e, R, param, &vb);
Expand Down Expand Up @@ -2619,8 +2643,10 @@ static jl_value_t *intersect_invariant(jl_value_t *x, jl_value_t *y, jl_stenv_t
return (jl_subtype(x,y) && jl_subtype(y,x)) ? y : NULL;
}
e->invdepth++;
e->Rinvdepth++;
jl_value_t *ii = intersect(x, y, e, 2);
e->invdepth--;
e->Rinvdepth--;
if (jl_is_typevar(x) && jl_is_typevar(y) && (jl_is_typevar(ii) || !jl_is_type(ii)))
return ii;
if (ii == jl_bottom_type) {
Expand Down Expand Up @@ -2748,7 +2774,7 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
}
jl_value_t *ub=NULL, *lb=NULL;
JL_GC_PUSH2(&lb, &ub);
ub = intersect_aside(xub, yub, e, xx ? xx->depth0 : 0);
ub = intersect_aside(xub, yub, e, 0, xx ? xx->depth0 : 0);
lb = simple_join(xlb, ylb);
if (yy) {
if (lb != y)
Expand Down
5 changes: 5 additions & 0 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1606,3 +1606,8 @@ end
# S = Type{T} where T<:Tuple{E, Vararg{E}} where E
# @test @elapsed (@test T != S) < 5
#end

# issue #32386
# TODO: intersect currently returns a bad answer here (it has free typevars)
@test typeintersect(Type{S} where S<:(Array{Pair{_A,N} where N, 1} where _A),
Type{Vector{T}} where T) != Union{}

0 comments on commit 0d9c72d

Please sign in to comment.