Skip to content

Commit

Permalink
New Runtime pybind API (#6063)
Browse files Browse the repository at this point in the history
Summary:
Based on this proposal: https://docs.google.com/document/d/10Q4-pt97inQQtFf-FjjwhMaDXXCfk1zGy6V6EkygNUY/edit#heading=h.fcrpnrtb6cud
    
Historically our pybinding APIs are not following the same C++ modeling
(Program, Method etc) and hence it's hard to use and easy to hit
footguns - for example, if we load the program and return it from a
python method, it goes out of the scope and releases the memory.

This effort is to create Pybind APIs that resembles C++ objects so it's
less confusing to the users.

Add the following python classes:
* `Runtime`: a singleton object hosting methods like `load_program`.
  Returns a `Program` object when calling `load_program`. Also exposes
  the operator registry
* `Program`: each pte file should have one `Program` object. Most
  important method is `load_method` which returns a `Method` object. It
  has a property `method_names` where we can inspect what methods are
  inside this .pte file.
* `Method`: one object per method name in a given `Program`. Exposes
  `execute` which takes in pytree flattened torch tensors as input and
  return pytree flattened output. It also exposes `MethodMeta` for users
  to inspect more information regarding input/output of this method.


Reviewed By: dbort

Differential Revision: D64132360

Pulled By: larryliu0820
  • Loading branch information
larryliu0820 authored and facebook-github-bot committed Oct 10, 2024
1 parent 69c2c76 commit 3e854de
Show file tree
Hide file tree
Showing 12 changed files with 416 additions and 87 deletions.
1 change: 1 addition & 0 deletions extension/pybindings/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ runtime.python_library(
srcs = ["portable_lib.py"],
visibility = [
"//executorch/exir/...",
"//executorch/runtime/...",
"@EXECUTORCH_CLIENTS",
],
deps = [":_portable_lib"],
Expand Down
1 change: 1 addition & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_reset_profile_results, # noqa: F401
BundledModule, # noqa: F401
ExecuTorchModule, # noqa: F401
MethodMeta, # noqa: F401
Verification, # noqa: F401
)

Expand Down
14 changes: 14 additions & 0 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,15 @@ class Module final {
return *methods_[method_name].get();
}

/// Returns the names of all methods in the program.
std::vector<std::string> method_names() const {
std::vector<std::string> names;
for (const auto& method : methods_) {
names.push_back(method.first);
}
return names;
}

bool has_etdump() {
return static_cast<bool>(event_tracer_);
}
Expand Down Expand Up @@ -905,6 +914,10 @@ struct PyModule final {
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
}

std::vector<std::string> method_names() {
return module_->method_names();
}

private:
std::shared_ptr<Module> module_;
// Need to keep-alive output storages until they can be compared in case of
Expand Down Expand Up @@ -1043,6 +1056,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
&PyModule::method_meta,
py::arg("method_name"),
call_guard)
.def("method_names", &PyModule::method_names, call_guard)
.def(
"run_method",
&PyModule::run_method,
Expand Down
1 change: 1 addition & 0 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ExecuTorchModule:
self, path: str, debug_buffer_path: Optional[str] = None
) -> None: ...
def method_meta(self, method_name: str) -> MethodMeta: ...
def method_names(self) -> List[str]: ...

@experimental("This API is experimental and subject to change without notice.")
class BundledModule:
Expand Down
5 changes: 4 additions & 1 deletion extension/pybindings/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ runtime.python_library(
srcs = [
"make_test.py",
],
visibility = ["//executorch/extension/pybindings/..."],
visibility = [
"//executorch/extension/pybindings/...",
"//executorch/runtime/...",
],
deps = [
"//caffe2:torch",
"//caffe2:torch_fx",
Expand Down
175 changes: 89 additions & 86 deletions extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,118 +16,124 @@
from torch.export import export


def make_test( # noqa: C901
tester: unittest.TestCase,
runtime: ModuleType,
) -> Callable[[unittest.TestCase], None]:
"""
Returns a function that operates as a test case within a unittest.TestCase class.
class ModuleAdd(torch.nn.Module):
"""The module to serialize and execute."""

Used to allow the test code for pybindings to be shared across different pybinding libs
which will all have different load functions. In this case each individual test case is a
subfunction of wrapper.
"""
load_fn: Callable = runtime._load_for_executorch_from_buffer
def __init__(self):
super(ModuleAdd, self).__init__()

def wrapper(tester: unittest.TestCase) -> None:
class ModuleAdd(torch.nn.Module):
"""The module to serialize and execute."""
def forward(self, x, y):
return x + y

def __init__(self):
super(ModuleAdd, self).__init__()
def get_methods_to_export(self):
return ("forward",)

def forward(self, x, y):
return x + y
def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))

def get_methods_to_export(self):
return ("forward",)

def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))
class ModuleMulti(torch.nn.Module):
"""The module to serialize and execute."""

class ModuleMulti(torch.nn.Module):
"""The module to serialize and execute."""
def __init__(self):
super(ModuleMulti, self).__init__()

def __init__(self):
super(ModuleMulti, self).__init__()
def forward(self, x, y):
return x + y

def forward(self, x, y):
return x + y
def forward2(self, x, y):
return x + y + 1

def forward2(self, x, y):
return x + y + 1
def get_methods_to_export(self):
return ("forward", "forward2")

def get_methods_to_export(self):
return ("forward", "forward2")
def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))

def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))

class ModuleAddSingleInput(torch.nn.Module):
"""The module to serialize and execute."""
class ModuleAddSingleInput(torch.nn.Module):
"""The module to serialize and execute."""

def __init__(self):
super(ModuleAddSingleInput, self).__init__()
def __init__(self):
super(ModuleAddSingleInput, self).__init__()

def forward(self, x):
return x + x
def forward(self, x):
return x + x

def get_methods_to_export(self):
return ("forward",)
def get_methods_to_export(self):
return ("forward",)

def get_inputs(self):
return (torch.ones(2, 2),)
def get_inputs(self):
return (torch.ones(2, 2),)

class ModuleAddConstReturn(torch.nn.Module):
"""The module to serialize and execute."""

def __init__(self):
super(ModuleAddConstReturn, self).__init__()
self.state = torch.ones(2, 2)
class ModuleAddConstReturn(torch.nn.Module):
"""The module to serialize and execute."""

def forward(self, x):
return x + self.state, self.state
def __init__(self):
super(ModuleAddConstReturn, self).__init__()
self.state = torch.ones(2, 2)

def get_methods_to_export(self):
return ("forward",)
def forward(self, x):
return x + self.state, self.state

def get_inputs(self):
return (torch.ones(2, 2),)
def get_methods_to_export(self):
return ("forward",)

def create_program(
eager_module: torch.nn.Module,
et_config: Optional[ExecutorchBackendConfig] = None,
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
"""Returns an executorch program based on ModuleAdd, along with inputs."""
def get_inputs(self):
return (torch.ones(2, 2),)

# Trace the test module and create a serialized ExecuTorch program.
inputs = eager_module.get_inputs()
input_map = {}
for method in eager_module.get_methods_to_export():
input_map[method] = inputs

class WrapperModule(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def create_program(
eager_module: torch.nn.Module,
et_config: Optional[ExecutorchBackendConfig] = None,
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
"""Returns an executorch program based on ModuleAdd, along with inputs."""

def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)
# Trace the test module and create a serialized ExecuTorch program.
inputs = eager_module.get_inputs()
input_map = {}
for method in eager_module.get_methods_to_export():
input_map[method] = inputs

exported_methods = {}
# These cleanup passes are required to convert the `add` op to its out
# variant, along with some other transformations.
for method_name, method_input in input_map.items():
wrapped_mod = WrapperModule( # pyre-ignore[16]
getattr(eager_module, method_name)
)
exported_methods[method_name] = export(wrapped_mod, method_input)
class WrapperModule(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)

exported_methods = {}
# These cleanup passes are required to convert the `add` op to its out
# variant, along with some other transformations.
for method_name, method_input in input_map.items():
wrapped_mod = WrapperModule(
getattr(eager_module, method_name)
)
exported_methods[method_name] = export(wrapped_mod, method_input)

exec_prog = to_edge(exported_methods).to_executorch(config=et_config)

exec_prog = to_edge(exported_methods).to_executorch(config=et_config)
# Create the ExecuTorch program from the graph.
exec_prog.dump_executorch_program(verbose=True)
return (exec_prog, inputs)

# Create the ExecuTorch program from the graph.
exec_prog.dump_executorch_program(verbose=True)
return (exec_prog, inputs)

def make_test( # noqa: C901
tester: unittest.TestCase,
runtime: ModuleType,
) -> Callable[[unittest.TestCase], None]:
"""
Returns a function that operates as a test case within a unittest.TestCase class.
Used to allow the test code for pybindings to be shared across different pybinding libs
which will all have different load functions. In this case each individual test case is a
subfunction of wrapper.
"""
load_fn: Callable = runtime._load_for_executorch_from_buffer

def wrapper(tester: unittest.TestCase) -> None:

######### TEST CASES #########

Expand Down Expand Up @@ -298,7 +304,6 @@ def test_constant_output_not_memory_planned(tester):
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))

def test_method_meta(tester) -> None:
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())

# Use pybindings to load the program and query its metadata.
Expand Down Expand Up @@ -345,7 +350,6 @@ def test_method_meta(tester) -> None:

def test_bad_name(tester) -> None:
# Create an ExecuTorch program from ModuleAdd.
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())

# Use pybindings to load and execute the program.
Expand All @@ -356,7 +360,6 @@ def test_bad_name(tester) -> None:

def test_verification_config(tester) -> None:
# Create an ExecuTorch program from ModuleAdd.
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())
Verification = runtime.Verification

Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ addopts =
backends/xnnpack/test
# extension/
extension/pybindings/test
# Runtime
runtime
# test
test/end2end/test_end2end.py
--ignore=backends/xnnpack/test/ops/linear.py
Expand Down
14 changes: 14 additions & 0 deletions runtime/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "runtime",
srcs = ["__init__.py"],
deps = [
"//executorch/extension/pybindings:portable_lib",
],
visibility = [
"//executorch/runtime/...",
],
)
Loading

0 comments on commit 3e854de

Please sign in to comment.