Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
Unify name mangling in TVM (apache#12066)
Browse files Browse the repository at this point in the history
* Add NameSupply and GlobalVarSupply

* Build GlobalVarSupply from IRModules instead of having it attached to an IRModule.

* Pass GlobalVarSupply when lowering shape funcs

* Partially replace instantiations of GlobalVar with GlobalVarSupply

* Construct GlobalVarSupply from IRModule

* Add tests for supply

* Add documentation for NameSupply and GlobalVarSupply

Co-authored-by: Florin-Gabriel Blanaru <fgb@system76-pc.localdomain>
  • Loading branch information
2 people authored and xinetzone committed Nov 25, 2022
1 parent 2a7fca5 commit 98b0ee2
Show file tree
Hide file tree
Showing 35 changed files with 1,052 additions and 350 deletions.
12 changes: 9 additions & 3 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#ifndef TVM_DRIVER_DRIVER_API_H_
#define TVM_DRIVER_DRIVER_API_H_

#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
Expand Down Expand Up @@ -99,14 +100,15 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/

TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);
GlobalVarSupply global_var_supply, bool simple_mode = false);

/*!
* \brief Build an IRModule given a TE schedule, args and binds. This function also applies
Expand All @@ -115,13 +117,14 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
* \param args The arguments to the function (Array of Tensor, Buffer and Vars)
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);
GlobalVarSupply global_var_supply, bool simple_mode = false);

/*!
* \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want
Expand All @@ -130,10 +133,13 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module and when creating
* GlobalVars.
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds);
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply);
/*!
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
Expand Down
125 changes: 125 additions & 0 deletions include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/global_var_supply.h
* \brief GlobalVarSupply that can be used to generate unique \class GlobalVar.
*/
#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_
#define TVM_IR_GLOBAL_VAR_SUPPLY_H_

#include <string>
#include <unordered_map>

#include "tvm/ir/expr.h"
#include "tvm/ir/module.h"
#include "tvm/ir/name_supply.h"

namespace tvm {

/*!
* \brief GlobalVarSupply can be used to generate unique GlobalVars.
*/
class GlobalVarSupplyNode : public Object {
public:
/*!
* \brief Empty constructor. Will use an empty NameSupply.
*/
GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}

/*!
* \brief Constructor.
* \param name_supply The NameSupply to use for generating the names of fresh GlobalVars.
* \param name_to_var_map An optional map.
*/
explicit GlobalVarSupplyNode(NameSupply name_supply,
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

/*!
* \brief Generates a unique GlobalVar from this supply.
* \param name The name from which the name of the GlobalVar is derived.
* \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended
* to the name. \return A unique GlobalVar.
*/
GlobalVar FreshGlobal(String name, bool add_prefix = true);

/*!
* \brief Looks up for a GlobalVar with the given name in this supply.
* If no entry is found, creates one, places it in the cache and returns it.
* \param name The name of the GlobalVar to search for.
* \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to
* the name before performing the search. \return A cached GlobalVar.
*/
GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);

/*!
* \brief Reserves an existing GlobalVar with this supply.
* \param var The GlobalVar to be registered.
* \param allow_conflict Allow conflict with other GlobalVars that have the same name.
*/
void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false);

void VisitAttrs(AttrVisitor* v) {}

/*! \brief The NameSupply used to generate unique name hints to GlobalVars. */
NameSupply name_supply_;

static constexpr const char* _type_key = "GlobalVarSupply";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object);

private:
std::unordered_map<std::string, GlobalVar> name_to_var_map_;
};

