Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pybinding] Add mapping from C++ program::verification to Python #5915

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_reset_profile_results, # noqa: F401
BundledModule, # noqa: F401
ExecuTorchModule, # noqa: F401
Verification, # noqa: F401
)

# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`
Expand Down
69 changes: 51 additions & 18 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,15 @@ class Module final {
explicit Module(
std::unique_ptr<DataLoader> loader,
std::unique_ptr<ETDumpGen> tracer = nullptr,
size_t debug_buffer_size = 0)
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency)
: loader_(std::move(loader)),
event_tracer_(std::move(tracer)),
debug_buffer_size_(debug_buffer_size) {
::executorch::runtime::runtime_init();
Result<Program> program = Program::load(
loader_.get(), Program::Verification::InternalConsistency);
Result<Program> program =
Program::load(loader_.get(), program_verification);
THROW_IF_ERROR(
program.error(),
"loading program failed with error: 0x%" PRIx32,
Expand Down Expand Up @@ -388,19 +390,22 @@ inline std::unique_ptr<Module> load_module_from_buffer(
const void* ptr,
size_t ptr_len,
bool enable_etdump,
size_t debug_buffer_size) {
size_t debug_buffer_size,
Program::Verification program_verification) {
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
return std::make_unique<Module>(
std::move(loader),
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
debug_buffer_size);
debug_buffer_size,
program_verification);
}

inline std::unique_ptr<Module> load_module_from_file(
const std::string& path,
bool enable_etdump,
size_t debug_buffer_size) {
size_t debug_buffer_size,
Program::Verification program_verification) {
EXECUTORCH_SCOPE_PROF("load_module_from_file");

Result<MmapDataLoader> res = MmapDataLoader::from(
Expand All @@ -415,7 +420,8 @@ inline std::unique_ptr<Module> load_module_from_file(
return std::make_unique<Module>(
std::move(loader),
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
debug_buffer_size);
debug_buffer_size,
program_verification);
}

static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
Expand Down Expand Up @@ -578,30 +584,41 @@ struct PyModule final {
explicit PyModule(
const py::bytes& buffer,
bool enable_etdump,
size_t debug_buffer_size = 0)
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency)
: module_(load_module_from_buffer(
buffer.cast<std::string_view>().data(),
py::len(buffer),
enable_etdump,
debug_buffer_size)) {}
debug_buffer_size,
program_verification)) {}

explicit PyModule(
const void* ptr,
size_t ptr_len,
bool enable_etdump,
size_t debug_buffer_size = 0)
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency)
: module_(load_module_from_buffer(
ptr,
ptr_len,
enable_etdump,
debug_buffer_size)) {}
debug_buffer_size,
program_verification)) {}

explicit PyModule(
const std::string& path,
bool enable_etdump,
size_t debug_buffer_size = 0)
: module_(load_module_from_file(path, enable_etdump, debug_buffer_size)) {
}
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency)
: module_(load_module_from_file(
path,
enable_etdump,
debug_buffer_size,
program_verification)) {}

