Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT: Fold more nullchecks #111985

Merged
merged 12 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 102 additions & 21 deletions src/coreclr/jit/assertionprop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4367,6 +4367,50 @@ GenTree* Compiler::optAssertionProp_RelOp(ASSERT_VALARG_TP assertions, GenTree*
return optAssertionPropLocal_RelOp(assertions, tree, stmt);
}

//--------------------------------------------------------------------------------
// optVisitReachingAssertions: given a vn, call the specified callback function on all
// the assertions that reach it via PHI definitions if any.
//
// Arguments:
// vn - The vn to visit all the reaching assertions for
// argVisitor - The callback function to call on the vn and its reaching assertions
//
// Return Value:
// AssertVisit::Aborted - an argVisitor returned AssertVisit::Abort, we stop the walk and return
// AssertVisit::Continue - all argVisitor returned AssertVisit::Continue
//
template <typename TAssertVisitor>
Compiler::AssertVisit Compiler::optVisitReachingAssertions(ValueNum vn, TAssertVisitor argVisitor)
{
VNPhiDef phiDef;
if (!vnStore->GetPhiDef(vn, &phiDef))
{
// We assume that the caller already checked assertions for the current block, so we're
// interested only in assertions for PHI definitions.
return AssertVisit::Abort;
}

LclSsaVarDsc* ssaDef = lvaGetDesc(phiDef.LclNum)->GetPerSsaData(phiDef.SsaDef);
GenTreeLclVarCommon* node = ssaDef->GetDefNode();
if ((node == nullptr) || !node->OperIs(GT_STORE_LCL_VAR) || !node->Data()->OperIs(GT_PHI))
{
return AssertVisit::Abort;
}

for (GenTreePhi::Use& use : node->Data()->AsPhi()->Uses())
{
GenTreePhiArg* phiArg = use.GetNode()->AsPhiArg();
const ValueNum phiArgVN = vnStore->VNConservativeNormalValue(phiArg->gtVNPair);
ASSERT_TP assertions = optGetEdgeAssertions(ssaDef->GetBlock(), phiArg->gtPredBB);
if (argVisitor(phiArgVN, assertions) == AssertVisit::Abort)
{
// The visitor wants to abort the walk.
return AssertVisit::Abort;
}
}
return AssertVisit::Continue;
}

//------------------------------------------------------------------------
// optAssertionProp: try and optimize a relop via assertion propagation
//
Expand All @@ -4380,6 +4424,8 @@ GenTree* Compiler::optAssertionProp_RelOp(ASSERT_VALARG_TP assertions, GenTree*
//
GenTree* Compiler::optAssertionPropGlobal_RelOp(ASSERT_VALARG_TP assertions, GenTree* tree, Statement* stmt)
{
assert(!optLocalAssertionProp);

GenTree* newTree = tree;
GenTree* op1 = tree->AsOp()->gtOp1;
GenTree* op2 = tree->AsOp()->gtOp2;
Expand Down Expand Up @@ -4463,6 +4509,24 @@ GenTree* Compiler::optAssertionPropGlobal_RelOp(ASSERT_VALARG_TP assertions, Gen
return nullptr;
}

// See if we have "PHI ==/!= null" tree. If so, we iterate over all PHI's arguments,
// and if all of them are known to be non-null, we can bash the comparison to true/false.
if (op2->IsIntegralConst(0) && op1->TypeIs(TYP_REF))
{
auto visitor = [this](ValueNum reachingVN, ASSERT_TP reachingAssertions) {
return optAssertionVNIsNonNull(reachingVN, reachingAssertions) ? AssertVisit::Continue : AssertVisit::Abort;
};

ValueNum op1vn = vnStore->VNConservativeNormalValue(op1->gtVNPair);
if (optVisitReachingAssertions(op1vn, visitor) == AssertVisit::Continue)
{
JITDUMP("... all of PHI's arguments are never null!\n");
assert(newTree->OperIs(GT_EQ, GT_NE));
newTree = tree->OperIs(GT_EQ) ? gtNewIconNode(0) : gtNewIconNode(1);
return optAssertionProp_Update(newTree, tree, stmt);
}
}

// Find an equal or not equal assertion involving "op1" and "op2".
index = optGlobalAssertionIsEqualOrNotEqual(assertions, op1, op2);

Expand Down Expand Up @@ -5083,31 +5147,19 @@ bool Compiler::optAssertionVNIsNonNull(ValueNum vn, ASSERT_VALARG_TP assertions)
return true;
}

// Check each assertion to find if we have a vn != null assertion.
//
BitVecOps::Iter iter(apTraits, assertions);
unsigned index = 0;
while (iter.NextElem(&index))
if (!BitVecOps::MayBeUninit(assertions))
{
AssertionIndex assertionIndex = GetAssertionIndex(index);
if (assertionIndex > optAssertionCount)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant condition, I guess it wasn't deleted when we moved all "walk asserts" pieces to BitVecOps::Iter.

{
break;
}
AssertionDsc* curAssertion = optGetAssertion(assertionIndex);
if (!curAssertion->CanPropNonNull())
{
continue;
}

if (curAssertion->op1.vn != vn)
BitVecOps::Iter iter(apTraits, assertions);
unsigned index = 0;
while (iter.NextElem(&index))
{
continue;
AssertionDsc* curAssertion = optGetAssertion(GetAssertionIndex(index));
if (curAssertion->CanPropNonNull() && curAssertion->op1.vn == vn)
{
return true;
}
}

return true;
}

return false;
}

