From 63fd6883fd3e845521c01daff822c125552cd4d2 Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 2 Feb 2024 23:49:14 +0000 Subject: [PATCH] [c10d] logging utility for cpp-python stacktrace (#118924) user may not know which line of code called collectives in a big code base. When debugging, we can print python-cpp stacktrace in case user call ``ProcessGroup.reduce`` instead of ``torch.distributed.reduce`` ``` LOG(INFO) << "ProcessGroupNCCL::_allgather_base stacktrace: " << get_python_cpp_trace(); ``` output (using _allgather_base as an example): one example python-part trace is ``all_gather_into_tensor from /data/users/weif/pytorch/torch/distributed/distributed_c10d.py:2838`` ``` ProcessGroupNCCL::_allgather_base stacktrace: #0 torch::unwind::unwind() from ??:0 #1 torch::CapturedTraceback::gather(bool, bool, bool) from ??:0 #2 c10d::get_python_cpp_trace[abi:cxx11]() from :0 #3 c10d::ProcessGroupNCCL::_allgather_base(at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&) from ??:0 #4 c10d::ops::(anonymous namespace)::_allgather_base_CUDA(at::Tensor&, at::Tensor&, c10::intrusive_ptr > const&, bool, long) from Ops.cpp:0 #5 c10::impl::make_boxed_from_unboxed_functor > > (*)(at::Tensor&, at::Tensor&, c10::intrusive_ptr > const&, bool, long), std::tuple > >, c10::guts::typelist::typelist > const&, bool, long> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector >*) from :0 #6 torch::autograd::basicAutogradNotImplementedFallbackImpl(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector >*) from autograd_not_implemented_fallback.cpp:0 #7 c10d::ProcessGroup::_allgather_base(at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&) from :0 #8 pybind11::cpp_function::initialize >, c10d::ProcessGroup, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::call_guard >(c10::intrusive_ptr > (c10d::ProcessGroup::*)(at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard const&)::{lambda(c10d::ProcessGroup*, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&)#1}, c10::intrusive_ptr >, c10d::ProcessGroup*, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::call_guard >(pybind11::cpp_function::initialize >, c10d::ProcessGroup, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::call_guard >(c10::intrusive_ptr > (c10d::ProcessGroup::*)(at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard const&)::{lambda(c10d::ProcessGroup*, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&)#1}&&, c10::intrusive_ptr > (*)(c10d::ProcessGroup*, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from :0 #9 pybind11::cpp_function::dispatcher(_object*, _object*, _object*) from :0 #10 cfunction_call from /usr/local/src/conda/python-3.10.12/Objects/methodobject.c:543 #11 _PyObject_MakeTpCall from /usr/local/src/conda/python-3.10.12/Objects/call.c:215 #12 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:112 #13 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #14 all_gather_into_tensor from /data/users/weif/pytorch/torch/distributed/distributed_c10d.py:2838 #15 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #16 do_call_core from /usr/local/src/conda/python-3.10.12/Python/ceval.c:5945 #17 wrapper from /data/users/weif/pytorch/torch/distributed/c10d_logger.py:75 #18 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #19 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #20 _all_gather_flat_param from /data/users/weif/pytorch/torch/distributed/fsdp/_flat_param.py:1399 #21 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #22 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #23 unshard from /data/users/weif/pytorch/torch/distributed/fsdp/_flat_param.py:1308 #24 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #25 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #26 _unshard from /data/users/weif/pytorch/torch/distributed/fsdp/_runtime_utils.py:332 #27 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #28 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #29 _pre_forward_unshard from /data/users/weif/pytorch/torch/distributed/fsdp/_runtime_utils.py:448 #30 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #31 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #32 _pre_forward from /data/users/weif/pytorch/torch/distributed/fsdp/_runtime_utils.py:413 #33 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #34 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #35 forward from /data/users/weif/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py:839 #36 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #37 do_call_core from /usr/local/src/conda/python-3.10.12/Python/ceval.c:5945 #38 _call_impl from /data/users/weif/pytorch/torch/nn/modules/module.py:1520 #39 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #40 do_call_core from /usr/local/src/conda/python-3.10.12/Python/ceval.c:5945 #41 _wrapped_call_impl from /data/users/weif/pytorch/torch/nn/modules/module.py:1511 #42 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #43 _PyObject_Call_Prepend from /usr/local/src/conda/python-3.10.12/Objects/call.c:431 #44 slot_tp_call from /usr/local/src/conda/python-3.10.12/Objects/typeobject.c:7494 #45 _PyObject_MakeTpCall from /usr/local/src/conda/python-3.10.12/Objects/call.c:215 #46 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:112 #47 inner from /data/users/weif/pytorch/run_fsdp.py:72 #48 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #49 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #50 run from /data/users/weif/pytorch/run_fsdp.py:76 #51 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #52 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #53 main from /data/users/weif/pytorch/run_fsdp.py:133 #54 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #55 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114 #56 from /data/users/weif/pytorch/run_fsdp.py:137 #57 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46 #58 PyEval_EvalCode from /usr/local/src/conda/python-3.10.12/Python/ceval.c:1134 #59 run_eval_code_obj from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:1291 #60 run_mod from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:1312 #61 pyrun_file from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:1208 #62 _PyRun_SimpleFileObject from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:456 #63 _PyRun_AnyFileObject from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:90 #64 pymain_run_file_obj from /usr/local/src/conda/python-3.10.12/Modules/main.c:357 #65 Py_BytesMain from /usr/local/src/conda/python-3.10.12/Modules/main.c:1090 #66 __libc_start_call_main from ??:0 #67 from ??:0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118924 Approved by: https://github.com/kwen2501 --- torch/csrc/distributed/c10d/TraceUtils.h | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index a7eef0650574f..746104ec861f0 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -343,6 +343,28 @@ inline std::string pickle_str(const c10::IValue& v) { return std::string(result.begin(), result.end()); } +inline std::string get_python_cpp_trace() { + // usage: + // LOG(INFO) << "stacktrace: " + // << get_python_cpp_trace(); + // warn: might be slow in getting cpp traces + // because of slow/broken addr2line + // in different system libs + std::shared_ptr tb = + torch::CapturedTraceback::gather( + /*python=*/true, /*script=*/true, /*cpp=*/true); + torch::SymbolizedTracebacks s_tbs = torch::symbolize({tb.get()}); + const auto& s_tb = s_tbs.tracebacks.at(0); + std::stringstream oss; + for (auto idx : c10::irange(s_tb.size())) { + auto frame_id = s_tb[idx]; + const auto& frame = s_tbs.all_frames.at(frame_id); + oss << "#" << idx << " " << frame.funcname << " from " << frame.filename + << ":" << frame.lineno << std::endl; + } + return oss.str(); +} + inline c10::Dict new_dict() { return c10::Dict( c10::AnyType::get(), c10::AnyType::get());