Skip to content

Commit

Permalink
[c10d] logging utility for cpp-python stacktrace (pytorch#118924)
Browse files Browse the repository at this point in the history
user may not know which line of code called collectives in a big code base. When debugging, we can print python-cpp stacktrace in case user call ``ProcessGroup.reduce`` instead of ``torch.distributed.reduce``

```
LOG(INFO) << "ProcessGroupNCCL::_allgather_base stacktrace: "
                       << get_python_cpp_trace();
```

output (using _allgather_base as an example): one example python-part trace is ``all_gather_into_tensor from /data/users/weif/pytorch/torch/distributed/distributed_c10d.py:2838``
```
ProcessGroupNCCL::_allgather_base stacktrace: #0 torch::unwind::unwind() from ??:0
#1 torch::CapturedTraceback::gather(bool, bool, bool) from ??:0
#2 c10d::get_python_cpp_trace[abi:cxx11]() from :0
pytorch#3 c10d::ProcessGroupNCCL::_allgather_base(at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&) from ??:0
pytorch#4 c10d::ops::(anonymous namespace)::_allgather_base_CUDA(at::Tensor&, at::Tensor&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, bool, long) from Ops.cpp:0
pytorch#5 c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<std::tuple<at::Tensor, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > > (*)(at::Tensor&, at::Tensor&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, bool, long), std::tuple<at::Tensor, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > >, c10::guts::typelist::typelist<at::Tensor&, at::Tensor&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, bool, long> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) from :0
pytorch#6 torch::autograd::basicAutogradNotImplementedFallbackImpl(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) from autograd_not_implemented_fallback.cpp:0
pytorch#7 c10d::ProcessGroup::_allgather_base(at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&) from :0
pytorch#8 pybind11::cpp_function::initialize<pybind11::cpp_function::initialize<c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> >, c10d::ProcessGroup, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > (c10d::ProcessGroup::*)(at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(c10d::ProcessGroup*, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&)#1}, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> >, c10d::ProcessGroup*, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(pybind11::cpp_function::initialize<c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> >, c10d::ProcessGroup, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > (c10d::ProcessGroup::*)(at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(c10d::ProcessGroup*, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&)#1}&&, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > (*)(c10d::ProcessGroup*, at::Tensor&, at::Tensor&, c10d::AllgatherOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(pybind11::detail::function_call&)pytorch#3}::_FUN(pybind11::detail::function_call&) from :0
pytorch#9 pybind11::cpp_function::dispatcher(_object*, _object*, _object*) from :0
pytorch#10 cfunction_call from /usr/local/src/conda/python-3.10.12/Objects/methodobject.c:543
pytorch#11 _PyObject_MakeTpCall from /usr/local/src/conda/python-3.10.12/Objects/call.c:215
pytorch#12 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:112
pytorch#13 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#14 all_gather_into_tensor from /data/users/weif/pytorch/torch/distributed/distributed_c10d.py:2838
pytorch#15 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#16 do_call_core from /usr/local/src/conda/python-3.10.12/Python/ceval.c:5945
pytorch#17 wrapper from /data/users/weif/pytorch/torch/distributed/c10d_logger.py:75
pytorch#18 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#19 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#20 _all_gather_flat_param from /data/users/weif/pytorch/torch/distributed/fsdp/_flat_param.py:1399
pytorch#21 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#22 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#23 unshard from /data/users/weif/pytorch/torch/distributed/fsdp/_flat_param.py:1308
pytorch#24 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#25 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#26 _unshard from /data/users/weif/pytorch/torch/distributed/fsdp/_runtime_utils.py:332
pytorch#27 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#28 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#29 _pre_forward_unshard from /data/users/weif/pytorch/torch/distributed/fsdp/_runtime_utils.py:448
pytorch#30 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#31 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#32 _pre_forward from /data/users/weif/pytorch/torch/distributed/fsdp/_runtime_utils.py:413
pytorch#33 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#34 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#35 forward from /data/users/weif/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py:839
pytorch#36 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#37 do_call_core from /usr/local/src/conda/python-3.10.12/Python/ceval.c:5945
pytorch#38 _call_impl from /data/users/weif/pytorch/torch/nn/modules/module.py:1520
pytorch#39 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#40 do_call_core from /usr/local/src/conda/python-3.10.12/Python/ceval.c:5945
pytorch#41 _wrapped_call_impl from /data/users/weif/pytorch/torch/nn/modules/module.py:1511
pytorch#42 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#43 _PyObject_Call_Prepend from /usr/local/src/conda/python-3.10.12/Objects/call.c:431
pytorch#44 slot_tp_call from /usr/local/src/conda/python-3.10.12/Objects/typeobject.c:7494
pytorch#45 _PyObject_MakeTpCall from /usr/local/src/conda/python-3.10.12/Objects/call.c:215
pytorch#46 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:112
pytorch#47 inner from /data/users/weif/pytorch/run_fsdp.py:72
pytorch#48 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#49 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#50 run from /data/users/weif/pytorch/run_fsdp.py:76
pytorch#51 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#52 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#53 main from /data/users/weif/pytorch/run_fsdp.py:133
pytorch#54 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#55 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:114
pytorch#56 <module> from /data/users/weif/pytorch/run_fsdp.py:137
pytorch#57 _PyEval_EvalFrame from /usr/local/src/conda/python-3.10.12/Include/internal/pycore_ceval.h:46
pytorch#58 PyEval_EvalCode from /usr/local/src/conda/python-3.10.12/Python/ceval.c:1134
pytorch#59 run_eval_code_obj from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:1291
pytorch#60 run_mod from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:1312
pytorch#61 pyrun_file from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:1208
pytorch#62 _PyRun_SimpleFileObject from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:456
pytorch#63 _PyRun_AnyFileObject from /usr/local/src/conda/python-3.10.12/Python/pythonrun.c:90
pytorch#64 pymain_run_file_obj from /usr/local/src/conda/python-3.10.12/Modules/main.c:357
pytorch#65 Py_BytesMain from /usr/local/src/conda/python-3.10.12/Modules/main.c:1090
pytorch#66 __libc_start_call_main from ??:0
pytorch#67 <unwind unsupported> from ??:0
```

Pull Request resolved: pytorch#118924
Approved by: https://github.com/kwen2501
  • Loading branch information
weifengpy authored and pytorchmergebot committed Feb 2, 2024
1 parent a3cec6a commit 63fd688
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions torch/csrc/distributed/c10d/TraceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,28 @@ inline std::string pickle_str(const c10::IValue& v) {
return std::string(result.begin(), result.end());
}

inline std::string get_python_cpp_trace() {
// usage:
// LOG(INFO) << "stacktrace: "
// << get_python_cpp_trace();
// warn: might be slow in getting cpp traces
// because of slow/broken addr2line
// in different system libs
std::shared_ptr<torch::CapturedTraceback> tb =
torch::CapturedTraceback::gather(
/*python=*/true, /*script=*/true, /*cpp=*/true);
torch::SymbolizedTracebacks s_tbs = torch::symbolize({tb.get()});
const auto& s_tb = s_tbs.tracebacks.at(0);
std::stringstream oss;
for (auto idx : c10::irange(s_tb.size())) {
auto frame_id = s_tb[idx];
const auto& frame = s_tbs.all_frames.at(frame_id);
oss << "#" << idx << " " << frame.funcname << " from " << frame.filename
<< ":" << frame.lineno << std::endl;
}
return oss.str();
}

inline c10::Dict<c10::IValue, c10::IValue> new_dict() {
return c10::Dict<c10::IValue, c10::IValue>(
c10::AnyType::get(), c10::AnyType::get());
Expand Down

0 comments on commit 63fd688

Please sign in to comment.