Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] delete dynamo cache entry when guard function is invalidated #117875

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9249,6 +9249,94 @@ def fn():
self.assertIn(0, result)
self.assertTrue(same(result[0], torch.tensor(3)))

@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_module_free(self):
"""Test that CUDA memory is freed when a model goes out of scope"""

class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.fc = torch.nn.Linear(10000, 10000)

def forward(self, out):
return self.fc(out)

def run(compile):
mod = Mod().cuda()
if compile:
mod = torch.compile(mod, backend="eager")
inp = torch.rand(10000, 10000).cuda()
mod(inp)

def clean_and_report_memory():
import gc

gc.collect()
return torch.cuda.memory_allocated()

run(False)
# mem1 = clean_and_report_memory()
run(True)
mem2 = clean_and_report_memory()
torch._dynamo.reset_code_caches()
mem3 = clean_and_report_memory()

# it's possible for dynamo to hold on to more memory
# even after a _dynamo.reset[_code_caches], so we omit the following check.
# self.assertEqual(mem1, mem2)

self.assertEqual(mem2, mem3)

def test_dynamo_cache_invalidate(self):
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.fc = torch.nn.Linear(3, 3)

def forward(self, out):
return self.fc(out)

def fn(x, mod):
return mod(x)

opt_fn = torch.compile(fn, backend="eager")

m1 = Mod()
m2 = Mod()
m3 = Mod()
inp = torch.randn(3, 3)

# NOTE: assumes that each cache entry is guarded
# on unique Mod instance
opt_fn(inp, m1)
opt_fn(inp, m2)
opt_fn(inp, m3)

c1 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c1), 3)

# move cache entry to front
opt_fn(inp, m2)
c2 = _debug_get_cache_entry_list(fn.__code__)
self.assertIs(c1[1], c2[0])

# delete center of cache
del m3
c3 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c3), 2)
self.assertIs(c3[0], c2[0])
self.assertIs(c3[1], c2[2])

# delete end of cache
del m1
c4 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c4), 1)
self.assertIs(c4[0], c3[0])

del m2
c5 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c5), 0)


class TestTracer(JitTestCase):
def test_jit_save(self):
Expand Down
18 changes: 12 additions & 6 deletions torch/_dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@
def reset() -> None:
"""Clear all compile caches and restore initial state"""
with eval_frame.compile_lock:
for weak_code in (
convert_frame.input_codes.seen + convert_frame.output_codes.seen
):
code = weak_code()
if code:
reset_code(code)
reset_code_caches()
convert_frame.input_codes.clear()
convert_frame.output_codes.clear()
orig_code_map.clear()
Expand All @@ -82,4 +77,15 @@ def reset() -> None:
_reset_guarded_backend_cache()
reset_frame_count()
torch._C._dynamo.compiled_autograd.clear_cache()


def reset_code_caches() -> None:
"""Clear compile caches that are keyed by code objects"""
with eval_frame.compile_lock:
for weak_code in (
convert_frame.input_codes.seen + convert_frame.output_codes.seen
):
code = weak_code()
if code:
reset_code(code)
code_context.clear()
5 changes: 4 additions & 1 deletion torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import traceback
import types
import warnings
import weakref
from enum import Enum
from os.path import dirname, join
from typing import (
Expand Down Expand Up @@ -346,7 +347,9 @@ def get_compiler_config():
# Assume that the underlying node metadata of `fn`,
# a GraphModule instance, accurately represents
# all instances of type(fn).
code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = fn
code_context.get_context(fn.forward.__code__)[
"orig_graphmodule"
] = weakref.ref(fn)

# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
Expand Down
22 changes: 18 additions & 4 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from . import config, convert_frame, exc, mutation_guard
from .eval_frame import set_guard_error_hook
from .source import DefaultsSource, LocalSource, TypeSource
from .types import GuardedCode, GuardFail, GuardFn # noqa: F401
from .types import CacheEntry, GuardedCode, GuardFail, GuardFn # noqa: F401
from .utils import (
common_constant_types,
dict_keys_getitem,
Expand Down Expand Up @@ -932,6 +932,10 @@ def must_add_nn_module_guards(guard):
)


class DeletedGuardFn:
pass


# NB: Naively, you'd expect this to only be a function that produces
# the callable that constitutes the guard. However, there is some
# delicate handling for invalidating this check function when the
Expand Down Expand Up @@ -1191,13 +1195,23 @@ def convert(size_or_stride):
"G": builder.scope["G"],
}
guard_fn.guard_fail_fn = guard_fail_fn
# will be populated by a weakref to a CacheEntry by eval_frame.c,
# when the CacheEntry is constructed
guard_fn.cache_entry = None
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
return guard_fn

def invalidate(self):
# A weakref is no longer valid, self.check_fn should return false
# TODO(janimesh) - Free up cache entry after the cache entry formation
# is in python, and the underlying data structure is a doubly linked
# list.
if (
self.valid
and hasattr(self, "check_fn")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this be false?

and self.check_fn is not DeletedGuardFn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.valid is already checked, why do we need DeletedGuardFn?

Copy link
Member Author

@williamwen42 williamwen42 Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to clean up check_fn and prevent it from being used again. DeletedGuardFn could be replaced with None - I think that using a class here makes it more clear (for debugging/error logging purposes) that we intentionally removed the reference to the actual guard function.

and (cache_entry := self.check_fn.cache_entry()) is not None
):
assert isinstance(cache_entry, CacheEntry)
cache_entry.invalidate() # type: ignore[attr-defined]
# to make sure we don't try using check_fn again
self.check_fn = DeletedGuardFn
self.valid = False

def id_ref(self, obj):
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,8 +1612,8 @@ def get_trace_call_log_str():
is_retracing = False
if tx.f_code is not self._cur_code:
orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
"orig_graphmodule", None
)
"orig_graphmodule", lambda: None
)()
if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
is_retracing = True
self._orig_gm_meta = [
Expand Down
12 changes: 6 additions & 6 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,12 +2195,12 @@ def create_call_resume_at(self, inst):
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
"orig_graphmodule", None
)
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)[
"orig_graphmodule"
] = orig_graphmodule_maybe
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)

if new_code.co_freevars:
cg.make_function_with_closure(name, new_code, True, stack_len)
Expand Down Expand Up @@ -2346,7 +2346,7 @@ def get_trace_call_log_str():
# but it is enough to add a context for `forward` in case it is called.
code_context.get_context(module.forward.__code__)[
"orig_graphmodule"
] = module
] = weakref.ref(module)

tracer: InliningInstructionTranslator
if is_generator(code):
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import sys
import types
import weakref
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union

from typing_extensions import TypeAlias
Expand Down Expand Up @@ -37,6 +38,7 @@ class GuardFn(Protocol):
verbose_code_parts: List[str]
global_scope: Dict[str, object]
guard_fail_fn: Optional[Callable[[GuardFail], None]]
cache_entry: Optional[weakref.ref] # type: ignore[type-arg]

# maps locals of user function to bool
def __call__(self, f_locals: Dict[str, object]) -> bool:
Expand Down
Loading
Loading