Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add torch.compile support #1791

Merged
merged 9 commits into from
Dec 4, 2024

Conversation

felixdittrich92
Copy link
Contributor

@felixdittrich92 felixdittrich92 commented Nov 21, 2024

This PR:

  • Add torch.compile basic support (different backends / fullgraph is up to the user - we support only the basic compile compatibility with inductor (default) backend)
  • Add corresponding tests and a documentation section
  • No benchmarks added here because this scales with the users hardware

Any feedback is welcome 🤗

Closes: #1684 #1690

@felixdittrich92 felixdittrich92 added topic: documentation Improvements or additions to documentation module: models Related to doctr.models framework: pytorch Related to PyTorch backend topic: text detection Related to the task of text detection topic: text recognition Related to the task of text recognition topic: character classification Related to the task of character classification type: new feature New feature ext: docs Related to docs folder labels Nov 21, 2024
@felixdittrich92 felixdittrich92 added this to the 0.11.0 milestone Nov 21, 2024
@felixdittrich92 felixdittrich92 self-assigned this Nov 21, 2024
Copy link

codecov bot commented Nov 21, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.64%. Comparing base (05d2fb6) to head (9e07318).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1791      +/-   ##
==========================================
+ Coverage   96.54%   96.64%   +0.10%     
==========================================
  Files         165      165              
  Lines        7892     7929      +37     
==========================================
+ Hits         7619     7663      +44     
+ Misses        273      266       -7     
Flag Coverage Δ
unittests 96.64% <100.00%> (+0.10%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felixdittrich92
Copy link
Contributor Author

@odulcy-mindee This slows down our pytorch test CI extremly (before ~9min | now ~20min) because the compilation takes some time :/

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Nov 21, 2024

Then i think rdy to review :)

I would say we shouldn't add a benchmark to the docs here because this is really user hardware depending or how do you see it ?

At the end the Onnx way is still to prefer ... compile brings ~5-10% boost ..Onnx boosts by +50% (on CPU) ^^

@felixdittrich92 felixdittrich92 marked this pull request as ready for review November 21, 2024 11:13
@felixdittrich92
Copy link
Contributor Author

@odulcy-mindee Anything to add ? Otherwise ready to review ^^

@felixdittrich92
Copy link
Contributor Author

@odulcy-mindee Anything to add ? Otherwise ready to review ^^

@odulcy-mindee ? 🤗

odulcy-mindee
odulcy-mindee previously approved these changes Dec 3, 2024
"crnn_mobilenet_v3_small",
"crnn_mobilenet_v3_large",
"sar_resnet31",
# "master", NOTE: MASTER model isn't compilable yet
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong with it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that's strange 😅 sometimes it works and sometimes not xD - so better to say not supported yet we can check this again with the next torch releases

os.environ["TORCH_LOGS"] = "+dynamo"
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
(doctr-dev) felix@felix-Z790-AORUS-MASTER:~/Desktop/doctr$ USE_TORCH=1 python3 /home/felix/Desktop/doctr/test_scripts/test_compile.py
/home/felix/Desktop/doctr/doctr/models/utils/pytorch.py:62: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(archive_path, map_location="cpu")
<class 'torch._dynamo.eval_frame.OptimizedModule'>
[('I', 0.7236536741256714)]
I1204 08:38:36.791000 56803 site-packages/torch/_dynamo/utils.py:399] TorchDynamo compilation metrics:
I1204 08:38:36.791000 56803 site-packages/torch/_dynamo/utils.py:399] Function, Runtimes (s)
V1204 08:38:36.791000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.791000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats evaluate_expr: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _find: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats simplify: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats replace: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats get_axioms: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats safe_expand: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1204 08:38:36.792000 56803 site-packages/torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats uninteresting_files: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)

Sometimes with compilation error:

