Skip to content

Commit

Permalink
small optimization to subtyping
Browse files Browse the repository at this point in the history
Zero and copy only the used portion of the union state buffer.
  • Loading branch information
JeffBezanson committed Jul 21, 2021
1 parent 55a873e commit 5642eae
Showing 1 changed file with 46 additions and 29 deletions.
75 changes: 46 additions & 29 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ extern "C" {
// TODO: the stack probably needs to be artificially large because of some
// deeper problem (see #21191) and could be shrunk once that is fixed
typedef struct {
int depth;
int more;
uint32_t stack[100]; // stack of bits represented as a bit vector
int16_t depth;
int16_t more;
int16_t used;
uint32_t stack[80]; // stack of bits represented as a bit vector
} jl_unionstate_t;

// Linked list storing the type variable environment. A new jl_varbinding_t
Expand All @@ -68,14 +69,14 @@ typedef struct jl_varbinding_t {
// and we would need to return `intersect(var,other)`. in this case
// we choose to over-estimate the intersection by returning the var.
int8_t constraintkind;
int depth0; // # of invariant constructors nested around the UnionAll type for this var
int8_t intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var
// when this variable's integer value is compared to that of another,
// it equals `other + offset`. used by vararg length parameters.
int offset;
int16_t offset;
// array of typevars that our bounds depend on, whose UnionAlls need to be
// moved outside ours.
jl_array_t *innervars;
int intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
struct jl_varbinding_t *prev;
} jl_varbinding_t;

Expand Down Expand Up @@ -129,6 +130,14 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
st->stack[i>>5] &= ~(1u<<(i&31));
}

static void copy_unionstate(jl_unionstate_t *dst, jl_unionstate_t *src) JL_NOTSAFEPOINT
{
dst->depth = src->depth;
dst->more = src->more;
dst->used = src->used;
memcpy(&dst->stack, &src->stack, (src->used+7)/8);
}

typedef struct {
int8_t *buf;
int rdepth;
Expand Down Expand Up @@ -486,6 +495,10 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
do {
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0) {
Expand Down Expand Up @@ -514,20 +527,19 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
return 1;
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
return 0;
jl_unionstate_t oldLunions = e->Lunions;
jl_unionstate_t oldRunions = e->Runions;
jl_unionstate_t oldLunions; copy_unionstate(&oldLunions, &e->Lunions);
jl_unionstate_t oldRunions; copy_unionstate(&oldRunions, &e->Runions);
int sub;
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Lunions.used = e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 0);

e->Runions = oldRunions;
e->Lunions = oldLunions;
copy_unionstate(&e->Runions, &oldRunions);
copy_unionstate(&e->Lunions, &oldLunions);
return sub;
}

Expand Down Expand Up @@ -731,8 +743,8 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
int ans;
Expand Down Expand Up @@ -1148,6 +1160,10 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
// union against the variable before trying to take it apart to see if there are any
// variables lurking inside.
jl_unionstate_t *state = &e->Runions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
Expand Down Expand Up @@ -1310,21 +1326,21 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
return 0;

jl_unionstate_t oldLunions = e->Lunions;
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
jl_unionstate_t oldLunions; copy_unionstate(&oldLunions, &e->Lunions);
e->Lunions.used = 0;
int sub;

if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
jl_unionstate_t oldRunions = e->Runions;
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
jl_unionstate_t oldRunions; copy_unionstate(&oldRunions, &e->Runions);
e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 2);

e->Runions = oldRunions;
copy_unionstate(&e->Runions, &oldRunions);
}
else {
int lastset = 0;
Expand All @@ -1342,13 +1358,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
}

e->Lunions = oldLunions;
copy_unionstate(&e->Lunions, &oldLunions);
return sub && subtype(y, x, e, 0);
}

static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
{
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Runions.used = 0;
int lastset = 0;
while (1) {
e->Runions.depth = 0;
Expand Down Expand Up @@ -1379,7 +1395,7 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
JL_GC_PUSH1(&saved);
save_env(e, &saved, &se);

memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
e->Lunions.used = 0;
int lastset = 0;
int sub;
while (1) {
Expand Down Expand Up @@ -1415,6 +1431,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
e->emptiness_only = 0;
e->Lunions.depth = 0; e->Runions.depth = 0;
e->Lunions.more = 0; e->Runions.more = 0;
e->Lunions.used = 0; e->Runions.used = 0;
}

// subtyping entry points
Expand Down Expand Up @@ -2084,14 +2101,14 @@ static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e,
if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x))
return x;

jl_unionstate_t oldRunions = e->Runions;
jl_unionstate_t oldRunions; copy_unionstate(&oldRunions, &e->Runions);
int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth;
// TODO: this doesn't quite make sense
e->invdepth = e->Rinvdepth = d;

jl_value_t *res = intersect_all(x, y, e);

e->Runions = oldRunions;
copy_unionstate(&e->Runions, &oldRunions);
e->invdepth = savedepth;
e->Rinvdepth = Rsavedepth;
return res;
Expand All @@ -2102,10 +2119,10 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t
if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) {
jl_value_t *a=NULL, *b=NULL;
JL_GC_PUSH2(&a, &b);
jl_unionstate_t oldRunions = e->Runions;
jl_unionstate_t oldRunions; copy_unionstate(&oldRunions, &e->Runions);
a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e);
b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e);
e->Runions = oldRunions;
copy_unionstate(&e->Runions, &oldRunions);
jl_value_t *i = simple_join(a,b);
JL_GC_POP();
return i;
Expand Down Expand Up @@ -2580,8 +2597,8 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
jl_savedenv_t se, se2;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
res = intersect_unionall_(t, u, e, R, param, &vb);
Expand Down Expand Up @@ -3152,7 +3169,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
e->Runions.depth = 0;
e->Runions.more = 0;
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
e->Runions.used = 0;
jl_value_t **is;
JL_GC_PUSHARGS(is, 3);
jl_value_t **saved = &is[2];
Expand Down

0 comments on commit 5642eae

Please sign in to comment.