Expand Down Expand Up @@ -5810,6 +5862,35 @@ ASSERT_VALRET_TP Compiler::optGetVnMappedAssertions(ValueNum vn)
return BitVecOps::UninitVal();
}

//------------------------------------------------------------------------
// optGetEdgeAssertions: Given a block and its predecessor, get the assertions
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've extracted this function from rangecheck.cpp to re-use in my function

// the predecessor creates for the block.
//
// Arguments:
// block - The block to get the assertions for.
// blockPred - The predecessor of the block (creating the assertions).
//
// Return Value:
// The assertions we have about the value number.
//
ASSERT_VALRET_TP Compiler::optGetEdgeAssertions(const BasicBlock* block, const BasicBlock* blockPred) const
{
if ((blockPred->KindIs(BBJ_COND) && blockPred->TrueTargetIs(block)))
{
if (bbJtrueAssertionOut != nullptr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious if you ever see this being null.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed it could in rangecheck if assertprop didn't run, but not sure such combination is possible. From a quick SPMI run it's not null. I'll convert it into an assert in my next PR to avoid spinning CI for it

{
return bbJtrueAssertionOut[blockPred->bbNum];
}
}

if (blockPred->KindIs(BBJ_ALWAYS, BBJ_COND))
{
return blockPred->bbAssertionOut;
}

return BitVecOps::MakeEmpty(apTraits);
}

/*****************************************************************************
*
* Given a const assertion this method computes the set of implied assertions
Expand Down
10 changes: 10 additions & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -8257,6 +8257,14 @@ class Compiler
bool optNonNullAssertionProp_Ind(ASSERT_VALARG_TP assertions, GenTree* indir);
bool optWriteBarrierAssertionProp_StoreInd(ASSERT_VALARG_TP assertions, GenTreeStoreInd* indir);

enum class AssertVisit
{
Continue,
Abort,
};
template <typename TAssertVisitor>
AssertVisit optVisitReachingAssertions(ValueNum vn, TAssertVisitor argVisitor);

void optAssertionProp_RangeProperties(ASSERT_VALARG_TP assertions,
GenTree* tree,
bool* isKnownNonZero,
Expand All @@ -8276,6 +8284,8 @@ class Compiler
void optDebugCheckAssertions(AssertionIndex AssertionIndex);
#endif

ASSERT_VALRET_TP optGetEdgeAssertions(const BasicBlock* block, const BasicBlock* blockPred) const;

static void optDumpAssertionIndices(const char* header, ASSERT_TP assertions, const char* footer = nullptr);
static void optDumpAssertionIndices(ASSERT_TP assertions, const char* footer = nullptr);

Expand Down
21 changes: 5 additions & 16 deletions src/coreclr/jit/rangecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,26 +977,15 @@ void RangeCheck::MergeAssertion(BasicBlock* block, GenTree* op, Range* pRange DE
ASSERT_TP assertions = BitVecOps::UninitVal();

// If we have a phi arg, we can get to the block from it and use its assertion out.
if (op->gtOper == GT_PHI_ARG)
if (op->OperIs(GT_PHI_ARG))
{
GenTreePhiArg* arg = (GenTreePhiArg*)op;
BasicBlock* pred = arg->gtPredBB;
if (pred->KindIs(BBJ_COND) && pred->FalseTargetIs(block))
const BasicBlock* pred = op->AsPhiArg()->gtPredBB;
assertions = m_pCompiler->optGetEdgeAssertions(block, pred);
if (!BitVecOps::MayBeUninit(assertions))
{
assertions = pred->bbAssertionOut;
JITDUMP("Merge assertions from pred " FMT_BB " edge: ", pred->bbNum);
JITDUMP("Merge assertions created by " FMT_BB " for " FMT_BB "\n", pred->bbNum, block->bbNum);
Compiler::optDumpAssertionIndices(assertions, "\n");
}
else if ((pred->KindIs(BBJ_ALWAYS) && pred->TargetIs(block)) ||
(pred->KindIs(BBJ_COND) && pred->TrueTargetIs(block)))
{
if (m_pCompiler->bbJtrueAssertionOut != nullptr)
{
assertions = m_pCompiler->bbJtrueAssertionOut[pred->bbNum];
JITDUMP("Merge assertions from pred " FMT_BB " JTrue edge: ", pred->bbNum);
Compiler::optDumpAssertionIndices(assertions, "\n");
}
}
}
// Get assertions from bbAssertionIn.
else if (op->IsLocal())
Expand Down
Loading