Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
Add methods to get and set late-bound constants. (apache#12664)
Browse files Browse the repository at this point in the history
* Add methods to read and restore late-bound constants on Executable.

* Add bindings for new functions

* Cleanup

* Fix function name

* Add tests for python API to access new load/save functions

* Add another tests for python API to access new load/save functions where there are no constants
  • Loading branch information
rkimball authored and xinetzone committed Nov 25, 2022
1 parent 11eba72 commit 3821c90
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 1 deletion.
13 changes: 13 additions & 0 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class TVM_DLL Executable : public ModuleNode {
*/
void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit);

/*!
* \brief Get a map of all constants with larger that byte_limit in size.
*/
Map<String, NDArray> GetLateBoundConstants(size_t byte_limit);

/*!
* \brief Restores the late-bound constants for the executable (if any) from given byte-stream.
*
Expand All @@ -134,6 +139,14 @@ class TVM_DLL Executable : public ModuleNode {
*/
void LoadLateBoundConstantsFromStream(dmlc::Stream* stream);

/*!
* \brief Restores the late-bound constants for the executable (if any) from given map.
*
* Must be called after \p Load but before any other methods if \p MoveLateBoundConstantsToBinary
* was used when saving. Otherwise can be ignored.
*/
void LoadLateBoundConstantsFromMap(Map<String, NDArray> map);

/*!
* \brief As for \p LoadLateBoundConstantsFromStream, but load from file at \p path.
*/
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def __init__(self, mod):
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
self._move_late_bound_consts = self.mod["move_late_bound_consts"]
self._get_late_bound_consts = self.mod["get_late_bound_consts"]
self._load_late_bound_consts = self.mod["load_late_bound_consts"]
self._load_late_bound_consts_from_map = self.mod["load_late_bound_consts_from_map"]

def save(self):
"""Save the Relay VM Executable.
Expand Down Expand Up @@ -312,10 +314,18 @@ def move_late_bound_consts(self, path, byte_limit):
"""Move all constants of byte size greater or equal to byte_limit to file at path"""
return self._move_late_bound_consts(path, byte_limit)

def get_late_bound_consts(self, byte_limit):
"""Return all constants of byte size greater or equal to byte_limit"""
return self._get_late_bound_consts(byte_limit)

def load_late_bound_consts(self, path):
"""Re-load constants previously saved to file at path"""
return self._load_late_bound_consts(path)

def load_late_bound_consts_from_map(self, map):
"""Re-load constants supplied in map"""
return self._load_late_bound_consts_from_map(map)


class VirtualMachine(object):
"""Relay VM runtime.
Expand Down
24 changes: 23 additions & 1 deletion src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,25 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
uint64_t byte_limit = args[1];
MoveLateBoundConstantsToFile(path, static_cast<size_t>(byte_limit));
});
} else if (name == "get_late_bound_consts") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1);
uint64_t byte_limit = args[0];
Map<String, NDArray> consts = GetLateBoundConstants(static_cast<size_t>(byte_limit));
*rv = consts;
});
} else if (name == "load_late_bound_consts") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1);
std::string path = args[0];
LoadLateBoundConstantsFromFile(path);
});
} else if (name == "load_late_bound_consts_from_map") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1);
Map<String, NDArray> map = args[0];
LoadLateBoundConstantsFromMap(map);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc();
Expand Down Expand Up @@ -300,7 +313,7 @@ void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) {
strm->Write(host_device_index);
}

void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) {
Map<String, NDArray> Executable::GetLateBoundConstants(size_t byte_limit) {
ICHECK(late_bound_constant_names.empty());
late_bound_constant_names.reserve(constants.size());
Map<String, NDArray> map;
Expand All @@ -323,6 +336,11 @@ void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byt
}
VLOG(1) << "moved " << map.size() << " constants of " << total_late_bound_bytes
<< " bytes (out of " << constants.size() << " overall) to be late-bound";
return map;
}

void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) {
Map<String, NDArray> map = GetLateBoundConstants(byte_limit);
runtime::SaveParams(stream, map);
}

Expand All @@ -341,6 +359,10 @@ void Executable::LoadLateBoundConstantsFromStream(dmlc::Stream* stream) {
ICHECK_EQ(late_bound_constant_names.size(), constants.size());
Map<String, NDArray> map = runtime::LoadParams(stream);
VLOG(1) << "loaded " << map.size() << " late-bound constants";
LoadLateBoundConstantsFromMap(map);
}

void Executable::LoadLateBoundConstantsFromMap(Map<String, NDArray> map) {
for (size_t const_index = 0; const_index < constants.size(); ++const_index) {
if (!late_bound_constant_names[const_index].defined()) {
ICHECK(constants[const_index].defined())
Expand Down
80 changes: 80 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,5 +1405,85 @@ def test_vm_save_and_load_without_designating_late_bound_consts():
tvm.testing.assert_allclose(expected, actual.numpy())


def test_load_and_save_constants_via_map():
"""Large constants can be serialized outside of executable"""
target = tvm.target.Target("llvm")
dev = tvm.cpu()

# fn(x) { add(x, <large constant>) }
x = relay.var("x", shape=(1000, 1000))
const_data = np.random.rand(1000, 1000).astype("float32")
const = relay.const(const_data, dtype="float32")
func = relay.Function([x], relay.op.add(x, const))
mod = tvm.IRModule.from_expr(func)

# Compile to executable.
vm_exec = vm.compile(mod, target=target)

consts_map = vm_exec.get_late_bound_consts(byte_limit=256)

# Save to constants and library files
temp = utils.tempdir()
path_dso = temp.relpath("lib.so")
vm_exec.mod.export_library(path_dso)

# Load library files and constants
mod = runtime.load_module(path_dso)
mod["load_late_bound_consts_from_map"](consts_map)

# Test main
x_data = np.random.rand(1000, 1000).astype("float32")
the_vm = runtime.vm.VirtualMachine(mod, dev)
actual = the_vm.invoke("main", x_data)
expected = x_data + const_data
tvm.testing.assert_allclose(expected, actual.numpy())

# We load the mod again so it's missing the consts.
mod = runtime.load_module(path_dso)
exe = runtime.vm.Executable(mod)

# Also test loading consts via the VM's wrapper API.
exe.load_late_bound_consts_from_map(consts_map)

# Test main again with consts now loaded via the above API.
x_data = np.random.rand(1000, 1000).astype("float32")
the_vm = runtime.vm.VirtualMachine(exe, dev)
actual = the_vm.invoke("main", x_data)
expected = x_data + const_data
tvm.testing.assert_allclose(expected, actual.numpy())


def test_load_late_bound_consts_via_map_with_no_late_bound_consts():
"""Check that load_late_bound_consts handles a model with no late bound consts."""
target = tvm.target.Target("llvm")
dev = tvm.cpu()

const_data = np.random.rand(1).astype("float64")
x = relay.var("x", shape=(1,), dtype="float64")
const = relay.const(const_data, dtype="float64")

func = relay.Function([x], relay.op.add(x, const))
mod = tvm.IRModule.from_expr(func)

vm_exec = vm.compile(mod, target=target)

temp = utils.tempdir()
path_dso = temp.relpath("lib.so")

# Ensure const_data is below the byte threshold for a late-bound const.
byte_limit = len(const_data.tobytes()) + 1
consts_map = vm_exec.get_late_bound_consts(byte_limit=byte_limit)
vm_exec.mod.export_library(path_dso)

mod = runtime.load_module(path_dso)
mod["load_late_bound_consts_from_map"](consts_map)

x_data = np.random.rand(1).astype("float64")
loaded_vm = runtime.vm.VirtualMachine(mod, dev)
actual = loaded_vm.invoke("main", x_data)
expected = x_data + const_data
tvm.testing.assert_allclose(expected, actual.numpy())


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 3821c90

Please sign in to comment.