/*!
* \brief Managed reference class to GlobalVarSupplyNode.
* \sa GlobalVarSupplyNode
*/
class GlobalVarSupply : public ObjectRef {
public:
/*!
* \brief Constructor.
* \param name_supply The NameSupply to be used when generating new GlobalVars.
* \param name_to_var_map An optional map.
*/
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply,
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

/*!
* \brief Constructs a supply from an array of IRModules. GlobalVars generated by this supply are
* guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array
* of IRModules.
*/
TVM_DLL explicit GlobalVarSupply(const Array<IRModule>& modules);

/*!
* \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are
* guaranteed not to conflict with GlobalVars that belong to the modules. \param module The
* IRModule.
*/
TVM_DLL explicit GlobalVarSupply(const IRModule module);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, GlobalVarSupplyNode);
};

} // namespace tvm

#endif // TVM_IR_GLOBAL_VAR_SUPPLY_H_
17 changes: 9 additions & 8 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,6 @@ class IRModuleNode : public Object {
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Returns a version of \p name which is unique amongst all function definitions in module.
*
* \param name The original name.
* \return Updated name which is unique.
*/
String GetUniqueName(const String& name);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Expand Down Expand Up @@ -481,6 +473,15 @@ namespace attr {

// Following are attributes for IRModule only.

/*!
* \brief Name of the module
*
* Type: String
*
* \sa tvm::runtime::String
*/
constexpr const char* kModuleName = "mod_name";

/*!
* \brief Executor targeted by the module
*
Expand Down
123 changes: 123 additions & 0 deletions include/tvm/ir/name_supply.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/name_supply.h
* \brief NameSupply that can be used to generate unique variable names.
*/
#ifndef TVM_IR_NAME_SUPPLY_H_
#define TVM_IR_NAME_SUPPLY_H_

#include <string>
#include <unordered_map>
#include <utility>

#include "tvm/ir/expr.h"

namespace tvm {

/*!
* \brief NameSupply can be used to generate unique names.
*/
class NameSupplyNode : public Object {
public:
/*!
* \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro.
*/
NameSupplyNode() = default;

/*!
* \brief Constructor.
* \param prefix The prefix to be used with this NameSupply.
* \param name_map The map used to guarantee uniqueness.
*/
NameSupplyNode(const String& prefix, std::unordered_map<std::string, int> name_map)
: prefix_(prefix), name_map(std::move(name_map)) {}

/*!
* \brief Generates a unique name from this NameSupply.
* \param name The name from which the generated name is derived.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name. \return A unique name.
*/
String FreshName(const String& name, bool add_prefix = true);

/*!
* \brief Reserves an existing name with this NameSupply.
* \param name The name to be reserved.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name before reserving it. \return The name that was reserved with the NameSupply. It can be
* different if a prefix is added.
*/
String ReserveName(const String& name, bool add_prefix = true);

/*!
* \brief Checks if this NameSupply already generated a name.
* \param name The name to check.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name before checking for it. \return True if the name has already been generated. False
* otherwise.
*/
bool ContainsName(const String& name, bool add_prefix = true);

void VisitAttrs(AttrVisitor* v) {}

// Prefix for all GlobalVar names. It can be empty.
std::string prefix_;

static constexpr const char* _type_key = "NameSupply";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object);

private:
/*! \brief Helper function to add the NameSupply prefix to the name. */
String add_prefix_to_name(const String& name);

/*!
* \brief Function that will generate a unique name.
* \param name The name to be used as a base.
* \return A unique name.
*/
std::string GetUniqueName(std::string name);

/*! \brief A map that is used to generate unique names. */
std::unordered_map<std::string, int> name_map;
};

/*!
* \brief Managed reference class to NameSupplyNode.
* \sa NameSupplyNode
*/
class NameSupply : public ObjectRef {
public:
/*!
* \brief Constructor.
* \param prefix The prefix to be used with this NameSupply.
* \param name_map An optional map.
*/
TVM_DLL explicit NameSupply(const String& prefix,
std::unordered_map<std::string, int> name_map = {});

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode);
};

} // namespace tvm

#endif // TVM_IR_NAME_SUPPLY_H_
Loading

0 comments on commit 98b0ee2

Please sign in to comment.