diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 5a432fb78b3d..d8df86b2aaa1 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,6 +5,10 @@ import depyf +# disable custom dispatcher, let Dynamo takes over +# all the control +os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0" + temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): cur_dir = os.path.dirname(__file__) @@ -16,19 +20,36 @@ compiled_code = sorted( glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) -full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0] + # we should only trigger Dynamo compilation three times: -# one for the profiling phase (and the compiled artifact will be discarded) +# one for the profiling phase without kv cache # one for the prefill phase with symbolic shapes # one for the decode phase with symbolic shapes # and later calls should not trigger Dynamo compilation again. # NOTE: it might still trigger XLA compilation. # check we have three compiled code +# this is the assumption when we use the custom dispatcher assert len(compiled_code) == 3 -# check the first compilation is discarded -with open(full_code) as f: - full_code_content = f.read() - profile_function = compiled_code[0].split(".")[0] - assert profile_function not in full_code_content +# check all the compilations are as expected +compiled_fn = sorted( + glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) + +# the first compilation is the profiling phase, +# it should not have any kv cache +with open(compiled_fn[0]) as f: + content = f.read() + assert "kv_caches" not in content + +# the second compilation is the prefill phase, +# it should have kv cache and the flash_attention op +with open(compiled_fn[1]) as f: + content = f.read() + assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content + +# the third compilation is the decode phase, +# it should have kv cache and the paged_attention op +with open(compiled_fn[2]) as f: + content = f.read() + assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b287a5d2715..de1a2e3235a8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1123,10 +1123,6 @@ def profile_run(self) -> None: device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() - - # reset and discard the guard and compiled bytecode for profiling runs - torch._dynamo.reset() - return def remove_all_loras(self): diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 320b15d3604b..44fa3aed5816 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -143,10 +143,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks = int(self.cache_config.swap_space_bytes // block_size_bytes) num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. - - # reset and discard the guard and compiled bytecode for profiling runs - torch._dynamo.reset() - return num_tpu_blocks, num_cpu_blocks def initialize_cache(