You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I modified it to make it work with sequences of tensors / dict of keys: tensor. This is very common when running model(**batch).
import inspect
def _get_trace_graph(self, model, dummy_input, optimized_onnx=False):
# Run the Pytorch graph to get a trace and generate a graph from it
trace_graph = None
with torch.no_grad():
if isinstance(dummy_input, dict):
forward_args = inspect.signature(model.forward).parameters.keys()
input_tensors = []
for argname in forward_args:
if argname not in ['args', 'kwargs']:
if argname in dummy_input:
input_tensor = dummy_input[argname]
input_tensors.append(input_tensor)
print(argname, input_tensor.shape)
else:
input_tensors.append(None)
input_tensors = tuple(input_tensors)
elif isinstance(dummy_input, torch.Tensor):
input_tensors = (dummy_input,)
else:
input_tensors = tuple(dummy_input)
trace_graph, _ = torch.jit._get_trace_graph(model, args=input_tensors)
The text was updated successfully, but these errors were encountered:
Thanks for the modifications which look good for me. Yes, I recommend to creating a pull request and we appreciate the contributions from the community. Upon your willingness, you could create a pull request here or awaiting for us to finish the repo immigration to Microsoft open-source affiliation (expected soon).
I know i should do a pull request, but this is a quick edit:
https://github.com/tianyic/only_train_once/blob/7e930f6ae6cab71659fb921671f6f3921828d7c4/only_train_once/graph/graph.py#L416-L420
I modified it to make it work with sequences of tensors / dict of keys: tensor. This is very common when running
model(**batch)
.The text was updated successfully, but these errors were encountered: