Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] remove reset #7975

Merged
merged 2 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

import depyf

# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"

temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
cur_dir = os.path.dirname(__file__)
Expand All @@ -16,19 +20,36 @@

compiled_code = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0]

# we should only trigger Dynamo compilation three times:
# one for the profiling phase (and the compiled artifact will be discarded)
# one for the profiling phase without kv cache
# one for the prefill phase with symbolic shapes
# one for the decode phase with symbolic shapes
# and later calls should not trigger Dynamo compilation again.
# NOTE: it might still trigger XLA compilation.

# check we have three compiled code
# this is the assumption when we use the custom dispatcher
assert len(compiled_code) == 3

# check the first compilation is discarded
with open(full_code) as f:
full_code_content = f.read()
profile_function = compiled_code[0].split(".")[0]
assert profile_function not in full_code_content
# check all the compilations are as expected
compiled_fn = sorted(
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))

# the first compilation is the profiling phase,
# it should not have any kv cache
with open(compiled_fn[0]) as f:
content = f.read()
assert "kv_caches" not in content

# the second compilation is the prefill phase,
# it should have kv cache and the flash_attention op
with open(compiled_fn[1]) as f:
content = f.read()
assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content

# the third compilation is the decode phase,
# it should have kv cache and the paged_attention op
with open(compiled_fn[2]) as f:
content = f.read()
assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content
4 changes: 0 additions & 4 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,10 +1123,6 @@ def profile_run(self) -> None:
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()

# reset and discard the guard and compiled bytecode for profiling runs
torch._dynamo.reset()

return

def remove_all_loras(self):
Expand Down
4 changes: 0 additions & 4 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.

# reset and discard the guard and compiled bytecode for profiling runs
torch._dynamo.reset()

return num_tpu_blocks, num_cpu_blocks

def initialize_cache(
Expand Down
Loading