Skip to content

Commit

Permalink
[Relay] Support large constants saved/loaded outside of VM executable (
Browse files Browse the repository at this point in the history
…apache#9734)

* [Relay] Support large constants.

This allows constant tensors at or above a given byte limit to be marked as
'late bound' and saved/reloaded to a file independently of the overall
executable. Since the executable is often embedded in the data segment of
generated runtime Modules this avoids problems with external tools which can't
handle multi-gigabyte data segments.

[ACE-466 in OctoML JIRA]

* [checkpoint] fix latent bytecode/code bug
  • Loading branch information
mbs-octoml authored and ylc committed Jan 7, 2022
1 parent 3b9db46 commit 534b25a
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 56 deletions.
79 changes: 66 additions & 13 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,31 @@ class Executable : public ModuleNode {

/*!
* \brief Write the Executable to the binary stream in serialized form.
*
* Late-bound constants (if any) must have already been saved by \p
* MoveLateBoundConstantsToBinary.
*
* \param stream The binary stream to save the executable to.
*/
void SaveToBinary(dmlc::Stream* stream) final;

/*!
* \brief Write the Executable to the provided path as a file contianing its serialized content.
* \brief Write the Executable to the provided path as a file containing its serialized content.
*
* Late-bound constants (if any) must have already been saved by \p
* MoveLateBoundConstantsToBinary.
*
* \param path The path to write the serialized data to.
* \param format The format of the serialized blob.
*/
void SaveToFile(const std::string& path, const std::string& format) final;

/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
* code section. This object must outlive the returned byte array.
*
* Late-bound constants (if any) must have already been saved by \p
* MoveLateBoundConstantsToBinary.
*
* \return The binary representation of the VM.
*/
Expand All @@ -90,13 +101,44 @@ class Executable : public ModuleNode {
/*!
* \brief Load the saved VM executable.
*
* Late-bound constants (if any) must then be loaded by \p LoadLateBoundConstantsFromBinary.
*
* \param code The bytecode in string.
* \param lib The compiled runtime library.
*
* \return exe The constructed executable.
*/
static runtime::Module Load(const std::string& code, const runtime::Module lib);

/*!
* \brief Returns the late-bound constants for the executable (if any) as a byte-stream.
* Leaves the executable's late-bound constants map empty. Only constants who's byte
* tensor size is greater than or equal to \p byte_limit are marked as late-bound. \p byte_limit
* may be zero.
*
* Must be called before \p SaveToBinary and friends if late-bound constants are
* desired. Otherwise can be ignore.
*/
void MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit);

/*!
* \brief As for \p MoveLateBoundConstantsToStream, but save to file at \p path.
*/
void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit);

/*!
* \brief Restores the late-bound constants for the executable (if any) from given byte-stream.
*
* Must be called after \p Load but before any other methods if \p MoveLateBoundConstantsToBinary
* was used when saving. Otherwise can be ignored.
*/
void LoadLateBoundConstantsFromStream(dmlc::Stream* stream);

/*!
* \brief As for \p LoadLateBoundConstantsFromStream, but load from file at \p path.
*/
void LoadLateBoundConstantsFromFile(const std::string& path);

/*!
* \brief Get the serialized form of the `functions`. This is
* essentially bytecode serialization.
Expand Down Expand Up @@ -125,7 +167,7 @@ class Executable : public ModuleNode {
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
*
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
Expand Down Expand Up @@ -205,8 +247,19 @@ class Executable : public ModuleNode {
* shape-related data and code.
*/
int host_device_index = -1;
/*! \brief The global constant pool. */
/*!
* \brief The global constant array.
*
* LoadConst instructions indexes are w.r.t. this vector. Late-bound constants are removed
* from this table after saving late-bound constants.
*/
std::vector<ObjectRef> constants;
/*!
* \brief For each constant index the name of the late-bound constant, or null if constant is
* immediate. Only populated after loading executable but before loading late-bound constants.
*/
std::vector<String> late_bound_constant_names;

/*! \brief A map from globals (as strings) to their index in the Relay function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function's global name (as string) to the index that
Expand Down Expand Up @@ -238,9 +291,16 @@ class Executable : public ModuleNode {
/*!
* \brief Save the constant pool.
*
* \param strm The output stream.
* \param stream The output stream.
*/
void SaveConstantSection(dmlc::Stream* stream);

/*!
* \brief Load the constant pool.
*
* \param stream The input stream.
*/
void SaveConstantSection(dmlc::Stream* strm);
void LoadConstantSection(dmlc::Stream* stream);

/*!
* \brief Save primitive op names.
Expand Down Expand Up @@ -270,13 +330,6 @@ class Executable : public ModuleNode {
*/
void LoadGlobalSection(dmlc::Stream* strm);

/*!
* \brief Load the constant pool.
*
* \param strm The input stream.
*/
void LoadConstantSection(dmlc::Stream* strm);

/*!
* \brief Load primitive op names.
*
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class VirtualMachine : public runtime::ModuleNode {
* \brief load the executable for the virtual machine.
* \param exec The executable.
*/
virtual void LoadExecutable(const Executable* exec);
virtual void LoadExecutable(Executable* exec);

protected:
/*! \brief Push a call frame on to the call stack. */
Expand Down Expand Up @@ -300,7 +300,7 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The special return register. */
ObjectRef return_register_;
/*! \brief The executable the VM will operate on. */
const Executable* exec_;
Executable* exec_;
/*! \brief The function name to inputs mapping. */
std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
/*!
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, mod):
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
self._move_late_bound_consts = self.mod["move_late_bound_consts"]
self._load_late_bound_consts = self.mod["load_late_bound_consts"]

def save(self):
"""Save the Relay VM Executable.
Expand Down Expand Up @@ -162,11 +164,11 @@ def load_exec(bytecode, lib):
An executable constructed using the provided artifacts.
"""
if isinstance(bytecode, (bytes, str)):
code = bytearray(bytecode)
bytecode = bytearray(bytecode)
elif not isinstance(bytecode, (bytearray, TVMByteArray)):
raise TypeError(
"bytecode is expected to be the type of bytearray "
+ "or TVMByteArray, but received {}".format(type(code))
+ "or TVMByteArray, but received {}".format(type(bytecode))
)

if lib is not None and not isinstance(lib, tvm.runtime.Module):
Expand Down Expand Up @@ -298,6 +300,14 @@ def get_function_params(self, func_name):
self._function_params[func_name] = params
return params

def move_late_bound_consts(self, path, byte_limit):
"""Move all constants of byte size greater or equal to byte_limit to file at path"""
return self._move_late_bound_consts(path, byte_limit)

def load_late_bound_consts(self, path):
"""Re-load constants previously saved to file at path"""
return self._load_late_bound_consts(path, bytes)


class VirtualMachine(object):
"""Relay VM runtime.
Expand Down
Loading

0 comments on commit 534b25a

Please sign in to comment.