Skip to content

Commit

Permalink
fix: backward jump type context compatibility check (#10)
Browse files Browse the repository at this point in the history
Co-authored-by: Jules <57632293+JuliaPoo@users.noreply.github.com>
  • Loading branch information
Fidget-Spinner and JuliaPoo authored Jun 3, 2023
1 parent cd81ce6 commit 197c0c9
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 61 deletions.
230 changes: 169 additions & 61 deletions Python/tier2.c
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,61 @@ typenode_get_type(_Py_TYPENODE_t node)
}
}

/**
* @brief Gets the location of the node in the type context
* @param ctx Pointer to type context to look in
* @param node A node in type context `ctx`
* @param is_localvar Pointer to boolean.
If `node` is inside `ctx->type_locals`, this will return true
* @return The index into the array the node is in.
* If `*is_localvar` is true, the array is `ctx->type_locals`.
* Otherwise it is `ctx->type_stack`
*/
static int
typenode_get_location(_PyTier2TypeContext *ctx, _Py_TYPENODE_t *node, bool *is_localvar)
{
// Search locals
int nlocals = ctx->type_locals_len;
int offset = (int)(node - ctx->type_locals);
if (offset >= 0 && offset < nlocals) {
*is_localvar = true;
return offset;
}

// Search stack
int nstack = ctx->type_stack_len;
offset = (int)(node - ctx->type_stack);
for (int i = 0; i < nstack; i++) {
*is_localvar = false;
return offset;
}

Py_UNREACHABLE();
}

/**
* @brief Check if two nodes in a type context are in the same tree
* @param src
*/
static bool typenode_is_same_tree(_Py_TYPENODE_t *x, _Py_TYPENODE_t *y)
{
_Py_TYPENODE_t *x_rootref = x;
_Py_TYPENODE_t *y_rootref = y;
uintptr_t x_tag = _Py_TYPENODE_GET_TAG(*x);
uintptr_t y_tag = _Py_TYPENODE_GET_TAG(*y);
switch (y_tag) {
case TYPE_REF: y_rootref = __typenode_get_rootptr(*y);
case TYPE_ROOT: break;
default: Py_UNREACHABLE();
}
switch (x_tag) {
case TYPE_REF: x_rootref = __typenode_get_rootptr(*x);
case TYPE_ROOT: break;
default: Py_UNREACHABLE();
}
return x_rootref == y_rootref;
}

