Skip to content

Commit

Permalink
Revert "record_function: remove legacy internal operators (pytorch#72303
Browse files Browse the repository at this point in the history
)"

This reverts commit 0be84bb.

Reverted pytorch#72303 on behalf of https://github.com/izaitsevfb due to Apparently _record_function_enter is still used internally at Meta in several places and in lots of internal tests. ([comment](pytorch#72303 (comment)))
  • Loading branch information
pytorchmergebot committed Oct 24, 2023
1 parent e72fcd3 commit b0087b4
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 15 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/dispatch/ObservedOperators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ std::unordered_set<std::string>& ObservedOperators::getUnobservedOperatorList()
"aten::output_nr",
"aten::_version",
"aten::is_complex",
"profiler::_record_function_enter",
"profiler::_record_function_enter_new",
"profiler::_record_function_exit",
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@
("aten::quantized_rnn_relu_cell", datetime.date(2023, 12, 31)),
("aten::quantized_rnn_tanh_cell", datetime.date(2023, 12, 31)),
("quantized::make_quantized_cell_params", datetime.date(2023, 12, 31)),
("profiler::_record_function_exit", datetime.date(2023, 12, 31)),
]

ALLOW_LIST_COMPILED = [
Expand Down
15 changes: 15 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4257,6 +4257,21 @@ def test_record_function_callbacks(self):
foo_event = [event for event in function_events if "foo" in event.name][0]
self.assertEqual(foo_event.count, 1)

def test_record_function_legacy(self):
# Test the new _record_function ops work
# Note: Remove once record_function uses these directly
x = torch.randn(10, 10)
with profile(use_kineto=kineto_available()) as p:
handle = torch.ops.profiler._record_function_enter("bar", None)
try:
y = x * 2 + 4
finally:
torch.ops.profiler._record_function_exit(handle)

function_events = p.function_events
foo_event = [event for event in function_events if "bar" in event.name][0]
self.assertEqual(foo_event.count, 1)

def test_profiler_aggregation_fake(self):
events = EventList()
id = [0]
Expand Down
1 change: 1 addition & 0 deletions torch/autograd/profiler_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,7 @@ def _filter_name(name):
filtered_out_names = [
MEMORY_EVENT_NAME, # used only for the top-level memory events
OUT_OF_MEMORY_EVENT_NAME,
"profiler::_record_function_enter",
"profiler::_record_function_enter_new",
"profiler::_record_function_exit",
"aten::is_leaf",
Expand Down
107 changes: 93 additions & 14 deletions torch/csrc/autograd/record_function_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
#include <ATen/ThreadLocalState.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/record_function.h>
#include <torch/csrc/autograd/record_function_ops.h>

#include <torch/csrc/jit/runtime/operator.h>
#include <torch/library.h>

namespace caffe2 {
// Required for cpp_custom_type_hack to work
// NOLINTNEXTLINE(bugprone-exception-escape)
CAFFE_KNOWN_TYPE(at::RecordFunction);
} // namespace caffe2

namespace torch {
namespace autograd {
namespace profiler {
Expand All @@ -25,6 +32,16 @@ static void record_function_enter(
}
}

// Legacy signature using cpp_custom_type_hack
static at::Tensor record_function_enter_legacy(
const std::string& name,
const c10::optional<std::string>& args) {
auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
record_function_enter(name, args, *rec);
return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
}

// New signature using custom_class
c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
const std::string& name,
const c10::optional<std::string>& args) {
Expand All @@ -34,43 +51,105 @@ c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
return rec;
}

static at::RecordFunction& getRecordFunctionFromTensor(
const at::Tensor& handle) {
auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
return rec;
}

// Ends the profiling scope created with record_function_enter.
static void record_function_exit(
static void record_function_exit(at::RecordFunction& rec) {
rec.end();
}

// Legacy signature using cpp_custom_type_hack
static void record_function_exit_legacy(const at::Tensor& handle) {
// We don't actually need to do anything with handle just need to persist the
// lifetime until now.
auto& rec = getRecordFunctionFromTensor(handle);
record_function_exit(rec);
}

// New signature using custom_class
static void record_function_exit_new(
const c10::intrusive_ptr<PythonRecordFunction>& record) {
record->record.end();
record_function_exit(record->record);
}

c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
const c10::intrusive_ptr<PythonRecordFunction>& record,
template <typename Func>
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
Func get_record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
// Profiling callback that ends the associated record_function
// and returns the value of the passed in future.
auto futureProfilingFunc = [record](c10::ivalue::Future& fut) {
record->record.end();
// Note: this future is returned to the user to ensure that a call to
// wait() ensures that profiling callbacks have ran. To ensure that this
// is transparent, we must make this future propagate the value of the
// RPC future. Use value() here instead of constValue() to ensure we
// propagate errors.
return fut.value();
};
auto futureProfilingFunc =
[get_record = std::move(get_record)](c10::ivalue::Future& fut) {
auto& rec = get_record();
rec.end();
// Note: this future is returned to the user to ensure that a call to
// wait() ensures that profiling callbacks have ran. To ensure that this
// is transparent, we must make this future propagate the value of the
// RPC future. Use value() here instead of constValue() to ensure we
// propagate errors.
return fut.value();
};
// Define a future that completes after the profiling callbacks are run.
auto profiledFut = fut->then(
at::wrapPropagateTLSState(std::move(futureProfilingFunc)),
fut->elementType());
return profiledFut;
}

// Legacy signature using cpp_custom_type_hack
static c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
const at::Tensor& handle,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
return _call_end_callbacks_on_fut(
[handle]() -> at::RecordFunction& {
TORCH_INTERNAL_ASSERT(
handle.defined(),
"Undefined RecordFunction handle. This can happen if the handle is "
"not correctly persisted and is destroyed before the future is "
"realized.");

return getRecordFunctionFromTensor(handle);
},
fut);
}

// New signature using custom_class
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
const c10::intrusive_ptr<PythonRecordFunction>& record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
return _call_end_callbacks_on_fut(
[record]() -> at::RecordFunction& { return record->record; }, fut);
}

// Internal only, do not use directly, use Python's record_function()
TORCH_LIBRARY_FRAGMENT(profiler, m) {
m.class_<PythonRecordFunction>("_RecordFunction");

m.def(
"_record_function_enter(str name, str? args=None) -> Tensor",
&record_function_enter_legacy);
m.def(
"_record_function_enter_new(str name, str? args=None) -> "
"__torch__.torch.classes.profiler._RecordFunction",
&record_function_enter_new);
m.def("_record_function_exit._RecordFunction", &record_function_exit);
m.def("_record_function_exit", &record_function_exit_legacy);
m.def("_record_function_exit._RecordFunction", &record_function_exit_new);

torch::jit::registerOperator(torch::jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
[](jit::Stack& stack) {
// Pop inputs, which should be a future and a tensor
auto fut = jit::pop(stack).toFuture();
auto tensor = jit::pop(stack).toTensor();
auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut);
// return future that completes when profiling callbacks have run.
jit::push(stack, std::move(profiledFut));
},
c10::AliasAnalysisKind::FROM_SCHEMA));
torch::jit::registerOperator(torch::jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut._RecordFunction("
"__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)",
Expand Down
1 change: 1 addition & 0 deletions torch/fx/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_ops.aten.copy_.default,
_ops.aten.sym_constrain_range.default,
_ops.aten.sym_constrain_range_for_size.default,
_ops.profiler._record_function_enter,
_ops.profiler._record_function_enter_new,
_ops.profiler._record_function_exit,
_ops.inductor.accumulate_grad_.default,
Expand Down
18 changes: 18 additions & 0 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,6 +2461,24 @@ def test_async_record_function_double_end_callbacks(self):
rf._call_end_callbacks_on_future(fut)
fut.wait()

@dist_init
def test_async_record_function_legacy(self):
# Test the legacy _record_function ops work
# Note: These exist for backward compatibility with TorchScript
num_sleep_seconds = 1
if self.rank == 1:
with _profile() as pf:
try:
handle = torch.ops.profiler._record_function_enter("foo", None)
fut = rpc.rpc_async(
worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
)
torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
finally:
torch.ops.profiler._record_function_exit(handle)

fut.wait()

@dist_init
def test_async_record_function_cbs_jit_call(self):
if self.rank == 1:
Expand Down

0 comments on commit b0087b4

Please sign in to comment.