Allow Custom Classes to register a handler for .to operations #51994
Labels
feature
A request for a proper, new feature.
module: custom-operators
custom operators, custom ops, custom-operators, custom-ops
oncall: jit
Add this issue/PR to JIT oncall triage queue
weeks
🚀 Feature
We would want to be able to register a handler for
.to
like we do for.def_pickle
using torchbind to allow us to define custom behavior to move a custom class from one device to another or to error out if such a move is impossible. Ideally this.to
function would be called either by the user directly on an instance of the class or recursively when the user calls.to
on the module owning the instance, similar to how it works for tensors owned by modules today.Motivation
In TRTorch we store a custom class managing a TensorRT engine as an attribute of a ScriptModule. However TRT engines once initialized are device specific so we would like to open up the possibility for users to move these engines between devices using a standard PyTorch convention.
Pitch
Ideally we would like to see something like this possible
Alternatives
We could write an independent
.to
method that works when invoked on an instance but in the case that that instance is owned by a module I am not sure what the process is for users to dig out a reference to the attribute to call the method on.There might be a better way to store the custom class so that this wouldn't effect all attributes as well. I am not too familiar with the common use cases for attributes.
Additional context
This is our current custom class: https://github.com/NVIDIA/TRTorch/blob/master/core/runtime/TRTEngine.cpp
and how we register it as an attribute in the module: https://github.com/NVIDIA/TRTorch/blob/6442fce997e1506d859fab789527fe1e282f683f/core/compiler.cpp#L57
Some additional context can be found here as well for a tangentially related feature in TRTorch: pytorch/TensorRT#311
It was pointed out in the PyTorch slack that this is where the
.to
call on the module is resolved, presumably this would need to be modified to call.to
on attributes of script modules as wellhttps://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp#L78-L102
cc @gmagogsfm
The text was updated successfully, but these errors were encountered: