Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch._dynamo.assume_constant_result does not work outside nn.Module #124858

Closed
IvanKobzarev opened this issue Apr 24, 2024 · 3 comments
Closed
Assignees
Labels
high priority module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@IvanKobzarev
Copy link
Contributor

IvanKobzarev commented Apr 24, 2024

🐛 Describe the bug

Using assert_constant_result outside of nn.Module fails on getting Submodule (to look for real value)

        @torch._dynamo.assume_constant_result
        def const_fn(n, s):
            return torch.full([n], s)

        def fn(B):
            B = const_fn(B.size(0), 13)
            X = B * 2
            return X.tolist()

        B_list = [8] * 32

        B = torch.tensor(B_list, dtype=torch.int32)
        torch._dynamo.decorators.mark_static(B, 0)

        torch._dynamo.config.capture_scalar_outputs = True
        torch._dynamo.config.capture_dynamic_output_shape_ops = True

        print(fn(B))
        torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)

Full error:

  1) torchrec.distributed.tests.test_test.TestTest: test_dynamo_constant_tensor
    1) TorchRuntimeError: Failed running get_attr const_fn(*(), **{}):
    'SubgraphTracer' object has no attribute 'get_submodule'
    
    from user code:
       File "/data/users/ivankobzarev/fbsource/buck-out/v2/gen/fbcode/680651077c79ba5d/torchrec/distributed/tests/__test_test__/test_test#link-tree/torchrec/distributed/tests/test_test.py", line 30, in forward
        X = B * 2
    
    Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
    
    
    You can suppress this exception and fall back to eager by setting:
        import torch._dynamo
        torch._dynamo.config.suppress_errors = True
    
      File "torchrec/distributed/tests/test_test.py", line 46, in test_dynamo_constant_tensor
        torch.compile(m, backend="eager", fullgraph=True, dynamic=True)(B)
      File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
      File "torch/_dynamo/eval_frame.py", line 403, in _fn
        return fn(*args, **kwargs)
      File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
      File "torch/_dynamo/convert_frame.py", line 977, in catch_errors
        return callback(frame, cache_entry, hooks, frame_state, skip=1)
      File "torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
        return _compile(
      File "torch/_utils_internal.py", line 279, in wrapper_function
        return StrobelightCompileTimeProfiler.profile_compile_time(
      File "caffe2/fb/strobelight/compile_time_profiler.py", line 96, in profile_compile_time
        return func(*args, **kwargs)
      File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
        return func(*args, **kwds)
      File "torch/_dynamo/convert_frame.py", line 700, in _compile
        guarded_code = compile_inner(code, one_graph, hooks, transform)
      File "torch/_dynamo/utils.py", line 268, in time_wrapper
        r = func(*args, **kwargs)
      File "torch/_dynamo/convert_frame.py", line 568, in compile_inner
        out_code = transform_code_object(code, transform)
      File "torch/_dynamo/bytecode_transformation.py", line 1116, in transform_code_object
        transformations(instructions, code_options)
      File "torch/_dynamo/convert_frame.py", line 173, in _fn
        return fn(*args, **kwargs)
      File "torch/_dynamo/convert_frame.py", line 515, in transform
        tracer.run()
      File "torch/_dynamo/symbolic_convert.py", line 2237, in run
        super().run()
      File "torch/_dynamo/symbolic_convert.py", line 875, in run
        while self.step():
      File "torch/_dynamo/symbolic_convert.py", line 790, in step
        self.dispatch_table[inst.opcode](self, inst)
      File "torch/_dynamo/symbolic_convert.py", line 229, in impl
        self.push(fn_var.call_function(self, self.popn(nargs), {}))
      File "torch/_dynamo/variables/builtin.py", line 946, in call_function
        return handler(tx, args, kwargs)
      File "torch/_dynamo/variables/builtin.py", line 850, in _handle_insert_op_in_graph
        return invoke_and_store_as_constant(
      File "torch/_dynamo/variables/functions.py", line 388, in invoke_and_store_as_constant
        args = [convert(x) for x in args]
      File "torch/_dynamo/variables/functions.py", line 388, in <listcomp>
        args = [convert(x) for x in args]
      File "torch/_dynamo/variables/functions.py", line 384, in convert
        return x.get_real_value()
      File "torch/_dynamo/variables/tensor.py", line 113, in get_real_value
        return get_real_value(self.proxy.node, self.proxy.tracer)
      File "torch/_dynamo/utils.py", line 1924, in get_real_value
        raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
      File "torch/_dynamo/utils.py", line 1921, in get_real_value
        real_value = run_node(tracer, node, args, kwargs, nn_module)
      File "torch/_dynamo/utils.py", line 1885, in run_node
        raise RuntimeError(make_error_message(e)).with_traceback(
      File "torch/_dynamo/utils.py", line 1874, in run_node
        return tracer.get_submodule(node.target)

While inside Module it works as expected:

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            @torch._dynamo.assume_constant_result
            def const_fn(self, n, s):
                return torch.full([n], s)

            def forward(self, B):
                B = const_fn(B.size(0), 13)
                X = B * 2
                return X.tolist()

        B_list = [8] * 32
        B = torch.tensor(B_list, dtype=torch.int32)
        torch._dynamo.decorators.mark_static(B, 0)
        torch._dynamo.config.capture_scalar_outputs = True
        torch._dynamo.config.capture_dynamic_output_shape_ops = True
        m = M()
        torch.compile(M, backend="eager", fullgraph=True, dynamic=True)(B)

Versions

fbcode/warm

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78

@eellison
Copy link
Contributor

cc @anijain2305 , this is blocking for an internal model. @IvanKobzarev is looking for pointers

@eellison
Copy link
Contributor

also cc @angelayi export

@jbschlosser jbschlosser added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 25, 2024
@jbschlosser
Copy link
Contributor

Added high-pri due to internal model blockage.

@jbschlosser jbschlosser removed triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 30, 2024
@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 16, 2024
pytorchmergebot pushed a commit that referenced this issue Jun 11, 2024
Fixes #124858

Pull Request resolved: #127696
Approved by: https://github.com/jansel
ghstack dependencies: #127695
TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this issue Jun 14, 2024
TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this issue Jun 14, 2024
ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this issue Jun 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants