From f3a36d74eeb1f8c6439affcc33e2a304550dc217 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Fri, 11 Oct 2024 23:16:26 +0800 Subject: [PATCH] Subtype: some performance tuning. (#56007) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The main motivation of this PR is to fix #55807. dc689fe8700f70f4a4e2dbaaf270f26b87e79e04 tries to remove the slow `may_contain_union_decision` check by re-organizing the code path. Now the fast path has been removed and most of its optimization has been integrated into the preserved slow path. Since the slow path stores all inner ∃ decisions on the outer most R stack, there might be overflow risk. aee69a41441b4306ba3ee5e845bc96cb45d9b327 should fix that concern. The reported MWE now becomes ```julia 0.000002 seconds 0.000040 seconds (105 allocations: 4.828 KiB, 52.00% compilation time) 0.000023 seconds (105 allocations: 4.828 KiB, 49.36% compilation time) 0.000026 seconds (105 allocations: 4.828 KiB, 50.38% compilation time) 0.000027 seconds (105 allocations: 4.828 KiB, 54.95% compilation time) 0.000019 seconds (106 allocations: 4.922 KiB, 49.73% compilation time) 0.000024 seconds (105 allocations: 4.828 KiB, 52.24% compilation time) ``` Local bench also shows that 72855cd slightly accelerates `OmniPackage.jl`'s loading ```julia julia> @time using OmniPackage # v1.11rc4 20.525278 seconds (25.36 M allocations: 1.606 GiB, 8.48% gc time, 12.89% compilation time: 77% of which was recompilation) # v1.11rc4+aee69a4+72855cd 19.527871 seconds (24.92 M allocations: 1.593 GiB, 8.88% gc time, 15.13% compilation time: 82% of which was recompilation) ``` --- src/subtype.c | 298 +++++++++++++++++++++++++++++--------------------- 1 file changed, 173 insertions(+), 125 deletions(-) diff --git a/src/subtype.c b/src/subtype.c index 65ee4d5916bce..5edcd100ee8e0 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -39,20 +39,24 @@ extern "C" { // Union type decision points are discovered while the algorithm works. // If a new Union decision is encountered, the `more` flag is set to tell // the forall/exists loop to grow the stack. -// TODO: the stack probably needs to be artificially large because of some -// deeper problem (see #21191) and could be shrunk once that is fixed + +typedef struct jl_bits_stack_t { + uint32_t data[16]; + struct jl_bits_stack_t *next; +} jl_bits_stack_t; + typedef struct { int16_t depth; int16_t more; int16_t used; - uint32_t stack[100]; // stack of bits represented as a bit vector + jl_bits_stack_t stack; } jl_unionstate_t; typedef struct { int16_t depth; int16_t more; int16_t used; - void *stack; + uint8_t *stack; } jl_saved_unionstate_t; // Linked list storing the type variable environment. A new jl_varbinding_t @@ -131,37 +135,111 @@ static jl_varbinding_t *lookup(jl_stenv_t *e, jl_tvar_t *v) JL_GLOBALLY_ROOTED J } #endif +// union-stack tools + static int statestack_get(jl_unionstate_t *st, int i) JL_NOTSAFEPOINT { - assert(i >= 0 && i < sizeof(st->stack) * 8); + assert(i >= 0 && i <= 32767); // limited by the depth bit. // get the `i`th bit in an array of 32-bit words - return (st->stack[i>>5] & (1u<<(i&31))) != 0; + jl_bits_stack_t *stack = &st->stack; + while (i >= sizeof(stack->data) * 8) { + // We should have set this bit. + assert(stack->next); + stack = stack->next; + i -= sizeof(stack->data) * 8; + } + return (stack->data[i>>5] & (1u<<(i&31))) != 0; } static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT { - assert(i >= 0 && i < sizeof(st->stack) * 8); + assert(i >= 0 && i <= 32767); // limited by the depth bit. + jl_bits_stack_t *stack = &st->stack; + while (i >= sizeof(stack->data) * 8) { + if (__unlikely(stack->next == NULL)) { + stack->next = (jl_bits_stack_t *)malloc(sizeof(jl_bits_stack_t)); + stack->next->next = NULL; + } + stack = stack->next; + i -= sizeof(stack->data) * 8; + } if (val) - st->stack[i>>5] |= (1u<<(i&31)); + stack->data[i>>5] |= (1u<<(i&31)); else - st->stack[i>>5] &= ~(1u<<(i&31)); + stack->data[i>>5] &= ~(1u<<(i&31)); +} + +#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0) + +static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT +{ + jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; + if (state->more == 0) + return 0; + // reset `used` and let `pick_union_decision` clean the stack. + state->used = state->more; + statestack_set(state, state->used - 1, 1); + return 1; } -#define push_unionstate(saved, src) \ - do { \ - (saved)->depth = (src)->depth; \ - (saved)->more = (src)->more; \ - (saved)->used = (src)->used; \ - (saved)->stack = alloca(((src)->used+7)/8); \ - memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \ +static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT +{ + jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; + if (state->depth >= state->used) { + statestack_set(state, state->used, 0); + state->used++; + } + int ui = statestack_get(state, state->depth); + state->depth++; + if (ui == 0) + state->more = state->depth; // memorize that this was the deepest available choice + return ui; +} + +static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT +{ + do { + if (pick_union_decision(e, R)) + u = ((jl_uniontype_t*)u)->b; + else + u = ((jl_uniontype_t*)u)->a; + } while (jl_is_uniontype(u)); + return u; +} + +#define push_unionstate(saved, src) \ + do { \ + (saved)->depth = (src)->depth; \ + (saved)->more = (src)->more; \ + (saved)->used = (src)->used; \ + jl_bits_stack_t *srcstack = &(src)->stack; \ + int pushbits = ((saved)->used+7)/8; \ + (saved)->stack = (uint8_t *)alloca(pushbits); \ + for (int n = 0; n < pushbits; n += sizeof(srcstack->data)) { \ + assert(srcstack != NULL); \ + int rest = pushbits - n; \ + if (rest > sizeof(srcstack->data)) \ + rest = sizeof(srcstack->data); \ + memcpy(&(saved)->stack[n], &srcstack->data, rest); \ + srcstack = srcstack->next; \ + } \ } while (0); -#define pop_unionstate(dst, saved) \ - do { \ - (dst)->depth = (saved)->depth; \ - (dst)->more = (saved)->more; \ - (dst)->used = (saved)->used; \ - memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \ +#define pop_unionstate(dst, saved) \ + do { \ + (dst)->depth = (saved)->depth; \ + (dst)->more = (saved)->more; \ + (dst)->used = (saved)->used; \ + jl_bits_stack_t *dststack = &(dst)->stack; \ + int popbits = ((saved)->used+7)/8; \ + for (int n = 0; n < popbits; n += sizeof(dststack->data)) { \ + assert(dststack != NULL); \ + int rest = popbits - n; \ + if (rest > sizeof(dststack->data)) \ + rest = sizeof(dststack->data); \ + memcpy(&dststack->data, &(saved)->stack[n], rest); \ + dststack = dststack->next; \ + } \ } while (0); static int current_env_length(jl_stenv_t *e) @@ -264,6 +342,18 @@ static void free_env(jl_savedenv_t *se) JL_NOTSAFEPOINT se->buf = NULL; } +static void free_stenv(jl_stenv_t *e) JL_NOTSAFEPOINT +{ + for (int R = 0; R < 2; R++) { + jl_bits_stack_t *temp = R ? e->Runions.stack.next : e->Lunions.stack.next; + while (temp != NULL) { + jl_bits_stack_t *next = temp->next; + free(temp); + temp = next; + } + } +} + static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPOINT { jl_value_t **roots = NULL; @@ -587,44 +677,6 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi) static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param); -#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0) - -static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT -{ - jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; - if (state->more == 0) - return 0; - // reset `used` and let `pick_union_decision` clean the stack. - state->used = state->more; - statestack_set(state, state->used - 1, 1); - return 1; -} - -static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT -{ - jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; - if (state->depth >= state->used) { - statestack_set(state, state->used, 0); - state->used++; - } - int ui = statestack_get(state, state->depth); - state->depth++; - if (ui == 0) - state->more = state->depth; // memorize that this was the deepest available choice - return ui; -} - -static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT -{ - do { - if (pick_union_decision(e, R)) - u = ((jl_uniontype_t*)u)->b; - else - u = ((jl_uniontype_t*)u)->a; - } while (jl_is_uniontype(u)); - return u; -} - static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow); // subtype for variable bounds consistency check. needs its own forall/exists environment. @@ -1513,37 +1565,12 @@ static int is_definite_length_tuple_type(jl_value_t *x) return k == JL_VARARG_NONE || k == JL_VARARG_INT; } -static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore); - -static int may_contain_union_decision(jl_value_t *x, jl_stenv_t *e, jl_typeenv_t *log) JL_NOTSAFEPOINT +static int is_exists_typevar(jl_value_t *x, jl_stenv_t *e) { - if (x == NULL || x == (jl_value_t*)jl_any_type || x == jl_bottom_type) - return 0; - if (jl_is_unionall(x)) - return may_contain_union_decision(((jl_unionall_t *)x)->body, e, log); - if (jl_is_datatype(x)) { - jl_datatype_t *xd = (jl_datatype_t *)x; - for (int i = 0; i < jl_nparams(xd); i++) { - jl_value_t *param = jl_tparam(xd, i); - if (jl_is_vararg(param)) - param = jl_unwrap_vararg(param); - if (may_contain_union_decision(param, e, log)) - return 1; - } - return 0; - } if (!jl_is_typevar(x)) - return jl_is_type(x); - jl_typeenv_t *t = log; - while (t != NULL) { - if (x == (jl_value_t *)t->var) - return 1; - t = t->prev; - } - jl_typeenv_t newlog = { (jl_tvar_t*)x, NULL, log }; - jl_varbinding_t *xb = lookup(e, (jl_tvar_t *)x); - return may_contain_union_decision(xb ? xb->lb : ((jl_tvar_t *)x)->lb, e, &newlog) || - may_contain_union_decision(xb ? xb->ub : ((jl_tvar_t *)x)->ub, e, &newlog); + return 0; + jl_varbinding_t *vb = lookup(e, (jl_tvar_t *)x); + return vb && vb->right; } static int has_exists_typevar(jl_value_t *x, jl_stenv_t *e) JL_NOTSAFEPOINT @@ -1574,31 +1601,9 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t int kindy = !jl_has_free_typevars(y); if (kindx && kindy) return jl_subtype(x, y); - if (may_contain_union_decision(y, e, NULL) && pick_union_decision(e, 1) == 0) { - jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions); - e->Lunions.used = e->Runions.used = 0; - e->Lunions.depth = e->Runions.depth = 0; - e->Lunions.more = e->Runions.more = 0; - int count = 0, noRmore = 0; - sub = _forall_exists_subtype(x, y, e, param, &count, &noRmore); - pop_unionstate(&e->Runions, &oldRunions); - // We could skip the slow path safely if - // 1) `_∀_∃_subtype` has tested all cases - // 2) `_∀_∃_subtype` returns 1 && `x` and `y` contain no ∃ typevar - // Once `limit_slow == 1`, also skip it if - // 1) `_∀_∃_subtype` returns 0 - // 2) the left `Union` looks big - // TODO: `limit_slow` ignores complexity from inner `local_∀_exists_subtype`. - if (limit_slow == -1) - limit_slow = kindx || kindy; - int skip = noRmore || (limit_slow && (count > 3 || !sub)) || - (sub && (kindx || !has_exists_typevar(x, e)) && - (kindy || !has_exists_typevar(y, e))); - if (skip) - e->Runions.more = oldRmore; - } - else { - // slow path + int has_exists = (!kindx && has_exists_typevar(x, e)) || + (!kindy && has_exists_typevar(y, e)); + if (has_exists && (is_exists_typevar(x, e) != is_exists_typevar(y, e))) { e->Lunions.used = 0; while (1) { e->Lunions.more = 0; @@ -1607,7 +1612,51 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t if (!sub || !next_union_state(e, 0)) break; } + return sub; } + if (limit_slow == -1) + limit_slow = kindx || kindy; + jl_savedenv_t se; + save_env(e, &se, has_exists); + int count, limited = 0, ini_count = 0; + jl_saved_unionstate_t latestLunions = {0, 0, 0, NULL}; + while (1) { + count = ini_count; + if (ini_count == 0) + e->Lunions.used = 0; + else + pop_unionstate(&e->Lunions, &latestLunions); + while (1) { + e->Lunions.more = 0; + e->Lunions.depth = 0; + if (count < 4) count++; + sub = subtype(x, y, e, param); + if (limit_slow && count == 4) + limited = 1; + if (!sub || !next_union_state(e, 0)) + break; + if (limited || !has_exists || e->Runions.more == oldRmore) { + // re-save env and freeze the ∃decision for previous ∀Union + // Note: We could ignore the rest `∃Union` decisions if `x` and `y` + // contain no ∃ typevar, as they have no effect on env. + ini_count = count; + push_unionstate(&latestLunions, &e->Lunions); + re_save_env(e, &se, has_exists); + e->Runions.more = oldRmore; + } + } + if (sub || e->Runions.more == oldRmore) + break; + assert(e->Runions.more > oldRmore); + next_union_state(e, 1); + restore_env(e, &se, has_exists); // also restore Rdepth here + e->Runions.more = oldRmore; + } + if (!sub) + assert(e->Runions.more == oldRmore); + else if (limited || !has_exists) + e->Runions.more = oldRmore; + free_env(&se); return sub; } @@ -1677,7 +1726,7 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_savede } } -static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore) +static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param) { // The depth recursion has the following shape, after simplification: // ∀₁ @@ -1689,12 +1738,8 @@ static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, i e->Lunions.used = 0; int sub; - if (count) *count = 0; - if (noRmore) *noRmore = 1; while (1) { sub = exists_subtype(x, y, e, &se, param); - if (count) *count = (*count < 4) ? *count + 1 : 4; - if (noRmore) *noRmore = *noRmore && e->Runions.more == 0; if (!sub || !next_union_state(e, 0)) break; re_save_env(e, &se, 1); @@ -1704,11 +1749,6 @@ static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, i return sub; } -static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param) -{ - return _forall_exists_subtype(x, y, e, param, NULL, NULL); -} - static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz) { e->vars = NULL; @@ -1728,6 +1768,8 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz) e->Lunions.depth = 0; e->Runions.depth = 0; e->Lunions.more = 0; e->Runions.more = 0; e->Lunions.used = 0; e->Runions.used = 0; + e->Lunions.stack.next = NULL; + e->Runions.stack.next = NULL; } // subtyping entry points @@ -2157,6 +2199,7 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env, } init_stenv(&e, env, envsz); int subtype = forall_exists_subtype(x, y, &e, 0); + free_stenv(&e); assert(obvious_subtype == 3 || obvious_subtype == subtype || jl_has_free_typevars(x) || jl_has_free_typevars(y)); #ifndef NDEBUG if (obvious_subtype == 0 || (obvious_subtype == 1 && envsz == 0)) @@ -2249,6 +2292,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b) { init_stenv(&e, NULL, 0); int subtype = forall_exists_subtype(a, b, &e, 0); + free_stenv(&e); assert(subtype_ab == 3 || subtype_ab == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b)); #ifndef NDEBUG if (subtype_ab != 0 && subtype_ab != 1) // ensures that running in a debugger doesn't change the result @@ -2265,6 +2309,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b) { init_stenv(&e, NULL, 0); int subtype = forall_exists_subtype(b, a, &e, 0); + free_stenv(&e); assert(subtype_ba == 3 || subtype_ba == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b)); #ifndef NDEBUG if (subtype_ba != 0 && subtype_ba != 1) // ensures that running in a debugger doesn't change the result @@ -4230,7 +4275,9 @@ static jl_value_t *intersect_types(jl_value_t *x, jl_value_t *y, int emptiness_o init_stenv(&e, NULL, 0); e.intersection = e.ignore_free = 1; e.emptiness_only = emptiness_only; - return intersect_all(x, y, &e); + jl_value_t *ans = intersect_all(x, y, &e); + free_stenv(&e); + return ans; } JL_DLLEXPORT jl_value_t *jl_intersect_types(jl_value_t *x, jl_value_t *y) @@ -4407,6 +4454,7 @@ jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t * memset(env, 0, szb*sizeof(void*)); e.envsz = szb; *ans = intersect_all(a, b, &e); + free_stenv(&e); if (*ans == jl_bottom_type) goto bot; // TODO: code dealing with method signatures is not able to handle unions, so if // `a` and `b` are both tuples, we need to be careful and may not return a union,