Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TVM Vertical Integration with PyTorch #11911

Merged
merged 28 commits into from
Jul 26, 2022
Merged

TVM Vertical Integration with PyTorch #11911

merged 28 commits into from
Jul 26, 2022

Conversation

juda
Copy link
Contributor

@juda juda commented Jun 27, 2022

The pull request contains two functions:

  1. optimize_torch as a function similar to torch.jit.trace, which is used to optimize the torch.nn.module by TVM metaSchedule, and returns a custom TorchScript operator
  2. as_torch as a decorator, which is used to wrap the TVMscript code to torch.nn.module.

The files consist of:

  • Two python codes for both functions and a C++ backend
  • Two test files for testing both functions

@yelite @junrushao1994 @masahi

@juda juda changed the title Totorch [PyTorch Integration] Torch optimization and wrapper Jun 27, 2022
@junrushao junrushao changed the title [PyTorch Integration] Torch optimization and wrapper TVM Vertical Integration with PyTorch Jun 27, 2022
@masahi
Copy link
Member

masahi commented Jun 27, 2022

I suggest moving tutorials to a separate PR. Ideally, tutorials should demonstrate more realistic examples than vector add or matmul, i.e. something PyTorch users would reach for "custom op" authoring.

For example, I think demonstrating equivalent of Triton fused softmax tutorial https://triton-lang.org/master/getting-started/tutorials/02-fused-softmax.html in this workflow would be very interesting.

@juda
Copy link
Contributor Author

juda commented Jun 28, 2022

I suggest moving tutorials to a separate PR. Ideally, tutorials should demonstrate more realistic examples than vector add or matmul, i.e. something PyTorch users would reach for "custom op" authoring.

For example, I think demonstrating equivalent of Triton fused softmax tutorial https://triton-lang.org/master/getting-started/tutorials/02-fused-softmax.html in this workflow would be very interesting.

After discussing with @yelite , we will drop the how-to guides and will resubmit a separate PR afterward.

apps/pt_tvmdsoop/tests/test_as_torch.py Outdated Show resolved Hide resolved
apps/pt_tvmdsoop/tests/test_optimize_torch.py Outdated Show resolved Hide resolved
python/tvm/contrib/torch/as_torch.py Outdated Show resolved Hide resolved
python/tvm/contrib/torch/optimize_torch.py Outdated Show resolved Hide resolved
python/tvm/contrib/torch/optimize_torch.py Outdated Show resolved Hide resolved
* The basic forward function calling TVM's runtime is provided.
* The TVM module can be serialized/deserialized as a Torch module.
*/
class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks similar to TvmGraphModulePack:

class TvmGraphModulePack {

Why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one reason is that we don't want to use temp files to transmit data, as bytedance's approach, but use TVM's FFI. @yelite

Copy link
Contributor Author

@juda juda Jul 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @masahi , there are several reasons we don't plan to use codes from tvm_class.cc:

