Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Custom Classes to register a handler for .to operations #51994

Open
narendasan opened this issue Feb 9, 2021 · 0 comments
Open

Allow Custom Classes to register a handler for .to operations #51994

narendasan opened this issue Feb 9, 2021 · 0 comments
Labels
feature A request for a proper, new feature. module: custom-operators custom operators, custom ops, custom-operators, custom-ops oncall: jit Add this issue/PR to JIT oncall triage queue weeks

Comments

@narendasan
Copy link
Contributor

narendasan commented Feb 9, 2021

🚀 Feature

We would want to be able to register a handler for .to like we do for .def_pickle using torchbind to allow us to define custom behavior to move a custom class from one device to another or to error out if such a move is impossible. Ideally this .to function would be called either by the user directly on an instance of the class or recursively when the user calls .to on the module owning the instance, similar to how it works for tensors owned by modules today.

Motivation

In TRTorch we store a custom class managing a TensorRT engine as an attribute of a ScriptModule. However TRT engines once initialized are device specific so we would like to open up the possibility for users to move these engines between devices using a standard PyTorch convention.

Pitch

Ideally we would like to see something like this possible

import torch
import trtorch

# Create model on device 0
model = MyModel()
ts_model = torch.jit.script(model).to("cuda:0")
trt_model = trtorch.compile(ts_model, {...}) # or trt_model = torch._C._jit_to_backend("tensorrt", ts_model, ...)

# Move module (internal tensors and attributes that have a .to registration) to device 1 
trt_model.to("cuda:1")

Alternatives

We could write an independent .to method that works when invoked on an instance but in the case that that instance is owned by a module I am not sure what the process is for users to dig out a reference to the attribute to call the method on.

There might be a better way to store the custom class so that this wouldn't effect all attributes as well. I am not too familiar with the common use cases for attributes.

Additional context

This is our current custom class: https://github.com/NVIDIA/TRTorch/blob/master/core/runtime/TRTEngine.cpp

and how we register it as an attribute in the module: https://github.com/NVIDIA/TRTorch/blob/6442fce997e1506d859fab789527fe1e282f683f/core/compiler.cpp#L57

Some additional context can be found here as well for a tangentially related feature in TRTorch: pytorch/TensorRT#311

It was pointed out in the PyTorch slack that this is where the .to call on the module is resolved, presumably this would need to be modified to call .to on attributes of script modules as well
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp#L78-L102

cc @gmagogsfm

@mrshenli mrshenli added feature A request for a proper, new feature. module: custom-operators custom operators, custom ops, custom-operators, custom-ops oncall: jit Add this issue/PR to JIT oncall triage queue labels Feb 9, 2021
@wanchaol wanchaol added the weeks label Feb 16, 2021
@SplitInfinity SplitInfinity removed their assignment Apr 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: custom-operators custom operators, custom ops, custom-operators, custom-ops oncall: jit Add this issue/PR to JIT oncall triage queue weeks
Projects
None yet
Development

No branches or pull requests

4 participants