Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure P4 expression optimization does not strip away types #4300

Merged
merged 8 commits into from
Jan 8, 2024
4 changes: 2 additions & 2 deletions backends/p4tools/common/compiler/convert_varbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

namespace P4Tools {

/// Converts all existing Type_Varbit types in the program into a custom Size_Type_Varbit type.
/// Sized_Type_Varbit also contains information about the width that was assigned to the type by
/// Converts all existing Type_Varbit types in the program into a custom Extracted_Varbit type.
/// Extracted_Varbit also contains information about the width that was assigned to the type by
/// the extract call.
class ConvertVarbits : public Transform {
public:
Expand Down
5 changes: 3 additions & 2 deletions backends/p4tools/common/lib/symbolic_env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <boost/container/vector.hpp>

#include "backends/p4tools/common/lib/model.h"
#include "frontends/p4/optimizeExpressions.h"
#include "ir/indexed_vector.h"
#include "ir/vector.h"
#include "ir/visitor.h"
Expand All @@ -28,7 +27,9 @@ const IR::Expression *SymbolicEnv::get(const IR::StateVariable &var) const {
bool SymbolicEnv::exists(const IR::StateVariable &var) const { return map.find(var) != map.end(); }

void SymbolicEnv::set(const IR::StateVariable &var, const IR::Expression *value) {
map[var] = P4::optimizeExpression(value);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the optimization, I'm not sure if the check if type is set should be here too? May not be necessary for all tools that use SymbolicEnv.

BUG_CHECK(value->type && !value->type->is<IR::Type_Unknown>(),
"Cannot set value with unspecified type: %1%", value);
map[var] = value;
}

const IR::Expression *SymbolicEnv::subst(const IR::Expression *expr) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ std::vector<std::pair<IR::StateVariable, const IR::Expression *>> ExprStepper::s
ExecutionState &nextState, const std::vector<IR::StateVariable> &flatFields,
int varBitFieldSize) {
std::vector<std::pair<IR::StateVariable, const IR::Expression *>> fields;
for (const auto &fieldRef : flatFields) {
// Make a copy of the StateVariable so it can be modified in the varbit case (and it is just a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, in this case we are making a copy for the common case (no Varbit). Is that really better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite small class (one pointer + vptr) so I'd say this does not matter either way. Could be even better since there will be no dereference, but could be also exactly the same if the optimizer is smart enough (I haven't checked).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that is a check we should keep in this particular setter. The other code can be moved out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, you have presubmably replied to a different thread, did you mean that the check for type being set in the expression should be (also) in SymbolicEnv::set?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the comment was meant for #4300 (comment). Not sure what happened.

// pointer wrapper anyway).
for (IR::StateVariable fieldRef : flatFields) {
const auto *fieldType = fieldRef->type;
// If the header had a varbit, the header needs to be updated.
// We assign @param varbitFeldSize to the varbit field.
Expand All @@ -62,15 +64,12 @@ std::vector<std::pair<IR::StateVariable, const IR::Expression *>> ExprStepper::s
// We need to cast the generated variable to the appropriate type.
if (fieldType->is<IR::Extracted_Varbits>()) {
pktVar = new IR::Cast(fieldType, pktVar);
// Update the field and add the field to the return list.
// TODO: Better way to handle varbits here?
auto *newRef = fieldRef->clone();
newRef->type = fieldType;
nextState.set(fieldRef, pktVar);
fields.emplace_back(fieldRef, pktVar);
continue;
}
if (const auto *bits = fieldType->to<IR::Type_Bits>()) {
// Rewrite the type of the field so it matches the extracted varbit type.
// TODO: is there a better way to do this?
auto *newRefExpr = fieldRef->clone();
newRefExpr->type = fieldType;
fieldRef.ref = newRefExpr;
} else if (const auto *bits = fieldType->to<IR::Type_Bits>()) {
if (bits->isSigned) {
pktVar = new IR::Cast(fieldType, pktVar);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ const IR::Expression *TableStepper::computeHit(TableMatchMap *matches) {
const IR::StringLiteral *TableStepper::getTableActionString(
const IR::MethodCallExpression *actionCall) {
cstring actionName = actionCall->method->toString();
return new IR::StringLiteral(actionName);
return new IR::StringLiteral(IR::Type_String::get(), actionName);
}

const IR::Expression *TableStepper::evalTableConstEntries() {
Expand Down
30 changes: 29 additions & 1 deletion backends/p4tools/modules/testgen/lib/execution_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "backends/p4tools/common/lib/taint.h"
#include "backends/p4tools/common/lib/trace_event.h"
#include "backends/p4tools/common/lib/variables.h"
#include "frontends/p4/optimizeExpressions.h"
#include "ir/id.h"
#include "ir/indexed_vector.h"
#include "ir/irutils.h"
Expand Down Expand Up @@ -178,10 +179,37 @@ void ExecutionState::markVisited(const IR::Node *node) {

const P4::Coverage::CoverageSet &ExecutionState::getVisited() const { return visitedNodes; }

/// Compare types, considering Extracted_Varbit and bits equal if the (real/extracted) sizes are
/// equal. This is because the packet expression can be something like 0 ++
/// (Extracted_Varbit<N>)pkt_var. This expression is typed as bit<N>, but the optimizer removes the
/// 0 ++ and makes it into Extracted_Varbit type.
/// TODO: Maybe there is a better way to handle varbit that could allow us to avoid this.
static bool typeEquivSansVarbit(const IR::Type *a, const IR::Type *b) {
if (a->equiv(*b)) {
return true;
}
const auto *abit = a->to<IR::Type_Bits>();
const auto *avar = a->to<IR::Extracted_Varbits>();
const auto *bbit = b->to<IR::Type_Bits>();
const auto *bvar = b->to<IR::Extracted_Varbits>();
return (abit && bvar && abit->width_bits() == bvar->width_bits()) ||
(avar && bbit && avar->width_bits() == bbit->width_bits());
}

void ExecutionState::set(const IR::StateVariable &var, const IR::Expression *value) {
const auto *type = value->type;
BUG_CHECK(type && !type->is<IR::Type_Unknown>(), "Cannot set value with unspecified type: %1%",
value);
if (getProperty<bool>("inUndefinedState")) {
// If we are in an undefined state, the variable we set is tainted.
value = ToolsVariables::getTaintExpression(value->type);
value = ToolsVariables::getTaintExpression(type);
} else {
value = P4::optimizeExpression(value);
BUG_CHECK(value->type && !value->type->is<IR::Type_Unknown>(),
"The P4 expression optimizer stripped a type of %1% (was %2%)", value, type);
BUG_CHECK(typeEquivSansVarbit(type, value->type),
"The P4 expression optimizer had changed type of %1% (%2% -> %3%)", value, type,
value->type);
}
env.set(var, value);
}
Expand Down
7 changes: 4 additions & 3 deletions frontends/common/constantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ const IR::Node *DoConstantFolding::compare(const IR::Operation_Binary *e) {
auto ri = rlist->components.at(i);
const IR::Operation_Binary *tmp;
if (eqTest)
tmp = new IR::Equ(li, ri);
tmp = new IR::Equ(IR::Type_Boolean::get(), li, ri);
else
tmp = new IR::Neq(li, ri);
tmp = new IR::Neq(IR::Type_Boolean::get(), li, ri);
auto cmp = compare(tmp);
auto boolLit = cmp->to<IR::BoolLiteral>();
if (boolLit == nullptr) return e;
Expand Down Expand Up @@ -960,7 +960,8 @@ const IR::Node *DoConstantFolding::postorder(IR::SelectExpression *expression) {
finished = true;
if (someUnknown) {
if (!c->keyset->is<IR::DefaultExpression>()) changes = true;
auto newc = new IR::SelectCase(c->srcInfo, new IR::DefaultExpression(), c->state);
auto newc = new IR::SelectCase(
c->srcInfo, new IR::DefaultExpression(expression->select->type), c->state);
cases.push_back(newc);
} else {
// This is the result.
Expand Down
68 changes: 41 additions & 27 deletions frontends/p4/strengthReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::BAnd *expr) {
auto l = expr->left->to<IR::Cmpl>();
auto r = expr->right->to<IR::Cmpl>();
if (l && r)
return new IR::Cmpl(expr->type, new IR::BOr(expr->srcInfo, expr->type, l->expr, r->expr));
return new IR::Cmpl(expr->srcInfo, expr->type,
new IR::BOr(expr->srcInfo, expr->type, l->expr, r->expr));

if (hasSideEffects(expr)) return expr;
if (isZero(expr->left)) return expr->left;
Expand All @@ -95,7 +96,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::BOr *expr) {
if (isZero(expr->right)) return expr->left;
auto l = expr->left->to<IR::Cmpl>();
auto r = expr->right->to<IR::Cmpl>();
if (l && r) return new IR::Cmpl(new IR::BAnd(expr->srcInfo, expr->type, l->expr, r->expr));
if (l && r)
return new IR::Cmpl(expr->srcInfo, expr->type,
new IR::BAnd(expr->srcInfo, expr->type, l->expr, r->expr));
if (hasSideEffects(expr)) return expr;
if (expr->left->equiv(*expr->right)) return expr->left;
return expr;
Expand Down Expand Up @@ -143,15 +146,15 @@ const IR::Node *DoStrengthReduction::postorder(IR::Equ *expr) {
if (isTrue(expr->left)) return expr->right;
if (isTrue(expr->right)) return expr->left;
// a == false is the same as !a
if (isFalse(expr->left)) return new IR::LNot(expr->right);
if (isFalse(expr->right)) return new IR::LNot(expr->left);
if (isFalse(expr->left)) return new IR::LNot(expr->srcInfo, expr->type, expr->right);
if (isFalse(expr->right)) return new IR::LNot(expr->srcInfo, expr->type, expr->left);
return expr;
}

const IR::Node *DoStrengthReduction::postorder(IR::Neq *expr) {
// a != true is the same as !a
if (isTrue(expr->left)) return new IR::LNot(expr->right);
if (isTrue(expr->right)) return new IR::LNot(expr->left);
if (isTrue(expr->left)) return new IR::LNot(expr->srcInfo, expr->type, expr->right);
if (isTrue(expr->right)) return new IR::LNot(expr->srcInfo, expr->type, expr->left);
// a != false is the same as a
if (isFalse(expr->left)) return expr->right;
if (isFalse(expr->right)) return expr->left;
Expand All @@ -160,12 +163,18 @@ const IR::Node *DoStrengthReduction::postorder(IR::Neq *expr) {

const IR::Node *DoStrengthReduction::postorder(IR::LNot *expr) {
if (auto e = expr->expr->to<IR::LNot>()) return e->expr;
if (auto e = expr->expr->to<IR::Equ>()) return new IR::Neq(e->left, e->right);
if (auto e = expr->expr->to<IR::Neq>()) return new IR::Equ(e->left, e->right);
if (auto e = expr->expr->to<IR::Leq>()) return new IR::Grt(e->left, e->right);
if (auto e = expr->expr->to<IR::Geq>()) return new IR::Lss(e->left, e->right);
if (auto e = expr->expr->to<IR::Lss>()) return new IR::Geq(e->left, e->right);
if (auto e = expr->expr->to<IR::Grt>()) return new IR::Leq(e->left, e->right);
if (auto e = expr->expr->to<IR::Equ>())
return new IR::Neq(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Neq>())
return new IR::Equ(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Leq>())
return new IR::Grt(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Geq>())
return new IR::Lss(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Lss>())
return new IR::Geq(expr->srcInfo, expr->type, e->left, e->right);
if (auto e = expr->expr->to<IR::Grt>())
return new IR::Leq(expr->srcInfo, expr->type, e->left, e->right);
return expr;
}

Expand All @@ -176,13 +185,13 @@ const IR::Node *DoStrengthReduction::postorder(IR::Sub *expr) {
if (expr->right->is<IR::Constant>()) {
auto cst = expr->right->to<IR::Constant>();
auto neg = new IR::Constant(cst->srcInfo, cst->type, -cst->value, cst->base, true);
auto result = new IR::Add(expr->srcInfo, expr->left, neg);
auto result = new IR::Add(expr->srcInfo, expr->type, expr->left, neg);
return result;
}
if (hasSideEffects(expr)) return expr;
if (expr->left->equiv(*expr->right) && expr->left->type &&
!expr->left->type->is<IR::Type_Unknown>())
return new IR::Constant(expr->left->type, 0);
return new IR::Constant(expr->srcInfo, expr->left->type, 0);
return expr;
}

Expand Down Expand Up @@ -230,14 +239,14 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mul *expr) {
if (isOne(expr->right)) return expr->left;
auto exp = isPowerOf2(expr->left);
if (exp >= 0) {
auto amt = new IR::Constant(exp);
auto sh = new IR::Shl(expr->srcInfo, expr->right, amt);
auto amt = new IR::Constant(expr->left->srcInfo, exp);
auto sh = new IR::Shl(expr->srcInfo, expr->type, expr->right, amt);
return sh;
}
exp = isPowerOf2(expr->right);
if (exp >= 0) {
auto amt = new IR::Constant(exp);
auto sh = new IR::Shl(expr->srcInfo, expr->left, amt);
auto amt = new IR::Constant(expr->right->srcInfo, exp);
auto sh = new IR::Shl(expr->srcInfo, expr->type, expr->left, amt);
return sh;
}
if (hasSideEffects(expr)) return expr;
Expand All @@ -254,8 +263,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::Div *expr) {
if (isOne(expr->right)) return expr->left;
auto exp = isPowerOf2(expr->right);
if (exp >= 0) {
auto amt = new IR::Constant(exp);
auto sh = new IR::Shr(expr->srcInfo, expr->left, amt);
auto amt = new IR::Constant(expr->right->srcInfo, exp);
auto sh = new IR::Shr(expr->srcInfo, expr->type, expr->left, amt);
return sh;
}
if (isZero(expr->left) && !hasSideEffects(expr->right)) return expr->left;
Expand All @@ -272,8 +281,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mod *expr) {
if (exp >= 0) {
big_int mask = 1;
mask = (mask << exp) - 1;
auto amt = new IR::Constant(expr->right->to<IR::Constant>()->type, mask);
auto sh = new IR::BAnd(expr->srcInfo, expr->left, amt);
auto amt =
new IR::Constant(expr->right->srcInfo, expr->right->to<IR::Constant>()->type, mask);
auto sh = new IR::BAnd(expr->srcInfo, expr->type, expr->left, amt);
return sh;
}
return expr;
Expand Down Expand Up @@ -301,7 +311,7 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mux *expr) {
if (isTrue(expr->e1) && isFalse(expr->e2))
return expr->e0;
else if (isFalse(expr->e1) && isTrue(expr->e2))
return new IR::LNot(expr->e0);
return new IR::LNot(expr->srcInfo, expr->type, expr->e0);
else if (const auto *lnot = expr->e0->to<IR::LNot>()) {
expr->e0 = lnot->expr;
const auto *tmp = expr->e1;
Expand Down Expand Up @@ -372,8 +382,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) {
expr->e0 = shift_of;
expr->e1 = new IR::Constant(hi + shift_amt);
expr->e2 = new IR::Constant(0);
return new IR::Concat(expr->type, expr,
new IR::Constant(IR::Type_Bits::get(-(lo + shift_amt)), 0));
return new IR::Concat(
expr->srcInfo, expr->type, expr,
new IR::Constant(expr->srcInfo, IR::Type_Bits::get(-(lo + shift_amt)), 0));
}
}

Expand All @@ -393,8 +404,11 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) {
else
break;
} else {
return new IR::Concat(expr->type, new IR::Slice(cat->left, expr->getH() - rwidth, 0),
new IR::Slice(cat->right, rwidth - 1, expr->getL()));
return new IR::Concat(
expr->srcInfo, expr->type,
// type of slice is calculated by its constructor
new IR::Slice(cat->left->srcInfo, cat->left, expr->getH() - rwidth, 0),
new IR::Slice(cat->right->srcInfo, cat->right, rwidth - 1, expr->getL()));
}
}

Expand Down
Loading