/**
* @brief Performs TYPE_SET operation. dst tree becomes part of src tree
*
Expand Down Expand Up @@ -279,55 +334,23 @@ __type_propagate_TYPE_SET(
}
#endif

if (!src_is_new) {
// Check if they are the same tree
_Py_TYPENODE_t *srcrootref = src;
_Py_TYPENODE_t *dstrootref = dst;
uintptr_t dsttag = _Py_TYPENODE_GET_TAG(*dst);
uintptr_t srctag = _Py_TYPENODE_GET_TAG(*src);
switch (dsttag) {
case TYPE_REF: dstrootref = __typenode_get_rootptr(*dst);
case TYPE_ROOT:
switch (srctag) {
case TYPE_REF: srcrootref = __typenode_get_rootptr(*src);
case TYPE_ROOT:
if (srcrootref == dstrootref) {
// Same tree, no point swapping
return;
}
break;
default:
Py_UNREACHABLE();
}
break;
default:
Py_UNREACHABLE();
}
if (!src_is_new && typenode_is_same_tree(src, dst)) {
return;
}

uintptr_t tag = _Py_TYPENODE_GET_TAG(*dst);
_Py_TYPENODE_t *rootref = dst;
switch (tag) {
case TYPE_ROOT: {
case TYPE_REF: rootref = __typenode_get_rootptr(*dst);
case TYPE_ROOT:
if (!src_is_new) {
// Make dst a reference to src
*dst = _Py_TYPENODE_MAKE_REF((_Py_TYPENODE_t)src);
break;
}
// Make dst the src
*dst = (_Py_TYPENODE_t)src;
break;
}
case TYPE_REF: {
_Py_TYPENODE_t *rootref = __typenode_get_rootptr(*dst);
if (!src_is_new) {
// Traverse up to the root of dst, make root a reference to src
*rootref = _Py_TYPENODE_MAKE_REF((_Py_TYPENODE_t)src);
break;
}
// Make root of dst the src
// Make dst the src
*rootref = (_Py_TYPENODE_t)src;
break;
}
default:
Py_UNREACHABLE();
}
Expand Down Expand Up @@ -365,6 +388,10 @@ __type_propagate_TYPE_OVERWRITE(
}
#endif

if (!src_is_new && typenode_is_same_tree(src, dst)) {
return;
}

uintptr_t tag = _Py_TYPENODE_GET_TAG(*dst);
switch (tag) {
case TYPE_ROOT: {
Expand Down Expand Up @@ -487,28 +514,8 @@ __type_propagate_TYPE_SWAP(
_PyTier2TypeContext *type_context,
_Py_TYPENODE_t *src, _Py_TYPENODE_t *dst)
{
// Check if they are the same tree
_Py_TYPENODE_t *srcrootref = src;
_Py_TYPENODE_t *dstrootref = dst;
uintptr_t dsttag = _Py_TYPENODE_GET_TAG(*dst);
uintptr_t srctag = _Py_TYPENODE_GET_TAG(*src);
switch (dsttag) {
case TYPE_REF: dstrootref = __typenode_get_rootptr(*dst);
case TYPE_ROOT:
switch (srctag) {
case TYPE_REF: srcrootref = __typenode_get_rootptr(*src);
case TYPE_ROOT:
if (srcrootref == dstrootref) {
// Same tree, no point swapping
return;
}
break;
default:
Py_UNREACHABLE();
}
break;
default:
Py_UNREACHABLE();
if (typenode_is_same_tree(src, dst)) {
return;
}

// src and dst are different tree,
Expand Down Expand Up @@ -2658,6 +2665,93 @@ _PyTier2_GenerateNextBB(
return metadata->tier2_start;
}

/**
* @brief Helper funnction of typecontext_is_compatible. See that for why we need this.
* @param ctx1 A pointer to a type context.
* @param ctx2 A pointer to a different type context.
* @param ctx1_node A pointer to type node belonging to ctx1.
* @param ctx2_node A pointer to type node belonging to ctx2t.
* @return If the type nodes' parent trees are compatible.
*/
static bool
typenode_is_compatible(
_PyTier2TypeContext *ctx1, _PyTier2TypeContext *ctx2,
_Py_TYPENODE_t *ctx1_node, _Py_TYPENODE_t *ctx2_node)
{
_Py_TYPENODE_t *root1 = ctx1_node;
_Py_TYPENODE_t *root2 = ctx2_node;
switch (_Py_TYPENODE_GET_TAG(*ctx1_node)) {
case TYPE_REF: root1 = __typenode_get_rootptr(*ctx1_node);
case TYPE_ROOT: break;
default: Py_UNREACHABLE();
}
switch (_Py_TYPENODE_GET_TAG(*ctx2_node)) {
case TYPE_REF: root2 = __typenode_get_rootptr(*ctx2_node);
case TYPE_ROOT: break;
default: Py_UNREACHABLE();
}

// Get location of each root
bool is_local1, is_local2;
int node_idx1 = typenode_get_location(ctx1, root1, &is_local1);
int node_idx2 = typenode_get_location(ctx2, root2, &is_local2);

// Map each root to the corresponding location in the other tree
_Py_TYPENODE_t* mappedroot1 = is_local1
? &ctx2->type_locals[node_idx1]
: &ctx2->type_stack[node_idx1];
_Py_TYPENODE_t* mappedroot2 = is_local2
? &ctx1->type_locals[node_idx2]
: &ctx1->type_stack[node_idx2];

return typenode_is_same_tree(mappedroot1, root2)
&& typenode_is_same_tree(mappedroot2, root1);
}