  1. tvm_class.cc is complex while our code is more natural. For example, they maintain a torch's tensor to DLpack by themselves, while we use torch's built-in library.
  2. Our code is more readable. We have less functions but could cover tvm_class.cc's functionality. For example, we don't need to have an extra initialization function init or loadTVMmodule.
  3. tvm_class.cc uses tempfile and absolute path to transmit TVM module while we use TVM's FFI, which is a better practice I believe
  4. The most significant difference is save/load functions. I tested that if we save a torch model via tvm_class.cc and then restart the python kernel, we cannot load the model back successfully because of (3). Our code can arbitrarily save/load models anywhere anytime because we serialize/deserialize the whole runtime module.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If GraphExecutorFactoryWrapper is strictly better than existing one, I want to see the existing one removed or reimplemented in terms of GraphExecutorFactoryWrapper. But this can be done in a follow up.

src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc Outdated Show resolved Hide resolved
src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc Outdated Show resolved Hide resolved
python/tvm/contrib/torch/optimize_torch.py Show resolved Hide resolved
python/tvm/contrib/torch/optimize_torch.py Outdated Show resolved Hide resolved
* The basic forward function calling TVM's runtime is provided.
* The TVM module can be serialized/deserialized as a Torch module.
*/
class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If GraphExecutorFactoryWrapper is strictly better than existing one, I want to see the existing one removed or reimplemented in terms of GraphExecutorFactoryWrapper. But this can be done in a follow up.

save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod")
save_runtime_mod(executor_factory.module)

return GraphExecutorFactoryWrapper(torch.classes.tvm_tuning.GraphExecutorFactoryWrapper())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks strange... Why torch.classes.tvm_tuning.GraphExecutorFactoryWrapper() doesn't take any argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class GraphExecutorFactoryWrapper is the subclass of Torch's module, and Torch's FFI cannot recognize TVM's datastructure, thus we transmit the runtime module by TVM's FFI.
Concretely, in line 185, we store the module in memory.
When the constructor of GraphExecutorFactoryWrapper is called, it will get the TVM's runtime module in the memory.
The Python class GraphExecutorFactoryWrapper is just a wrapper of the output because C++ doesn't support tuple unpacking but we do need this function in Python.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now that the compiled module is passed between python and C++ PyTorch by a thread local storage (stored by save_runtime_mod).

return self.rt_module.forward(torch_inputs)


def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as_torch doesn't provide tuning facilities? I noticed that all tuning tests in this PR is done via optimize_torch which involves Relay. If a user wants to tune a TVMScript-written op and use @as_torch decorator, how tuning can be done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that as_torch is just used to convert TVMscript to Torch.
Need to confirm with @yelite to see if we need to do more.

Copy link
Member

@masahi masahi Jul 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still possible for PT users to write TVMScript and use tune_tir to tune, and use as_torch to convert the tuned prim func to PT. We are offering optimize_pytorch to wrap tune_relay, so it would be nice if as_torch also wraps tune_tir and automatically does tuning.

Current examples only show the usage of as_torch as an decorator on top of a manually written TVMScript without tuning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the tune_tir in as_torch

Copy link
Member

@masahi masahi Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this: We add tune(config) (with optional config param like optimize_torch) method on OperatorModuleWrapper, which does tuning and rebuild the mod. And remove tune_tir from build(...). So by default tuning won't happen, but the user can explicitly ask to tune.

src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc Outdated Show resolved Hide resolved
return sch

def build(self, target=None):
tuned_module = self.tune_tir_auto(self.ir_module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the input TVMScript module doesn't have any tunable knobs, does this tune_tir_auto finish instantly? Tuning should be an opt-in feature.

mod = default_config.mod(mod)
target = default_config.target(target)

extracted_task = ExtractedTask(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is always only one task, since it is tuning a single op

python/tvm/contrib/torch/as_torch.py Outdated Show resolved Hide resolved
python/tvm/contrib/torch/as_torch.py Outdated Show resolved Hide resolved
python/tvm/contrib/torch/optimize_torch.py Outdated Show resolved Hide resolved
"For optimal performance, it is recommended to provide",
"the `tuning_config` argument with a bigger number of trials.",
)
warnings.warn(" ".join(warning_msg), stacklevel=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the default tuning config is dropped? @juda

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is moved to line 111 because we need to get extracted_tasks in advance

@masahi masahi merged commit ea6ea42 into apache:main Jul 26, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* optimize_torch & as_torch

* split files

* code formatting

* optimizing optimized_torch

* scrap your boilerplate

* as_torch polished

* configuration fixed

* Apply suggestions from code review

Co-authored-by: Lite Ye <liteye859@gmail.com>

* more document

* file deleter

* optimize deleter

* drop how-to guides

* clang-format-10

* formatter changes

* reformat

* reformat

* reformat

* reformatting

* fixed

* auto setting

* fixed

* split long string

* tune_tir

* upgrade as_torch

* optimize as_torch

* as_torch

* fixed typo

Co-authored-by: juda <yzhou@octoml.ai>
Co-authored-by: Lite Ye <liteye859@gmail.com>
mikeseven pushed a commit to mikeseven/tvm that referenced this pull request Sep 27, 2023
* optimize_torch & as_torch

* split files

* code formatting

* optimizing optimized_torch

* scrap your boilerplate

* as_torch polished

* configuration fixed

* Apply suggestions from code review

Co-authored-by: Lite Ye <liteye859@gmail.com>

* more document

* file deleter

* optimize deleter

* drop how-to guides

* clang-format-10

* formatter changes

* reformat

* reformat

* reformat

* reformatting

* fixed

* auto setting

* fixed

* split long string

* tune_tir

* upgrade as_torch

* optimize as_torch

* as_torch

* fixed typo

Co-authored-by: juda <yzhou@octoml.ai>
Co-authored-by: Lite Ye <liteye859@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants