Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Prototype Module-Acceleration in Dynamo #1921

Closed
wants to merge 3 commits into from

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented May 15, 2023

Description

  • Add support for excluding entire Torch modules from tracing in Dynamo using Torch custom operators
  • Develop new dataclass to store required replacement functions and operators in a streamlined way
  • Add new registry to store mapping between replacement operators and their corresponding dataclass
  • Add detailed tutorial descriptions to sample module substitution with step-by-step detailed instructions for creating a new module substitution

Fixes #1894

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive added the component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths label May 15, 2023
@gs-olive gs-olive self-assigned this May 15, 2023
@github-actions github-actions bot added the component: api [Python] Issues re: Python API label May 15, 2023
@gs-olive gs-olive force-pushed the dynamo_module_level_exclusion branch 3 times, most recently from 880d3e2 to 7e0f440 Compare May 16, 2023 00:54
@gs-olive gs-olive added the WIP Work is in progress, pull request should not be merged yet label May 17, 2023
@gs-olive gs-olive force-pushed the dynamo_module_level_exclusion branch 9 times, most recently from f34c25c to 4acaae2 Compare May 19, 2023 20:11
@gs-olive gs-olive added the Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path label May 22, 2023
@github-actions github-actions bot requested a review from narendasan May 22, 2023 20:00
@gs-olive gs-olive force-pushed the dynamo_module_level_exclusion branch 4 times, most recently from c3c266d to 613b773 Compare May 25, 2023 05:11
- Add support for excluding entire Torch modules from tracing in Dynamo
using Torch custom operators
- Develop new dataclass to store required replacement functions and
operators in a streamlined way
- Add new registry to store mapping between replacement operators and
their corresponding dataclass
- Add documentation for easy additions of new module-level exclusion
operators
@gs-olive gs-olive force-pushed the dynamo_module_level_exclusion branch from 613b773 to 47a7f69 Compare May 26, 2023 17:06
@gs-olive gs-olive force-pushed the dynamo_module_level_exclusion branch from 47a7f69 to 60df50e Compare May 26, 2023 17:08
@gs-olive gs-olive removed the WIP Work is in progress, pull request should not be merged yet label May 26, 2023
@gs-olive gs-olive marked this pull request as ready for review May 26, 2023 17:08
.circleci/config.yml Outdated Show resolved Hide resolved
@@ -14,6 +15,11 @@

logger = logging.getLogger(__name__)

DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
"torch.ops." + str(module.new_operator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if theres a better type than string for this registry? Like is there a op type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into this more - there is the torch._ops.OpOverload type which could be a substitute here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking more into this, str is likely the best choice for this registry for now. The issue is that node targets can be a mix of functions and torch._ops.OpOverload objects. For example, torch.ops.aten.add.Tensor is an overload object representing a Tensor addition op, whereas the operator for get is an actual Python function. The unifying framework which can connect all of these types is the _get_qualified_name function, which can handle all of these types, and returns a string. I have updated the implementation here to use that function.

MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict()


def module_substitution(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sick 😄

@@ -0,0 +1,75 @@
from typing import Dict, Tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cool to have a sphinx tutorial on how to do this from an external user perspective. could be as easy as removing maxpool1d from the registry then walking through all the parts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I will add this, in addition to documentation on how to ensure all the relevant code is registered (__init__.py imports)

Copy link
Collaborator Author

@gs-olive gs-olive Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid clashing with or making dependencies on #1966 and #1967, I'll add more commenting to this file and then port it over to a formatted .py file once those PRs are merged.



@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
def aten_ops_maxpool1d(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For module substitution "in library" do we want to put the converter here? or do we want to put the converter in the registry with the rest?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For external users they'd probably put it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it could be cleaner to have the converter implementation here, so all of the code relating to that module and its replacement is centralized.

The requirement, however, is that for every new module replacement file, the user will have to add from .my_module_replacement import * to module_substitutions/__init__.py to ensure the registrations occur.

@@ -31,6 +34,8 @@ def fx_dynamo_testing_backend(
torch_executed_ops=torch_executed_ops,
)

gm = pre_aot_module_replacement(gm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a separate testing backend? Would we need to continue to make changes to this in step with the actual one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for the separate testing backend is the argument store_intermediate_graphs. This is needed to track intermediate partitioned modules at different points in the compilation, to ensure decompositions, fusions, etc. are being utilized.

As changes are made to the main backend, yes, those changes would need to be reflected here, and in the compile_module_testing function below, however these are higher-level functions which do not change often.

- Fix typing issues, add depedencies to `setup.py`, add qualified name
checking for module registry
- Add detailed tutorial descriptions to sample module substitution with
step-by-step detailed instructions for creating a new module
substitution
@gs-olive gs-olive force-pushed the dynamo_module_level_exclusion branch from 46cbe6e to e4c6ba5 Compare June 1, 2023 19:53
Comment on lines +18 to +21
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
_get_qualified_name(module.new_operator)
for module in MODULE_SUBSTITUTION_REGISTRY.values()
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated registry to use strings formatted in a streamlined/uniform way via get_qualified_name as is already done in partitioning, and addressing the issues raised here: #1921 (comment)

@gs-olive gs-olive requested a review from narendasan June 1, 2023 19:56
Comment on lines +13 to +24
# This file serves as an example and a tutorial for excluding custom modules from
# torch.compile tracing. Each required step is labeled with a number indicating the
# preferable implementation order.


# 1. The Placeholder
#
# Specify the schema and namespace of the operator, as well as a placeholder function
# representing the schema. The schema should be in torch JIT syntax, indicating input and output
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
# Then, create a placeholder function with no operations, but having the same schema and naming as that
# used in the decorator
Copy link
Collaborator Author

@gs-olive gs-olive Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added detailed instructions for creating a module substitution, to later be turned into a formatted .py file

# One should expect that torch.compile will compress all kwargs into the args field in
# the order specified in the schema written in step 1.
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
def aten_ops_maxpool1d(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten_ops --> tensorrt

@gs-olive
Copy link
Collaborator Author

gs-olive commented Jun 5, 2023

Closed in place of #1979

@gs-olive gs-olive closed this Jun 5, 2023
@gs-olive gs-olive deleted the dynamo_module_level_exclusion branch June 5, 2023 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Upgrade Prototype System for Module-Level Acceleration in Dynamo Path
3 participants