Skip to content

Commit

Permalink
Change aot_compile callsites (#2207)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2207

X-link: pytorch/pytorch#122225

Replacing `torch._export.aot_compile` callsites with
```
ep = torch.export._export(.., predispatch=True)   # Traces the given program into predispatch IR
so_path = torch._inductor.aot_compile_ep(ep, ...)  # Takes an exported program and compiles it into a .so
```

This allows us to explicitly split up the export step from AOTInductor. We can later modify tests to do `export + serialize + deserialize + inductor` to mimic internal production use cases better.

This PR also enables export's predispatch IR most of the the existing use cases. Previously this is using export to torch IR, which is a different graph. This may result in some performance regressions as some of inductor's passes will no longer run -- if so, please let me know.

This PR changes the seemingly low-impact files to the new calling convention, and a followup PR will change the high-impact site.

Reviewed By: SherlockNoMad

Differential Revision: D54808612

fbshipit-source-id: 4cd287d5af0475630b327d78a4582eca7d9f78f5
  • Loading branch information
angelayi authored and facebook-github-bot committed Mar 29, 2024
1 parent f26ff42 commit 756ea35
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,9 @@ class AOTInductorModelCache:

@classmethod
def load(cls, model, example_inputs, device):
import torch._inductor
import torch.export._trace

key = weakref.ref(model)
if key not in cls.cache:
# Register the output dataclass to pytree
Expand All @@ -1132,7 +1135,17 @@ def load(cls, model, example_inputs, device):
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs)

so_path = torch._export.aot_compile(model, example_args, example_kwargs)
# TODO(angelayi): change this to predispatch
gm = torch.export._trace._export_to_torch_ir(
model,
example_args,
example_kwargs,
)
with torch.no_grad():
so_path = torch._inductor.aot_compile(
gm, example_args, example_kwargs
) # type: ignore[arg-type]

cls.cache[key] = torch._export.aot_load(so_path, device)

return cls.cache[key]
Expand Down

0 comments on commit 756ea35

Please sign in to comment.