diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 0aabde6aa8c5..9572588ce6e5 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -587,7 +587,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): generation_config = GenerationConfig(top_k=top_k, top_p=top_p, do_sample=True) - warpers = generation_model._get_logits_warper(generation_config) + warpers = generation_model._get_logits_warper(generation_config, device) assert len(warpers) == 2 # top_p and top_k seq_group_metadata_list: List[SequenceGroupMetadata] = []