Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Support large constants saved/loaded outside of VM executable #9734

Merged
merged 2 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 66 additions & 13 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,31 @@ class Executable : public ModuleNode {

/*!
* \brief Write the Executable to the binary stream in serialized form.
*
* Late-bound constants (if any) must have already been saved by \p
* MoveLateBoundConstantsToBinary.
*
* \param stream The binary stream to save the executable to.
*/
void SaveToBinary(dmlc::Stream* stream) final;

/*!
* \brief Write the Executable to the provided path as a file contianing its serialized content.
* \brief Write the Executable to the provided path as a file containing its serialized content.
*
* Late-bound constants (if any) must have already been saved by \p
* MoveLateBoundConstantsToBinary.
*
* \param path The path to write the serialized data to.
* \param format The format of the serialized blob.
*/
void SaveToFile(const std::string& path, const std::string& format) final;

/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
* code section. This object must outlive the returned byte array.
*
* Late-bound constants (if any) must have already been saved by \p
* MoveLateBoundConstantsToBinary.
*
* \return The binary representation of the VM.
*/
Expand All @@ -90,13 +101,44 @@ class Executable : public ModuleNode {
/*!
* \brief Load the saved VM executable.
*
* Late-bound constants (if any) must then be loaded by \p LoadLateBoundConstantsFromBinary.
*
* \param code The bytecode in string.
* \param lib The compiled runtime library.
*
* \return exe The constructed executable.
*/
static runtime::Module Load(const std::string& code, const runtime::Module lib);

/*!
* \brief Returns the late-bound constants for the executable (if any) as a byte-stream.
* Leaves the executable's late-bound constants map empty. Only constants who's byte
* tensor size is greater than or equal to \p byte_limit are marked as late-bound. \p byte_limit
* may be zero.
*
* Must be called before \p SaveToBinary and friends if late-bound constants are
* desired. Otherwise can be ignore.
*/
void MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit);

/*!
* \brief As for \p MoveLateBoundConstantsToStream, but save to file at \p path.
*/
void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit);

/*!
* \brief Restores the late-bound constants for the executable (if any) from given byte-stream.
*
* Must be called after \p Load but before any other methods if \p MoveLateBoundConstantsToBinary
* was used when saving. Otherwise can be ignored.
*/
void LoadLateBoundConstantsFromStream(dmlc::Stream* stream);

/*!
* \brief As for \p LoadLateBoundConstantsFromStream, but load from file at \p path.
*/
void LoadLateBoundConstantsFromFile(const std::string& path);

/*!
* \brief Get the serialized form of the `functions`. This is
* essentially bytecode serialization.
Expand Down Expand Up @@ -125,7 +167,7 @@ class Executable : public ModuleNode {
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.

*
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
Expand Down Expand Up @@ -205,8 +247,19 @@ class Executable : public ModuleNode {
* shape-related data and code.
*/
int host_device_index = -1;
/*! \brief The global constant pool. */
/*!
* \brief The global constant array.
*
* LoadConst instructions indexes are w.r.t. this vector. Late-bound constants are removed
* from this table after saving late-bound constants.
*/
std::vector<ObjectRef> constants;
/*!
* \brief For each constant index the name of the late-bound constant, or null if constant is
* immediate. Only populated after loading executable but before loading late-bound constants.
*/
std::vector<String> late_bound_constant_names;

/*! \brief A map from globals (as strings) to their index in the Relay function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function's global name (as string) to the index that
Expand Down Expand Up @@ -238,9 +291,16 @@ class Executable : public ModuleNode {
/*!
* \brief Save the constant pool.
*
* \param strm The output stream.
* \param stream The output stream.
*/
void SaveConstantSection(dmlc::Stream* stream);

/*!
* \brief Load the constant pool.
*
* \param stream The input stream.
*/
void SaveConstantSection(dmlc::Stream* strm);
void LoadConstantSection(dmlc::Stream* stream);

/*!
* \brief Save primitive op names.
Expand Down Expand Up @@ -270,13 +330,6 @@ class Executable : public ModuleNode {
*/
void LoadGlobalSection(dmlc::Stream* strm);

/*!
* \brief Load the constant pool.
*
* \param strm The input stream.
*/
void LoadConstantSection(dmlc::Stream* strm);

/*!
* \brief Load primitive op names.
*
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class VirtualMachine : public runtime::ModuleNode {
* \brief load the executable for the virtual machine.
* \param exec The executable.
*/
virtual void LoadExecutable(const Executable* exec);
virtual void LoadExecutable(Executable* exec);

protected:
/*! \brief Push a call frame on to the call stack. */
Expand Down Expand Up @@ -300,7 +300,7 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The special return register. */
ObjectRef return_register_;
/*! \brief The executable the VM will operate on. */
const Executable* exec_;
Executable* exec_;
/*! \brief The function name to inputs mapping. */
std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
/*!
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, mod):
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
self._move_late_bound_consts = self.mod["move_late_bound_consts"]
self._load_late_bound_consts = self.mod["load_late_bound_consts"]

def save(self):
"""Save the Relay VM Executable.
Expand Down Expand Up @@ -162,11 +164,11 @@ def load_exec(bytecode, lib):
An executable constructed using the provided artifacts.
"""
if isinstance(bytecode, (bytes, str)):
code = bytearray(bytecode)
bytecode = bytearray(bytecode)
elif not isinstance(bytecode, (bytearray, TVMByteArray)):
raise TypeError(
"bytecode is expected to be the type of bytearray "
+ "or TVMByteArray, but received {}".format(type(code))
+ "or TVMByteArray, but received {}".format(type(bytecode))
)

if lib is not None and not isinstance(lib, tvm.runtime.Module):
Expand Down Expand Up @@ -298,6 +300,14 @@ def get_function_params(self, func_name):
self._function_params[func_name] = params
return params

def move_late_bound_consts(self, path, byte_limit):
"""Move all constants of byte size greater or equal to byte_limit to file at path"""
return self._move_late_bound_consts(path, byte_limit)

def load_late_bound_consts(self, path):
"""Re-load constants previously saved to file at path"""
return self._load_late_bound_consts(path, bytes)


class VirtualMachine(object):
"""Relay VM runtime.
Expand Down
Loading