(doctr-dev) felix@felix-Z790-AORUS-MASTER:~/Desktop/doctr$ USE_TORCH=1 python3 /home/felix/Desktop/doctr/test_scripts/test_compile.py
/home/felix/Desktop/doctr/doctr/models/utils/pytorch.py:62: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(archive_path, map_location="cpu")
Traceback (most recent call last):
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/cpp_builder.py", line 331, in _run_compile_cmd
    status = subprocess.check_output(args=cmd, cwd=cwd, stderr=subprocess.STDOUT)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/subprocess.py", line 421, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/subprocess.py", line 526, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['g++', '/tmp/torchinductor_felix/34/c34pw76n7sgkuqnafsk3e5yb6e64rjkes6xcwzar2r7jaszd7uqt.cpp', '-D', 'TORCH_INDUCTOR_CPP_WRAPPER', '-D', 'C10_USING_CUSTOM_GENERATED_MACROS', '-D', 'CPU_CAPABILITY_AVX2', '-shared', '-fPIC', '-O3', '-DNDEBUG', '-ffast-math', '-fno-finite-math-only', '-fno-unsafe-math-optimizations', '-ffp-contract=off', '-march=native', '-Wall', '-std=c++17', '-Wno-unused-variable', '-Wno-unknown-pragmas', '-fopenmp', '-I/home/felix/anaconda3/envs/doctr-dev/include/python3.10', '-I/home/felix/anaconda3/envs/doctr-dev/include/python3.10', '-I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include', '-I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/torch/csrc/api/include', '-I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/TH', '-I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/THC', '-I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include', '-I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/torch/csrc/api/include', '-I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/TH', '-I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/THC', '-mavx2', '-mfma', '-mf16c', '-D_GLIBCXX_USE_CXX11_ABI=0', '-ltorch', '-ltorch_cpu', '-ltorch_python', '-lc10', '-lgomp', '-L/home/felix/anaconda3/envs/doctr-dev/lib', '-L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib', '-L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib', '-L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib', '-o', '/tmp/torchinductor_felix/34/c34pw76n7sgkuqnafsk3e5yb6e64rjkes6xcwzar2r7jaszd7uqt.so']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/__init__.py", line 2234, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
    return aot_autograd(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
    return inner_compile(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1334, in load
    compiled_graph = compile_fx_fn(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 878, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1913, in compile_to_fn
    return self.compile_to_module().call
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1839, in compile_to_module
    return self._compile_to_module()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1867, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2876, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_felix/xb/cxb7w3qrt7osob523qtzdkby4ofjwvb5vrj7i5sgbqyl6wykw76s.py", line 98816, in <module>
    async_compile.wait(globals())
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/async_compile.py", line 276, in wait
    scope[key] = result.result()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 3353, in result
    return self.result_fn()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2377, in future
    result = get_result()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2177, in load_fn
    future.result()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2218, in _worker_compile_cpp
    cpp_builder.build()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/cpp_builder.py", line 1508, in build
    status = run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/cpp_builder.py", line 352, in run_compile_cmd
    return _run_compile_cmd(cmd_line, cwd)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_inductor/cpp_builder.py", line 346, in _run_compile_cmd
    raise exc.CppCompileError(cmd, output) from e
torch._inductor.exc.CppCompileError: C++ compile error

Command:
g++ /tmp/torchinductor_felix/34/c34pw76n7sgkuqnafsk3e5yb6e64rjkes6xcwzar2r7jaszd7uqt.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_AVX2 -shared -fPIC -O3 -DNDEBUG -ffast-math -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -fopenmp -I/home/felix/anaconda3/envs/doctr-dev/include/python3.10 -I/home/felix/anaconda3/envs/doctr-dev/include/python3.10 -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/TH -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/THC -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/TH -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/THC -mavx2 -mfma -mf16c -D_GLIBCXX_USE_CXX11_ABI=0 -ltorch -ltorch_cpu -ltorch_python -lc10 -lgomp -L/home/felix/anaconda3/envs/doctr-dev/lib -L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib -L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib -L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib -o /tmp/torchinductor_felix/34/c34pw76n7sgkuqnafsk3e5yb6e64rjkes6xcwzar2r7jaszd7uqt.so

Output:
In file included from /home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/c10/util/Half.h:535,
                 from /home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/c10/util/Float8_e5m2.h:17,
                 from /home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/ATen/NumericUtils.h:11,
                 from /tmp/torchinductor_felix/vu/cvuvp4i7roujum4xemrfwnb3t4c5t3r3mihr4b7iegh6tcqvdg43.h:19,
                 from /tmp/torchinductor_felix/34/c34pw76n7sgkuqnafsk3e5yb6e64rjkes6xcwzar2r7jaszd7uqt.cpp:2:
/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/c10/util/Half-inl.h: In function ‘c10::Half c10::operator*(Half, int64_t)’:
/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/c10/util/Half-inl.h:265:33: internal compiler error: Segmentation fault
  265 |   return a * static_cast<Half>(b);
      |                                 ^
0x7b813ea4531f ???
        ./signal/../sysdeps/unix/sysv/linux/x86_64/libc_sigaction.c:0
0x7b813ea2a1c9 __libc_start_call_main
        ../sysdeps/nptl/libc_start_call_main.h:58
0x7b813ea2a28a __libc_start_main_impl
        ../csu/libc-start.c:360
Please submit a full bug report, with preprocessed source (by using -freport-bug).
Please include the complete backtrace with any bug report.
See <file:///usr/share/doc/gcc-13/README.Bugs> for instructions.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/felix/Desktop/doctr/test_scripts/test_compile.py", line 12, in <module>
    compiled_out = compiled_predictor(doc)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/felix/Desktop/doctr/doctr/models/recognition/predictor/pytorch.py", line 77, in forward
    raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
  File "/home/felix/Desktop/doctr/doctr/models/recognition/predictor/pytorch.py", line 77, in <listcomp>
    raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
    result = self._inner_convert(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 580, in wrapper
    return handle_graph_break(self, inst, speculation.reason)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 649, in handle_graph_break
    self.output.compile_subgraph(self, reason=reason)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1142, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CppCompileError: C++ compile error

Command:
g++ /tmp/torchinductor_felix/34/c34pw76n7sgkuqnafsk3e5yb6e64rjkes6xcwzar2r7jaszd7uqt.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_AVX2 -shared -fPIC -O3 -DNDEBUG -ffast-math -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -fopenmp -I/home/felix/anaconda3/envs/doctr-dev/include/python3.10 -I/home/felix/anaconda3/envs/doctr-dev/include/python3.10 -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/TH -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/THC -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/TH -I/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/THC -mavx2 -mfma -mf16c -D_GLIBCXX_USE_CXX11_ABI=0 -ltorch -ltorch_cpu -ltorch_python -lc10 -lgomp -L/home/felix/anaconda3/envs/doctr-dev/lib -L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib -L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib -L/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/lib -o /tmp/torchinductor_felix/34/c34pw76n7sgkuqnafsk3e5yb6e64rjkes6xcwzar2r7jaszd7uqt.so

Output:
In file included from /home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/c10/util/Half.h:535,
                 from /home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/c10/util/Float8_e5m2.h:17,
                 from /home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/ATen/NumericUtils.h:11,
                 from /tmp/torchinductor_felix/vu/cvuvp4i7roujum4xemrfwnb3t4c5t3r3mihr4b7iegh6tcqvdg43.h:19,
                 from /tmp/torchinductor_felix/34/c34pw76n7sgkuqnafsk3e5yb6e64rjkes6xcwzar2r7jaszd7uqt.cpp:2:
/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/c10/util/Half-inl.h: In function ‘c10::Half c10::operator*(Half, int64_t)’:
/home/felix/anaconda3/envs/doctr-dev/lib/python3.10/site-packages/torch/include/c10/util/Half-inl.h:265:33: internal compiler error: Segmentation fault
  265 |   return a * static_cast<Half>(b);
      |                                 ^
0x7b813ea4531f ???
        ./signal/../sysdeps/unix/sysv/linux/x86_64/libc_sigaction.c:0
0x7b813ea2a1c9 __libc_start_call_main
        ../sysdeps/nptl/libc_start_call_main.h:58
0x7b813ea2a28a __libc_start_main_impl
        ../csu/libc-start.c:360
Please submit a full bug report, with preprocessed source (by using -freport-bug).
Please include the complete backtrace with any bug report.
See <file:///usr/share/doc/gcc-13/README.Bugs> for instructions.


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@odulcy-mindee I have a guess:

  • the master architecture is really large and the created graph has some complexity (partial plugged attn layer in the backbone etc.)
  • torch compile breaks the whole graph in subgraphs which are compiled step by step - the already compiled subgraphs are cached
  • So a logical explanation would be: The compilation runs in a timeout for example - by multiple script execution we compile more and more subgraphs until the whole model is compiled because it will use the already compiled and cached subgraphs
  • Because after it's compiled successfully the first time it works again and again without issues ^^

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I see. Thank you for this investigation !
Is there something we can do to set this architecture as not supported by torch.compile ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mh currently it's mentioned in the corresponding docs section .. from code i don't think so.. because we leave the torch.compile step flexible to users - which is good because it's dynamic and people can test different backends if they want on its own / we support only the torch default one inductor as tested

tests/pytorch/test_models_recognition_pt.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@odulcy-mindee odulcy-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@felixdittrich92 felixdittrich92 merged commit 9b6dea3 into mindee:main Dec 4, 2024
70 checks passed
@felixdittrich92 felixdittrich92 deleted the pt-compile branch December 4, 2024 10:33
sarjil77 pushed a commit to sarjil77/doctr that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ext: docs Related to docs folder framework: pytorch Related to PyTorch backend module: models Related to doctr.models topic: character classification Related to the task of character classification topic: documentation Improvements or additions to documentation topic: text detection Related to the task of text detection topic: text recognition Related to the task of text recognition type: new feature New feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for torch.compile
3 participants