Skip to content

Commit

Permalink
Avoid copying out/inout args when inlining (#4877)
Browse files Browse the repository at this point in the history
- If the actual argument for an out or inout argument is local to the
  calling functions, we can use it directly in the inlined function
  instead of introducing extra copies in and out.

Signed-off-by: Chris Dodd <cdodd@nvidia.com>
  • Loading branch information
ChrisDodd committed Sep 20, 2024
1 parent d63f2e8 commit 6497f39
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 34 deletions.
74 changes: 61 additions & 13 deletions frontends/p4/functionsInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,48 @@ Visitor::profile_t FunctionsInliner::init_apply(const IR::Node *node) {
return rv;
}

class FunctionsInliner::isLocalExpression : public Inspector, ResolutionContext {
bool done = false;
bool result = true;

profile_t init_apply(const IR::Node *node) override {
BUG_CHECK(!done, "isLocalExpression can only be applied once");
return Inspector::init_apply(node);
}
void end_apply() override { done = true; }
bool preorder(const IR::Node *) override { return result; }
bool preorder(const IR::Path *p) override {
if (p->absolute) return result = false;
// cribbed from ResolutionContext::resolve -- we want to resolve the name,
// except we want to know what scope it is found in, not what it resolves to
const Context *ctxt = nullptr;
while (auto scope = findOrigCtxt<IR::INamespace>(ctxt)) {
if (scope->is<IR::P4Control>() || scope->is<IR::P4Parser>() ||
scope->is<IR::P4Program>() || scope->is<IR::V1Control>() ||
scope->is<IR::V1Parser>() || scope->is<IR::V1Program>()) {
// these are "global" things that may contain functions
return result = false;
}
if (!lookup(scope, p->name, P4::ResolutionType::Any).empty()) return result;
if (scope->is<IR::Function>() || scope->is<IR::P4Action>()) {
// no need to look further as can't have nested functions
return result = false;
}
}
BUG("failed to reach global scope");
return result;
}

public:
isLocalExpression(const IR::Expression *expr, const Visitor_Context *ctxt) {
expr->apply(*this, ctxt);
}
explicit operator bool() {
BUG_CHECK(done, "isLocalExpression not computed");
return result;
}
};

bool FunctionsInliner::preCaller() {
LOG2("Visiting: " << dbp(getOriginal()));
if (toInline->sites.count(getOriginal()) == 0) {
Expand Down Expand Up @@ -195,26 +237,35 @@ const IR::Statement *FunctionsInliner::inlineBefore(const IR::Node *calleeNode,
BUG_CHECK(callee, "%1%: expected a function", calleeNode);

IR::IndexedVector<IR::StatOrDecl> body;
ParameterSubstitution subst;
ParameterSubstitution subst; // rewrites for params
TypeVariableSubstitution tvs; // empty

std::map<const IR::Parameter *, cstring> paramRename;
ParameterSubstitution substitution;
ParameterSubstitution substitution; // map params to actual arguments
substitution.populate(callee->type->parameters, mce->arguments);

// parameters that need copyout
std::vector<std::pair<cstring, const IR::Argument *>> needCopyout;

// evaluate in and inout parameters in order
for (auto param : callee->type->parameters->parameters) {
auto argument = substitution.lookup(param);
cstring newName = nameGen->newName(param->name.name.string_view());
paramRename.emplace(param, newName);
if (param->direction == IR::Direction::In || param->direction == IR::Direction::InOut) {
if ((param->direction == IR::Direction::Out || param->direction == IR::Direction::InOut) &&
isLocalExpression(argument->expression, getChildContext())) {
// If the actual parameter is local to the caller, we can just rewrite the callee
// to access it directly, without the overhead of copying it in or out
subst.add(param, argument);
} else if (param->direction == IR::Direction::In ||
param->direction == IR::Direction::InOut) {
auto vardecl = new IR::Declaration_Variable(newName, param->annotations, param->type);
body.push_back(vardecl);
auto copyin =
new IR::AssignmentStatement(new IR::PathExpression(newName), argument->expression);
body.push_back(copyin);
subst.add(param, new IR::Argument(argument->srcInfo, argument->name,
new IR::PathExpression(newName)));
if (param->direction == IR::Direction::InOut)
needCopyout.emplace_back(newName, argument);
} else if (param->direction == IR::Direction::None) {
// This works because there can be no side-effects in the evaluation of this
// argument.
Expand All @@ -225,6 +276,7 @@ const IR::Statement *FunctionsInliner::inlineBefore(const IR::Node *calleeNode,
subst.add(param, new IR::Argument(argument->srcInfo, argument->name,
new IR::PathExpression(newName)));
body.push_back(vardecl);
needCopyout.emplace_back(newName, argument);
}
}

Expand All @@ -242,14 +294,10 @@ const IR::Statement *FunctionsInliner::inlineBefore(const IR::Node *calleeNode,
auto retExpr = cloneBody(funclone->body->components, body);

// copy out and inout parameters
for (auto param : callee->type->parameters->parameters) {
auto left = substitution.lookup(param);
if (param->direction == IR::Direction::InOut || param->direction == IR::Direction::Out) {
cstring newName = ::P4::get(paramRename, param);
auto right = new IR::PathExpression(newName);
auto copyout = new IR::AssignmentStatement(left->expression, right);
body.push_back(copyout);
}
for (auto [newName, argument] : needCopyout) {
auto right = new IR::PathExpression(newName);
auto copyout = new IR::AssignmentStatement(argument->expression, right);
body.push_back(copyout);
}

if (auto assign = statement->to<IR::AssignmentStatement>()) {
Expand Down
1 change: 1 addition & 0 deletions frontends/p4/functionsInlining.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class FunctionsInliner : public AbstractInliner<FunctionsInlineList, FunctionsIn
const IR::Node *postCaller(const IR::Node *caller);
const ReplacementMap *getReplacementMap() const;
void dumpReplacementMap() const;
class isLocalExpression; // functor to test actual arguments scope use

public:
FunctionsInliner() = default;
Expand Down
17 changes: 13 additions & 4 deletions frontends/p4/parameterSubstitution.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,21 @@ class ParameterSubstitution : public IHasDbPrint {
}

void dbprint(std::ostream &out) const {
bool brief = (DBPrint::dbgetflags(out) & DBPrint::Brief);
if (paramList != nullptr) {
for (auto s : *paramList->getEnumerator())
out << dbp(s) << "=>" << dbp(lookup(s)) << std::endl;
if (!brief) out << "paramList:" << Log::endl;
for (auto s : *paramList->getEnumerator()) {
out << dbp(s) << "=>" << dbp(lookup(s));
if (!brief) out << " " << lookup(s);
out << Log::endl;
}
} else {
for (auto s : parametersByName)
out << dbp(s.second) << "=>" << dbp(lookupByName(s.first)) << std::endl;
if (!brief) out << "parametersByName:" << Log::endl;
for (auto s : parametersByName) {
out << dbp(s.second) << "=>" << dbp(lookupByName(s.first));
if (!brief) out << " " << lookupByName(s.first);
out << Log::endl;
}
}
}

Expand Down
10 changes: 2 additions & 8 deletions testdata/p4_16_samples_outputs/issue2345-2-frontend.p4
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,15 @@ parser p(packet_in pkt, out Headers hdr, inout Meta m, inout standard_metadata_t

control ingress(inout Headers h, inout Meta m, inout standard_metadata_t sm) {
@name("ingress.val_0") Headers val;
@name("ingress.val1_0") Headers val1;
@name("ingress.val1_1") Headers val1_2;
@name("ingress.simple_action") action simple_action() {
if (h.eth_hdr.eth_type == 16w1) {
;
} else {
h.eth_hdr.src_addr = 48w1;
val = h;
val1 = val;
val1.eth_hdr.dst_addr = val1.eth_hdr.dst_addr + 48w3;
val = val1;
val.eth_hdr.dst_addr = val.eth_hdr.dst_addr + 48w3;
val.eth_hdr.eth_type = 16w2;
val1_2 = val;
val1_2.eth_hdr.dst_addr = val1_2.eth_hdr.dst_addr + 48w3;
val = val1_2;
val.eth_hdr.dst_addr = val.eth_hdr.dst_addr + 48w3;
h = val;
h.eth_hdr.dst_addr = h.eth_hdr.dst_addr + 48w4;
}
Expand Down
12 changes: 3 additions & 9 deletions testdata/p4_16_samples_outputs/issue2345-2-midend.p4
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,16 @@ parser p(packet_in pkt, out Headers hdr, inout Meta m, inout standard_metadata_t

control ingress(inout Headers h, inout Meta m, inout standard_metadata_t sm) {
ethernet_t val_eth_hdr;
ethernet_t val1_eth_hdr;
ethernet_t val1_2_eth_hdr;
@name("ingress.simple_action") action simple_action() {
if (h.eth_hdr.eth_type == 16w1) {
;
} else {
h.eth_hdr.src_addr = 48w1;
val_eth_hdr = h.eth_hdr;
val1_eth_hdr = h.eth_hdr;
val1_eth_hdr.dst_addr = val1_eth_hdr.dst_addr + 48w3;
val_eth_hdr = val1_eth_hdr;
val_eth_hdr.dst_addr = val_eth_hdr.dst_addr + 48w3;
val_eth_hdr.eth_type = 16w2;
val1_2_eth_hdr = val_eth_hdr;
val1_2_eth_hdr.dst_addr = val1_2_eth_hdr.dst_addr + 48w3;
val_eth_hdr = val1_2_eth_hdr;
h.eth_hdr = val1_2_eth_hdr;
val_eth_hdr.dst_addr = val_eth_hdr.dst_addr + 48w3;
h.eth_hdr = val_eth_hdr;
h.eth_hdr.dst_addr = h.eth_hdr.dst_addr + 48w4;
}
}
Expand Down

0 comments on commit 6497f39

Please sign in to comment.