Skip to content

Commit

Permalink
Strips the op name suffix added from `PopulateFunctionalToRegionPatte…
Browse files Browse the repository at this point in the history
…rns` inside `PopulateRegionToFunctionalPatterns` to reduce model size and improve debugging experience.

PiperOrigin-RevId: 532845336
  • Loading branch information
tensorflower-gardener committed May 17, 2023
1 parent d3d9f17 commit 9ffdba6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/transforms/region_to_functional/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"//tensorflow/core/ir:Dialect",
"//tensorflow/core/ir/types:Dialect",
"//tensorflow/core/transforms:utils",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand Down
21 changes: 21 additions & 0 deletions tensorflow/core/transforms/region_to_functional/impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -151,6 +153,9 @@ class BasePattern {
ArrayAttr GetControlRetAttrs(ValueRange ctls, ValueRange args,
NameUniquer *name_uniquer) const;

// Strip out added names.
void StripAddedSuffix(Region &region) const;

// Create a function with the given name and attributes. Use the types of the
// block arguments and the given results types. Take the body of the region.
GraphFuncOp CreateFunc(Location loc, const Twine &sym_name, Region &region,
Expand Down Expand Up @@ -598,6 +603,20 @@ ArrayAttr BasePattern::GetControlRetAttrs(ValueRange ctls, ValueRange args,
return ArrayAttr::get(ctx_, ctl_ret_attrs);
}

void BasePattern::StripAddedSuffix(Region &region) const {
StringAttr name_id = dialect_.getNameAttrIdentifier();
for (Operation &op : region.getOps()) {
if (auto name = op.getAttrOfType<StringAttr>(name_id)) {
if (absl::StrContains(name.getValue().str(), "_tfg_inlined_")) {
std::vector<std::string> name_components =
absl::StrSplit(name.getValue().str(), "_tfg_inlined_");
auto new_name = StringAttr::get(op.getContext(), name_components[0]);
op.setAttr(name_id, new_name);
}
}
}
}

GraphFuncOp BasePattern::CreateFunc(Location loc, const Twine &sym_name,
Region &region, TypeRange res_types,
NamedAttrList attrs) const {
Expand Down Expand Up @@ -644,6 +663,8 @@ FuncAttr BasePattern::Outline(Operation *op, PatternRewriter &rewriter,
ValueRange args, Region &region,
RegionAttr preserved, DictionaryAttr attrs,
const Twine &func_name) const {
StripAddedSuffix(region);

// Create a name scope for the function.
NameUniquer name_uniquer(ctx_);

Expand Down

0 comments on commit 9ffdba6

Please sign in to comment.