diff --git a/backends/p4tools/common/compiler/convert_varbits.h b/backends/p4tools/common/compiler/convert_varbits.h index 330c18190a..42fe678246 100644 --- a/backends/p4tools/common/compiler/convert_varbits.h +++ b/backends/p4tools/common/compiler/convert_varbits.h @@ -10,8 +10,8 @@ namespace P4Tools { -/// Converts all existing Type_Varbit types in the program into a custom Size_Type_Varbit type. -/// Sized_Type_Varbit also contains information about the width that was assigned to the type by +/// Converts all existing Type_Varbit types in the program into a custom Extracted_Varbit type. +/// Extracted_Varbit also contains information about the width that was assigned to the type by /// the extract call. class ConvertVarbits : public Transform { public: diff --git a/backends/p4tools/common/lib/symbolic_env.cpp b/backends/p4tools/common/lib/symbolic_env.cpp index 95d9d6888c..0578338f6e 100644 --- a/backends/p4tools/common/lib/symbolic_env.cpp +++ b/backends/p4tools/common/lib/symbolic_env.cpp @@ -8,7 +8,6 @@ #include #include "backends/p4tools/common/lib/model.h" -#include "frontends/p4/optimizeExpressions.h" #include "ir/indexed_vector.h" #include "ir/vector.h" #include "ir/visitor.h" @@ -28,7 +27,9 @@ const IR::Expression *SymbolicEnv::get(const IR::StateVariable &var) const { bool SymbolicEnv::exists(const IR::StateVariable &var) const { return map.find(var) != map.end(); } void SymbolicEnv::set(const IR::StateVariable &var, const IR::Expression *value) { - map[var] = P4::optimizeExpression(value); + BUG_CHECK(value->type && !value->type->is(), + "Cannot set value with unspecified type: %1%", value); + map[var] = value; } const IR::Expression *SymbolicEnv::subst(const IR::Expression *expr) const { diff --git a/backends/p4tools/modules/testgen/core/small_step/extern_stepper.cpp b/backends/p4tools/modules/testgen/core/small_step/extern_stepper.cpp index 9cc12ece11..6cbca2acd1 100644 --- a/backends/p4tools/modules/testgen/core/small_step/extern_stepper.cpp +++ b/backends/p4tools/modules/testgen/core/small_step/extern_stepper.cpp @@ -37,7 +37,9 @@ std::vector> ExprStepper::s ExecutionState &nextState, const std::vector &flatFields, int varBitFieldSize) { std::vector> fields; - for (const auto &fieldRef : flatFields) { + // Make a copy of the StateVariable so it can be modified in the varbit case (and it is just a + // pointer wrapper anyway). + for (IR::StateVariable fieldRef : flatFields) { const auto *fieldType = fieldRef->type; // If the header had a varbit, the header needs to be updated. // We assign @param varbitFeldSize to the varbit field. @@ -62,15 +64,12 @@ std::vector> ExprStepper::s // We need to cast the generated variable to the appropriate type. if (fieldType->is()) { pktVar = new IR::Cast(fieldType, pktVar); - // Update the field and add the field to the return list. - // TODO: Better way to handle varbits here? - auto *newRef = fieldRef->clone(); - newRef->type = fieldType; - nextState.set(fieldRef, pktVar); - fields.emplace_back(fieldRef, pktVar); - continue; - } - if (const auto *bits = fieldType->to()) { + // Rewrite the type of the field so it matches the extracted varbit type. + // TODO: is there a better way to do this? + auto *newRefExpr = fieldRef->clone(); + newRefExpr->type = fieldType; + fieldRef.ref = newRefExpr; + } else if (const auto *bits = fieldType->to()) { if (bits->isSigned) { pktVar = new IR::Cast(fieldType, pktVar); } diff --git a/backends/p4tools/modules/testgen/core/small_step/table_stepper.cpp b/backends/p4tools/modules/testgen/core/small_step/table_stepper.cpp index a1a7740bc1..037f3317a1 100644 --- a/backends/p4tools/modules/testgen/core/small_step/table_stepper.cpp +++ b/backends/p4tools/modules/testgen/core/small_step/table_stepper.cpp @@ -149,7 +149,7 @@ const IR::Expression *TableStepper::computeHit(TableMatchMap *matches) { const IR::StringLiteral *TableStepper::getTableActionString( const IR::MethodCallExpression *actionCall) { cstring actionName = actionCall->method->toString(); - return new IR::StringLiteral(actionName); + return new IR::StringLiteral(IR::Type_String::get(), actionName); } const IR::Expression *TableStepper::evalTableConstEntries() { diff --git a/backends/p4tools/modules/testgen/lib/execution_state.cpp b/backends/p4tools/modules/testgen/lib/execution_state.cpp index 0b3996e4e2..f34c534fe7 100644 --- a/backends/p4tools/modules/testgen/lib/execution_state.cpp +++ b/backends/p4tools/modules/testgen/lib/execution_state.cpp @@ -21,6 +21,7 @@ #include "backends/p4tools/common/lib/taint.h" #include "backends/p4tools/common/lib/trace_event.h" #include "backends/p4tools/common/lib/variables.h" +#include "frontends/p4/optimizeExpressions.h" #include "ir/id.h" #include "ir/indexed_vector.h" #include "ir/irutils.h" @@ -178,10 +179,37 @@ void ExecutionState::markVisited(const IR::Node *node) { const P4::Coverage::CoverageSet &ExecutionState::getVisited() const { return visitedNodes; } +/// Compare types, considering Extracted_Varbit and bits equal if the (real/extracted) sizes are +/// equal. This is because the packet expression can be something like 0 ++ +/// (Extracted_Varbit)pkt_var. This expression is typed as bit, but the optimizer removes the +/// 0 ++ and makes it into Extracted_Varbit type. +/// TODO: Maybe there is a better way to handle varbit that could allow us to avoid this. +static bool typeEquivSansVarbit(const IR::Type *a, const IR::Type *b) { + if (a->equiv(*b)) { + return true; + } + const auto *abit = a->to(); + const auto *avar = a->to(); + const auto *bbit = b->to(); + const auto *bvar = b->to(); + return (abit && bvar && abit->width_bits() == bvar->width_bits()) || + (avar && bbit && avar->width_bits() == bbit->width_bits()); +} + void ExecutionState::set(const IR::StateVariable &var, const IR::Expression *value) { + const auto *type = value->type; + BUG_CHECK(type && !type->is(), "Cannot set value with unspecified type: %1%", + value); if (getProperty("inUndefinedState")) { // If we are in an undefined state, the variable we set is tainted. - value = ToolsVariables::getTaintExpression(value->type); + value = ToolsVariables::getTaintExpression(type); + } else { + value = P4::optimizeExpression(value); + BUG_CHECK(value->type && !value->type->is(), + "The P4 expression optimizer stripped a type of %1% (was %2%)", value, type); + BUG_CHECK(typeEquivSansVarbit(type, value->type), + "The P4 expression optimizer had changed type of %1% (%2% -> %3%)", value, type, + value->type); } env.set(var, value); } diff --git a/frontends/common/constantFolding.cpp b/frontends/common/constantFolding.cpp index d4fa416314..f05d6d1dfc 100644 --- a/frontends/common/constantFolding.cpp +++ b/frontends/common/constantFolding.cpp @@ -425,9 +425,9 @@ const IR::Node *DoConstantFolding::compare(const IR::Operation_Binary *e) { auto ri = rlist->components.at(i); const IR::Operation_Binary *tmp; if (eqTest) - tmp = new IR::Equ(li, ri); + tmp = new IR::Equ(IR::Type_Boolean::get(), li, ri); else - tmp = new IR::Neq(li, ri); + tmp = new IR::Neq(IR::Type_Boolean::get(), li, ri); auto cmp = compare(tmp); auto boolLit = cmp->to(); if (boolLit == nullptr) return e; @@ -960,7 +960,8 @@ const IR::Node *DoConstantFolding::postorder(IR::SelectExpression *expression) { finished = true; if (someUnknown) { if (!c->keyset->is()) changes = true; - auto newc = new IR::SelectCase(c->srcInfo, new IR::DefaultExpression(), c->state); + auto newc = new IR::SelectCase( + c->srcInfo, new IR::DefaultExpression(expression->select->type), c->state); cases.push_back(newc); } else { // This is the result. diff --git a/frontends/p4/strengthReduction.cpp b/frontends/p4/strengthReduction.cpp index f80ff1cdc6..af24c67860 100644 --- a/frontends/p4/strengthReduction.cpp +++ b/frontends/p4/strengthReduction.cpp @@ -81,7 +81,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::BAnd *expr) { auto l = expr->left->to(); auto r = expr->right->to(); if (l && r) - return new IR::Cmpl(expr->type, new IR::BOr(expr->srcInfo, expr->type, l->expr, r->expr)); + return new IR::Cmpl(expr->srcInfo, expr->type, + new IR::BOr(expr->srcInfo, expr->type, l->expr, r->expr)); if (hasSideEffects(expr)) return expr; if (isZero(expr->left)) return expr->left; @@ -95,7 +96,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::BOr *expr) { if (isZero(expr->right)) return expr->left; auto l = expr->left->to(); auto r = expr->right->to(); - if (l && r) return new IR::Cmpl(new IR::BAnd(expr->srcInfo, expr->type, l->expr, r->expr)); + if (l && r) + return new IR::Cmpl(expr->srcInfo, expr->type, + new IR::BAnd(expr->srcInfo, expr->type, l->expr, r->expr)); if (hasSideEffects(expr)) return expr; if (expr->left->equiv(*expr->right)) return expr->left; return expr; @@ -143,15 +146,15 @@ const IR::Node *DoStrengthReduction::postorder(IR::Equ *expr) { if (isTrue(expr->left)) return expr->right; if (isTrue(expr->right)) return expr->left; // a == false is the same as !a - if (isFalse(expr->left)) return new IR::LNot(expr->right); - if (isFalse(expr->right)) return new IR::LNot(expr->left); + if (isFalse(expr->left)) return new IR::LNot(expr->srcInfo, expr->type, expr->right); + if (isFalse(expr->right)) return new IR::LNot(expr->srcInfo, expr->type, expr->left); return expr; } const IR::Node *DoStrengthReduction::postorder(IR::Neq *expr) { // a != true is the same as !a - if (isTrue(expr->left)) return new IR::LNot(expr->right); - if (isTrue(expr->right)) return new IR::LNot(expr->left); + if (isTrue(expr->left)) return new IR::LNot(expr->srcInfo, expr->type, expr->right); + if (isTrue(expr->right)) return new IR::LNot(expr->srcInfo, expr->type, expr->left); // a != false is the same as a if (isFalse(expr->left)) return expr->right; if (isFalse(expr->right)) return expr->left; @@ -160,12 +163,18 @@ const IR::Node *DoStrengthReduction::postorder(IR::Neq *expr) { const IR::Node *DoStrengthReduction::postorder(IR::LNot *expr) { if (auto e = expr->expr->to()) return e->expr; - if (auto e = expr->expr->to()) return new IR::Neq(e->left, e->right); - if (auto e = expr->expr->to()) return new IR::Equ(e->left, e->right); - if (auto e = expr->expr->to()) return new IR::Grt(e->left, e->right); - if (auto e = expr->expr->to()) return new IR::Lss(e->left, e->right); - if (auto e = expr->expr->to()) return new IR::Geq(e->left, e->right); - if (auto e = expr->expr->to()) return new IR::Leq(e->left, e->right); + if (auto e = expr->expr->to()) + return new IR::Neq(expr->srcInfo, expr->type, e->left, e->right); + if (auto e = expr->expr->to()) + return new IR::Equ(expr->srcInfo, expr->type, e->left, e->right); + if (auto e = expr->expr->to()) + return new IR::Grt(expr->srcInfo, expr->type, e->left, e->right); + if (auto e = expr->expr->to()) + return new IR::Lss(expr->srcInfo, expr->type, e->left, e->right); + if (auto e = expr->expr->to()) + return new IR::Geq(expr->srcInfo, expr->type, e->left, e->right); + if (auto e = expr->expr->to()) + return new IR::Leq(expr->srcInfo, expr->type, e->left, e->right); return expr; } @@ -176,13 +185,13 @@ const IR::Node *DoStrengthReduction::postorder(IR::Sub *expr) { if (expr->right->is()) { auto cst = expr->right->to(); auto neg = new IR::Constant(cst->srcInfo, cst->type, -cst->value, cst->base, true); - auto result = new IR::Add(expr->srcInfo, expr->left, neg); + auto result = new IR::Add(expr->srcInfo, expr->type, expr->left, neg); return result; } if (hasSideEffects(expr)) return expr; if (expr->left->equiv(*expr->right) && expr->left->type && !expr->left->type->is()) - return new IR::Constant(expr->left->type, 0); + return new IR::Constant(expr->srcInfo, expr->left->type, 0); return expr; } @@ -230,14 +239,14 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mul *expr) { if (isOne(expr->right)) return expr->left; auto exp = isPowerOf2(expr->left); if (exp >= 0) { - auto amt = new IR::Constant(exp); - auto sh = new IR::Shl(expr->srcInfo, expr->right, amt); + auto amt = new IR::Constant(expr->left->srcInfo, exp); + auto sh = new IR::Shl(expr->srcInfo, expr->type, expr->right, amt); return sh; } exp = isPowerOf2(expr->right); if (exp >= 0) { - auto amt = new IR::Constant(exp); - auto sh = new IR::Shl(expr->srcInfo, expr->left, amt); + auto amt = new IR::Constant(expr->right->srcInfo, exp); + auto sh = new IR::Shl(expr->srcInfo, expr->type, expr->left, amt); return sh; } if (hasSideEffects(expr)) return expr; @@ -254,8 +263,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::Div *expr) { if (isOne(expr->right)) return expr->left; auto exp = isPowerOf2(expr->right); if (exp >= 0) { - auto amt = new IR::Constant(exp); - auto sh = new IR::Shr(expr->srcInfo, expr->left, amt); + auto amt = new IR::Constant(expr->right->srcInfo, exp); + auto sh = new IR::Shr(expr->srcInfo, expr->type, expr->left, amt); return sh; } if (isZero(expr->left) && !hasSideEffects(expr->right)) return expr->left; @@ -272,8 +281,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mod *expr) { if (exp >= 0) { big_int mask = 1; mask = (mask << exp) - 1; - auto amt = new IR::Constant(expr->right->to()->type, mask); - auto sh = new IR::BAnd(expr->srcInfo, expr->left, amt); + auto amt = + new IR::Constant(expr->right->srcInfo, expr->right->to()->type, mask); + auto sh = new IR::BAnd(expr->srcInfo, expr->type, expr->left, amt); return sh; } return expr; @@ -301,7 +311,7 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mux *expr) { if (isTrue(expr->e1) && isFalse(expr->e2)) return expr->e0; else if (isFalse(expr->e1) && isTrue(expr->e2)) - return new IR::LNot(expr->e0); + return new IR::LNot(expr->srcInfo, expr->type, expr->e0); else if (const auto *lnot = expr->e0->to()) { expr->e0 = lnot->expr; const auto *tmp = expr->e1; @@ -372,8 +382,9 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) { expr->e0 = shift_of; expr->e1 = new IR::Constant(hi + shift_amt); expr->e2 = new IR::Constant(0); - return new IR::Concat(expr->type, expr, - new IR::Constant(IR::Type_Bits::get(-(lo + shift_amt)), 0)); + return new IR::Concat( + expr->srcInfo, expr->type, expr, + new IR::Constant(expr->srcInfo, IR::Type_Bits::get(-(lo + shift_amt)), 0)); } } @@ -393,8 +404,11 @@ const IR::Node *DoStrengthReduction::postorder(IR::Slice *expr) { else break; } else { - return new IR::Concat(expr->type, new IR::Slice(cat->left, expr->getH() - rwidth, 0), - new IR::Slice(cat->right, rwidth - 1, expr->getL())); + return new IR::Concat( + expr->srcInfo, expr->type, + // type of slice is calculated by its constructor + new IR::Slice(cat->left->srcInfo, cat->left, expr->getH() - rwidth, 0), + new IR::Slice(cat->right->srcInfo, cat->right, rwidth - 1, expr->getL())); } }