Skip to content

Commit

Permalink
Fix 'get_attr' call in dynamo 'run_node' (pytorch#127696)
Browse files Browse the repository at this point in the history
Fixes pytorch#124858

Pull Request resolved: pytorch#127696
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#127695
  • Loading branch information
BowenBao authored and TharinduRusira committed Jun 14, 2024
1 parent 3e3621a commit 53cec9b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
22 changes: 22 additions & 0 deletions test/dynamo/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,28 @@ def fn(a, b, c):

self.assertEqual(cnt.frame_count, 1)

def test_assume_constant_result_on_user_defined_fn(self):
@torch._dynamo.assume_constant_result
def const_fn(n, s):
return torch.full([n], s)

def fn(B):
B = const_fn(B.size(0), 13)
X = B * 2
return X.tolist()

B_list = [8] * 32

B = torch.tensor(B_list, dtype=torch.int32)
torch._dynamo.decorators.mark_static(B, 0)

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

self.assertEqual(
fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)
)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,7 @@ def make_error_message(e):
assert nnmodule is not None
return nnmodule(*args, **kwargs)
elif op == "get_attr":
return tracer.get_submodule(node.target)
return tracer.output_graph.get_submodule(node.target)
elif op == "placeholder":
assert "example_value" in node.meta
return node.meta["example_value"]
Expand Down

0 comments on commit 53cec9b

Please sign in to comment.