PyModule(const PyModule&) = delete;
PyModule& operator=(const PyModule&) = delete;
Expand All @@ -612,14 +629,20 @@ struct PyModule final {
static std::unique_ptr<PyModule> load_from_buffer(
const py::bytes& buffer,
bool enable_etdump,
size_t debug_buffer_size = 0) {
return std::make_unique<PyModule>(buffer, enable_etdump, debug_buffer_size);
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency) {
return std::make_unique<PyModule>(
buffer, enable_etdump, debug_buffer_size, program_verification);
}
static std::unique_ptr<PyModule> load_from_file(
const std::string& path,
bool enable_etdump,
size_t debug_buffer_size = 0) {
return std::make_unique<PyModule>(path, enable_etdump, debug_buffer_size);
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency) {
return std::make_unique<PyModule>(
path, enable_etdump, debug_buffer_size, program_verification);
}

static std::unique_ptr<PyModule> load_from_bundled_program(
Expand Down Expand Up @@ -944,19 +967,29 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
// Redirects cout and cerr for function calls this guards to the python env.
auto call_guard = py::
call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>();

// Bind the verification enum to python.
py::enum_<Program::Verification>(m, "Verification")
larryliu0820 marked this conversation as resolved.
Show resolved Hide resolved
.value("Minimal", Program::Verification::Minimal)
.value("InternalConsistency", Program::Verification::InternalConsistency);

m.def(
"_load_for_executorch",
PyModule::load_from_file,
py::arg("path"),
py::arg("enable_etdump") = false,
py::arg("debug_buffer_size") = 0,
py::arg("program_verification") =
Program::Verification::InternalConsistency,
call_guard);
m.def(
"_load_for_executorch_from_buffer",
&PyModule::load_from_buffer,
py::arg("buffer"),
py::arg("enable_etdump") = false,
py::arg("debug_buffer_size") = 0,
py::arg("program_verification") =
Program::Verification::InternalConsistency,
call_guard);
m.def(
"_load_for_executorch_from_bundled_program",
Expand Down
24 changes: 21 additions & 3 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,22 @@
# pyre-strict
from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, Enum, List, Optional, Sequence, Tuple

from executorch.exir._warnings import experimental

@experimental("This API is experimental and subject to change without notice.")
class Verification(Enum):
"""Verification maps C++ Program::Verification to Python.

.. warning::

This API is experimental and subject to change without notice.
"""

Minimal: ...
InternalConsistency: ...

@experimental("This API is experimental and subject to change without notice.")
class ExecuTorchModule:
"""ExecuTorchModule is a Python wrapper around a C++ ExecuTorch program.
Expand Down Expand Up @@ -125,7 +137,10 @@ class MethodMeta:

@experimental("This API is experimental and subject to change without notice.")
def _load_for_executorch(
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0
path: str,
enable_etdump: bool = False,
debug_buffer_size: int = 0,
program_verification: Verification = Verification.InternalConsistency,
) -> ExecuTorchModule:
"""Load an ExecuTorch Program from a file.

Expand All @@ -148,7 +163,10 @@ def _load_for_executorch(

@experimental("This API is experimental and subject to change without notice.")
def _load_for_executorch_from_buffer(
buffer: bytes, enable_etdump: bool = False, debug_buffer_size: int = 0
buffer: bytes,
enable_etdump: bool = False,
debug_buffer_size: int = 0,
program_verification: Verification = Verification.InternalConsistency,
) -> ExecuTorchModule:
"""Same as _load_for_executorch, but takes a byte buffer instead of a file path.

Expand Down
28 changes: 27 additions & 1 deletion extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-unsafe

import unittest
from types import ModuleType
from typing import Any, Callable, Optional, Tuple

import torch
Expand All @@ -17,7 +18,7 @@

def make_test( # noqa: C901
tester: unittest.TestCase,
load_fn: Callable,
runtime: ModuleType,
) -> Callable[[unittest.TestCase], None]:
"""
Returns a function that operates as a test case within a unittest.TestCase class.
Expand All @@ -26,6 +27,7 @@ def make_test( # noqa: C901
which will all have different load functions. In this case each individual test case is a
subfunction of wrapper.
"""
load_fn: Callable = runtime._load_for_executorch_from_buffer

def wrapper(tester: unittest.TestCase) -> None:
class ModuleAdd(torch.nn.Module):
Expand Down Expand Up @@ -352,6 +354,29 @@ def test_bad_name(tester) -> None:
with tester.assertRaises(RuntimeError):
executorch_module.run_method("not_a_real_method", inputs)

def test_verification_config(tester) -> None:
# Create an ExecuTorch program from ModuleAdd.
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())
Verification = runtime.Verification

# Use pybindings to load and execute the program.
for config in [Verification.Minimal, Verification.InternalConsistency]:
executorch_module = load_fn(
exported_program.buffer,
enable_etdump=False,
debug_buffer_size=0,
program_verification=config,
)

executorch_output = executorch_module.forward(inputs)[0]

# The test module adds the two inputs, so its output should be the same
# as adding them directly.
expected = inputs[0] + inputs[1]

tester.assertEqual(str(expected), str(executorch_output))

######### RUN TEST CASES #########
test_e2e(tester)
test_multiple_entry(tester)
Expand All @@ -363,5 +388,6 @@ def test_bad_name(tester) -> None:
test_constant_output_not_memory_planned(tester)
test_method_meta(tester)
test_bad_name(tester)
test_verification_config(tester)

return wrapper
21 changes: 8 additions & 13 deletions extension/pybindings/test/test_pybindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,19 @@

kernel_mode = None # either aten mode or portable mode
try:
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from executorch.extension.pybindings import portable_lib as runtime

kernel_mode = "portable"
except Exception:
print("can't load portable lib")

try:
from executorch.extension.pybindings.aten_lib import ( # noqa: F811
_load_for_executorch_from_buffer,
)

assert kernel_mode is None
if kernel_mode is None:
try:
from executorch.extension.pybindings import aten_lib as runtime # noqa: F811

kernel_mode = "aten"
except Exception:
print("can't load aten lib")
kernel_mode = "aten"
except Exception:
print("can't load aten lib")

assert kernel_mode is not None

Expand All @@ -37,4 +32,4 @@

class PybindingsTest(unittest.TestCase):
def test(self):
make_test(self, _load_for_executorch_from_buffer)(self)
make_test(self, runtime)(self)
Loading