Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

small optimization to subtyping #41672

Merged
merged 1 commit into from
Aug 19, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 65 additions & 36 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@ extern "C" {
// TODO: the stack probably needs to be artificially large because of some
// deeper problem (see #21191) and could be shrunk once that is fixed
typedef struct {
int depth;
int more;
int16_t depth;
int16_t more;
int16_t used;
uint32_t stack[100]; // stack of bits represented as a bit vector
} jl_unionstate_t;

typedef struct {
int16_t depth;
int16_t more;
int16_t used;
void *stack;
} jl_saved_unionstate_t;

// Linked list storing the type variable environment. A new jl_varbinding_t
// is pushed for each UnionAll type we encounter. `lb` and `ub` are updated
// during the computation.
Expand All @@ -68,14 +76,14 @@ typedef struct jl_varbinding_t {
// and we would need to return `intersect(var,other)`. in this case
// we choose to over-estimate the intersection by returning the var.
int8_t constraintkind;
int depth0; // # of invariant constructors nested around the UnionAll type for this var
int8_t intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var
// when this variable's integer value is compared to that of another,
// it equals `other + offset`. used by vararg length parameters.
int offset;
int16_t offset;
// array of typevars that our bounds depend on, whose UnionAlls need to be
// moved outside ours.
jl_array_t *innervars;
int intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
struct jl_varbinding_t *prev;
} jl_varbinding_t;

Expand Down Expand Up @@ -129,6 +137,23 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
st->stack[i>>5] &= ~(1u<<(i&31));
}

#define push_unionstate(saved, src) \
do { \
(saved)->depth = (src)->depth; \
(saved)->more = (src)->more; \
(saved)->used = (src)->used; \
(saved)->stack = alloca(((src)->used+7)/8); \
memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \
} while (0);

#define pop_unionstate(dst, saved) \
do { \
(dst)->depth = (saved)->depth; \
(dst)->more = (saved)->more; \
(dst)->used = (saved)->used; \
memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \
} while (0);

typedef struct {
int8_t *buf;
int rdepth;
Expand Down Expand Up @@ -486,6 +511,10 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
do {
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0) {
Expand Down Expand Up @@ -514,20 +543,19 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
return 1;
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
return 0;
jl_unionstate_t oldLunions = e->Lunions;
jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int sub;
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Lunions.used = e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 0);

e->Runions = oldRunions;
e->Lunions = oldLunions;
pop_unionstate(&e->Runions, &oldRunions);
pop_unionstate(&e->Lunions, &oldLunions);
return sub;
}

Expand Down Expand Up @@ -731,8 +759,8 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
int ans;
Expand Down Expand Up @@ -1148,6 +1176,10 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
// union against the variable before trying to take it apart to see if there are any
// variables lurking inside.
jl_unionstate_t *state = &e->Runions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
Expand Down Expand Up @@ -1310,21 +1342,21 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
return 0;

jl_unionstate_t oldLunions = e->Lunions;
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
e->Lunions.used = 0;
int sub;

if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
jl_unionstate_t oldRunions = e->Runions;
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 2);

e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
}
else {
int lastset = 0;
Expand All @@ -1342,13 +1374,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
}

e->Lunions = oldLunions;
pop_unionstate(&e->Lunions, &oldLunions);
return sub && subtype(y, x, e, 0);
}

static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
{
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Runions.used = 0;
int lastset = 0;
while (1) {
e->Runions.depth = 0;
Expand Down Expand Up @@ -1379,7 +1411,7 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
JL_GC_PUSH1(&saved);
save_env(e, &saved, &se);

memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
e->Lunions.used = 0;
int lastset = 0;
int sub;
while (1) {
Expand Down Expand Up @@ -1415,6 +1447,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
e->emptiness_only = 0;
e->Lunions.depth = 0; e->Runions.depth = 0;
e->Lunions.more = 0; e->Runions.more = 0;
e->Lunions.used = 0; e->Runions.used = 0;
}

// subtyping entry points
Expand Down Expand Up @@ -2084,14 +2117,14 @@ static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e,
if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x))
return x;

jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth;
// TODO: this doesn't quite make sense
e->invdepth = e->Rinvdepth = d;

jl_value_t *res = intersect_all(x, y, e);

e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
e->invdepth = savedepth;
e->Rinvdepth = Rsavedepth;
return res;
Expand All @@ -2102,10 +2135,10 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t
if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) {
jl_value_t *a=NULL, *b=NULL;
JL_GC_PUSH2(&a, &b);
jl_unionstate_t oldRunions = e->Runions;
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e);
b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e);
e->Runions = oldRunions;
pop_unionstate(&e->Runions, &oldRunions);
jl_value_t *i = simple_join(a,b);
JL_GC_POP();
return i;
Expand Down Expand Up @@ -2600,8 +2633,8 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
jl_savedenv_t se, se2;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
res = intersect_unionall_(t, u, e, R, param, &vb);
Expand Down Expand Up @@ -3159,7 +3192,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
e->Runions.depth = 0;
e->Runions.more = 0;
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Runions.used = 0;
jl_value_t **is;
JL_GC_PUSHARGS(is, 3);
jl_value_t **saved = &is[2];
Expand All @@ -3176,11 +3209,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
save_env(e, saved, &se);
}
while (e->Runions.more) {
if (e->emptiness_only && ii != jl_bottom_type) {
free_env(&se);
JL_GC_POP();
return ii;
}
if (e->emptiness_only && ii != jl_bottom_type)
break;
e->Runions.depth = 0;
int set = e->Runions.more - 1;
e->Runions.more = 0;
Expand Down Expand Up @@ -3209,9 +3239,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
total_iter++;
if (niter > 3 || total_iter > 400000) {
free_env(&se);
JL_GC_POP();
return y;
ii = y;
break;
}
}
free_env(&se);
Expand Down