Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing dummy_input to be tuple / dict #56

Open
iamanigeeit opened this issue Mar 1, 2024 · 1 comment
Open

Allowing dummy_input to be tuple / dict #56

iamanigeeit opened this issue Mar 1, 2024 · 1 comment

Comments

@iamanigeeit
Copy link
Contributor

I know i should do a pull request, but this is a quick edit:

https://github.com/tianyic/only_train_once/blob/7e930f6ae6cab71659fb921671f6f3921828d7c4/only_train_once/graph/graph.py#L416-L420

I modified it to make it work with sequences of tensors / dict of keys: tensor. This is very common when running model(**batch).

import inspect
    def _get_trace_graph(self, model, dummy_input, optimized_onnx=False):
        # Run the Pytorch graph to get a trace and generate a graph from it
        trace_graph = None
        with torch.no_grad():
            if isinstance(dummy_input, dict):
                forward_args = inspect.signature(model.forward).parameters.keys()
                input_tensors = []
                for argname in forward_args:
                    if argname not in ['args', 'kwargs']:
                        if argname in dummy_input:
                            input_tensor = dummy_input[argname]
                            input_tensors.append(input_tensor)
                            print(argname, input_tensor.shape)
                        else:
                            input_tensors.append(None)
                input_tensors = tuple(input_tensors)
            elif isinstance(dummy_input, torch.Tensor):
                input_tensors = (dummy_input,)
            else:
                input_tensors = tuple(dummy_input)
            trace_graph, _ = torch.jit._get_trace_graph(model, args=input_tensors)

@tianyic
Copy link
Owner

tianyic commented Mar 1, 2024

@iamanigeeit

Thanks for the modifications which look good for me. Yes, I recommend to creating a pull request and we appreciate the contributions from the community. Upon your willingness, you could create a pull request here or awaiting for us to finish the repo immigration to Microsoft open-source affiliation (expected soon).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants