Skip to content

Commit

Permalink
bundled program alpha document (pytorch#3224)
Browse files Browse the repository at this point in the history
Summary:

as title

Reviewed By: Jack-Khuu

Differential Revision: D56446890
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed Apr 23, 2024
1 parent 6c36f10 commit 4c61b71
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 74 deletions.
139 changes: 66 additions & 73 deletions docs/source/sdk-bundled-io.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ ExecuTorch Program can be emitted from user's model by using ExecuTorch APIs. Fo

In `BundledProgram`, we create two new classes, `MethodTestCase` and `MethodTestSuite`, to hold essential info for ExecuTorch program verification.

`MethodTestCase` represents a single testcase. Each `MethodTestCase` contains inputs and expected outputs for a single execution.

:::{dropdown} `MethodTestCase`

```{eval-rst}
Expand All @@ -31,6 +33,8 @@ In `BundledProgram`, we create two new classes, `MethodTestCase` and `MethodTest
```
:::

`MethodTestSuite` contains all testing info for single method, including a str representing method name, and a `List[MethodTestCase]` for all testcases:

:::{dropdown} `MethodTestSuite`

```{eval-rst}
Expand All @@ -44,18 +48,18 @@ Since each model may have multiple inference methods, we need to generate `List[

### Step 3: Generate `BundledProgram`

We provide `create_bundled_program` API under `executorch/sdk/bundled_program/core.py` to generate `BundledProgram` by bundling the emitted ExecuTorch program with the `List[MethodTestSuite]`:
We provide `BundledProgram` class under `executorch/sdk/bundled_program/core.py` to bundled the `ExecutorchProgram`-like variable, including
`ExecutorchProgram`, `MultiMethodExecutorchProgram` or `ExecutorchProgramManager`, with the `List[MethodTestSuite]`:

:::{dropdown} `BundledProgram`

```{eval-rst}
.. currentmodule:: executorch.sdk.bundled_program.core
.. autofunction:: create_bundled_program
.. autofunction:: executorch.sdk.bundled_program.core.BundledProgram.__init__
:noindex:
```
:::

`create_bundled_program` will do sannity check internally to see if the given `List[MethodTestSuite]` matches the given Program's requirements. Specifically:
Construtor of `BundledProgram `will do sannity check internally to see if the given `List[MethodTestSuite]` matches the given Program's requirements. Specifically:
1. The method_names of each `MethodTestSuite` in `List[MethodTestSuite]` for should be also in program. Please notice that it is no need to set testcases for every method in the Program.
2. The metadata of each testcase should meet the requirement of the coresponding inference methods input.

Expand Down Expand Up @@ -83,20 +87,20 @@ To serialize `BundledProgram` to make runtime APIs use it, we provide two APIs,
Here is a flow highlighting how to generate a `BundledProgram` given a PyTorch model and the representative inputs we want to test it along with.

```python

import torch

from executorch.exir import to_edge
from executorch.sdk import BundledProgram

from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.sdk.bundled_program.core import create_bundled_program
from executorch.sdk.bundled_program.serialize import (
serialize_from_bundled_program_to_flatbuffer,
)

from executorch.exir import to_edge
from torch._export import capture_pre_autograd_graph
from torch.export import export

# Step 1: ExecuTorch Program Export

# Step 1: ExecuTorch Program Export
class SampleModel(torch.nn.Module):
"""An example model with multi-methods. Each method has multiple input and single output"""

Expand All @@ -105,82 +109,70 @@ class SampleModel(torch.nn.Module):
self.a: torch.Tensor = 3 * torch.ones(2, 2, dtype=torch.int32)
self.b: torch.Tensor = 2 * torch.ones(2, 2, dtype=torch.int32)

def encode(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
z = x.clone()
torch.mul(self.a, x, out=z)
y = x.clone()
torch.add(z, self.b, out=y)
torch.add(y, q, out=y)
return y

def decode(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
y = x * q
torch.add(y, self.b, out=y)
return y

# Inference method names of SampleModel we want to bundle testcases to.
# Inference method name of SampleModel we want to bundle testcases to.
# Notices that we do not need to bundle testcases for every inference methods.
method_names = ["encode", "decode"]
method_name = "forward"
model = SampleModel()

capture_inputs = {
m_name: (
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
)
for m_name in method_names
}
# Inputs for graph capture.
capture_input = (
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
)

# Find each method of model needs to be traced my its name, export its FX Graph.
method_graphs = {
m_name: export(getattr(model, m_name), capture_inputs[m_name])
for m_name in method_names
}
# Export method's FX Graph.
method_graph = export(
capture_pre_autograd_graph(model, capture_input),
capture_input,
)

# Emit the traced methods into ET Program.
program = to_edge(method_graphs).to_executorch().executorch_program

# Emit the traced method into ET Program.
et_program = to_edge(method_graph).to_executorch()

# Step 2: Construct MethodTestSuite for Each Method

# Prepare the Test Inputs.

# number of input sets to be verified
# Number of input sets to be verified
n_input = 10

# Input sets to be verified for each inference methods.
# To simplify, here we create same inputs for all methods.
inputs = {
# Inference method name corresponding to its test cases.
m_name: [
# Each list below is a individual input set.
# The number of inputs, dtype and size of each input follow Program's spec.
[
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
]
for _ in range(n_input)
# Input sets to be verified.
inputs = [
# Each list below is a individual input set.
# The number of inputs, dtype and size of each input follow Program's spec.
[
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
]
for m_name in method_names
}
for _ in range(n_input)
]

# Generate Test Suites
method_test_suites = [
MethodTestSuite(
method_name=m_name,
method_name=method_name,
test_cases=[
MethodTestCase(
inputs=input,
expected_outputs=getattr(model, m_name)(*input),
expected_outputs=(getattr(model, method_name)(*input), ),
)
for input in inputs[m_name]
for input in inputs
],
)
for m_name in method_names
),
]

# Step 3: Generate BundledProgram

bundled_program = create_bundled_program(program, method_test_suites)
bundled_program = BundledProgram(et_program, method_test_suites)

# Step 4: Serialize BundledProgram to flatbuffer.
serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer(
Expand Down Expand Up @@ -320,10 +312,10 @@ Here's the example of the dtype of test input not meet model's requirement:
```python
import torch
from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.sdk.bundled_program.core import create_bundled_program
from executorch.exir import to_edge
from executorch.sdk import BundledProgram
from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite
from torch.export import export
Expand All @@ -344,15 +336,16 @@ class Module(torch.nn.Module):
model = Module()
method_names = ["forward"]
inputs = torch.ones(2, 2, dtype=torch.float)
inputs = (torch.ones(2, 2, dtype=torch.float), )
# Find each method of model needs to be traced my its name, export its FX Graph.
method_graphs = {
m_name: export(getattr(model, m_name), (inputs,)) for m_name in method_names
}
method_graph = export(
capture_pre_autograd_graph(model, inputs),
inputs,
)
# Emit the traced methods into ET Program.
program = to_edge(method_graphs).to_executorch().executorch_program
et_program = to_edge(method_graph).to_executorch()
# number of input sets to be verified
n_input = 10
Expand All @@ -378,7 +371,7 @@ method_test_suites = [
test_cases=[
MethodTestCase(
inputs=input,
expected_outputs=getattr(model, m_name)(*input),
expected_outputs=(getattr(model, m_name)(*input),),
)
for input in inputs[m_name]
],
Expand All @@ -388,7 +381,7 @@ method_test_suites = [
# Generate BundledProgram
bundled_program = create_bundled_program(program, method_test_suites)
bundled_program = BundledProgram(et_program, method_test_suites)
```

:::{dropdown} Raised Error
Expand Down Expand Up @@ -455,10 +448,10 @@ Another common error would be the method name in any `MethodTestSuite` does not
```python
import torch

from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.sdk.bundled_program.core import create_bundled_program

from executorch.exir import to_edge
from executorch.sdk import BundledProgram

from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite
from torch.export import export


Expand All @@ -477,18 +470,18 @@ class Module(torch.nn.Module):


model = Module()

method_names = ["forward"]

inputs = torch.ones(2, 2, dtype=torch.float)
inputs = (torch.ones(2, 2, dtype=torch.float),)

# Find each method of model needs to be traced my its name, export its FX Graph.
method_graphs = {
m_name: export(getattr(model, m_name), (inputs,)) for m_name in method_names
}
method_graph = export(
capture_pre_autograd_graph(model, inputs),
inputs,
)

# Emit the traced methods into ET Program.
program = to_edge(method_graphs).to_executorch().executorch_program
et_program = to_edge(method_graph).to_executorch()

# number of input sets to be verified
n_input = 10
Expand All @@ -513,7 +506,7 @@ method_test_suites = [
test_cases=[
MethodTestCase(
inputs=input,
expected_outputs=getattr(model, m_name)(*input),
expected_outputs=(getattr(model, m_name)(*input),),
)
for input in inputs[m_name]
],
Expand All @@ -525,7 +518,7 @@ method_test_suites = [
method_test_suites[0].method_name = "MISSING_METHOD_NAME"

# Generate BundledProgram
bundled_program = create_bundled_program(program, method_test_suites)
bundled_program = BundledProgram(et_program, method_test_suites)

```

Expand Down
2 changes: 1 addition & 1 deletion sdk/bundled_program/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
input: All inputs required by eager_model with specific inference method for one-time execution.
It is worth mentioning that, although both bundled program and ET runtime apis support setting input
other than torch.tensor type, only the input in torch.tensor type will be actually updated in
other than `torch.tensor` type, only the input in `torch.tensor` type will be actually updated in
the method, and the rest of the inputs will just do a sanity check if they match the default value in method.
expected_output: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling.
Expand Down

0 comments on commit 4c61b71

Please sign in to comment.