/**
* @brief Checks that type context 2 is compatible with context 1.
* ctx2 is compatible with ctx1 if any execution state with ctx2 can run on code emitted from ctx1
*
* @param ctx1 The base type context.
* @param ctx2 The type context to compare with.
* @return true if compatible, false otherwise.
*/
static bool
typecontext_is_compatible(_PyTier2TypeContext *ctx1, _PyTier2TypeContext *ctx2)
{
// This function does two things:
// 1. Check that the trees are the same "shape" and equivalent. This allows
// ctx1's trees to be a subtree of ctx2.
// 2. Check that the trees resolve to the same root type.

#ifdef Py_DEBUG
// These should be true during runtime
assert(ctx1->type_locals_len == ctx2->type_locals_len);
assert(ctx1->type_stack_len == ctx2->type_stack_len);
int stack_elems1 = (int)(ctx1->type_stack_ptr - ctx1->type_stack);
int stack_elems2 = (int)(ctx2->type_stack_ptr - ctx2->type_stack);
assert(stack_elems1 == stack_elems2);
#endif

// Check the locals
for (int i = 0; i < ctx1->type_locals_len; i++) {
if (!typenode_is_compatible(ctx1, ctx2, &ctx1->type_locals[i],
&ctx2->type_locals[i])) {
return false;
}
}

// Check the type stack
for (int i = 0; i < stack_elems1; i++) {
if (!typenode_is_compatible(ctx1, ctx2, &ctx1->type_stack[i],
&ctx2->type_stack[i])) {
return false;
}
}

return true;
}

/**
* @brief Calculates the difference between two type contexts.
* @param ctx1 The base type context.
Expand All @@ -2672,16 +2766,22 @@ diff_typecontext(_PyTier2TypeContext *ctx1, _PyTier2TypeContext *ctx2)
assert(ctx2 != NULL);
#if BB_DEBUG
fprintf(stderr, " [*] Diffing type contexts\n");
#if TYPEPROP_DEBUG
static void print_typestack(const _PyTier2TypeContext * type_context);
print_typestack(ctx1);
print_typestack(ctx2);
#endif
#endif
assert(ctx1->type_locals_len == ctx2->type_locals_len);
assert(ctx1->type_stack_len == ctx2->type_stack_len);
int stack_elems1 = (int)(ctx1->type_stack_ptr - ctx1->type_stack);
int stack_elems2 = (int)(ctx2->type_stack_ptr - ctx2->type_stack);
assert(stack_elems1 == stack_elems2);

if (!typecontext_is_compatible(ctx1, ctx2)) {
return INT_MAX;
}

int diff = 0;
// Check the difference in the type locals
for (int i = 0; i < ctx1->type_locals_len; i++) {
Expand Down Expand Up @@ -2822,9 +2922,17 @@ _PyTier2_LocateJumpBackwardsBB(_PyInterpreterFrame *frame, uint16_t bb_id_tagged
assert(jump_offset_id >= 0);
assert(candidate_bb_id >= 0);
assert(candidate_bb_tier1_start != NULL);
#if BB_DEBUG
if (matching_bb_id != -1) {
fprintf(stderr, "Found jump target BB ID: %d\n", matching_bb_id);
}
#endif
// We couldn't find a matching BB to jump to. Time to generate our own.
// This also requires rewriting our backwards jump to a forward jump later.
if (matching_bb_id == -1) {
#if BB_DEBUG
fprintf(stderr, "Generating new jump target BB ID: %d\n", matching_bb_id);
#endif
// We should use the type context occuring at the end of the loop.
_PyTier2TypeContext *copied = _PyTier2TypeContext_Copy(curr_type_context);
if (copied == NULL) {
Expand Down Expand Up @@ -2867,7 +2975,7 @@ _PyTier2_LocateJumpBackwardsBB(_PyInterpreterFrame *frame, uint16_t bb_id_tagged
assert(matching_bb_id >= 0);
assert(matching_bb_id <= t2_info->bb_data_curr);
#if BB_DEBUG
fprintf(stderr, "Found jump target BB ID: %d\n", matching_bb_id);
fprintf(stderr, "Using jump target BB ID: %d\n", matching_bb_id);
#endif
_PyTier2BBMetadata *target_metadata = t2_info->bb_data[matching_bb_id];
return target_metadata->tier2_start;
Expand Down
16 changes: 16 additions & 0 deletions tier2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,4 +447,20 @@ def test_iter_tuple(a):
assert jmp_target.opname == "NOP" # Space for an EXTENDED_ARG
assert insts[instidx + 1].opname == "BB_TEST_ITER_TUPLE" # The loop predicate

######################################################################
# Tests for: Tier 2 backward jump type context compatiblity check #
######################################################################
with TestInfo("type context backwards jump compatibility check"):
# See https://github.com/pylbbv/pylbbv/issues/9 for more information.
def f(y,z,w):
d = z
for _ in [1,2]:
z+z
d+d
d=w


trigger_tier2(f, (1,1,1.))

# As long as it doesn't crash, everything's good.
print("Tests completed ^-^")

0 comments on commit 197c0c9

Please sign in to comment.