-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
/
Copy pathtest_chunked_prefill_scheduler.py
564 lines (487 loc) · 21.8 KB
/
test_chunked_prefill_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
from typing import List
from unittest.mock import MagicMock
import pytest # noqa
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler
from vllm.sequence import Logprob, SequenceGroup
from .utils import create_dummy_prompt
def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def append_new_token(seq_group, token_id: int):
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
def test_simple():
"""Verify basic scheduling works."""
block_size = 4
num_seq_group = 4
max_model_len = 16
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
num_seq_group,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# Schedule seq groups prompts.
num_tokens = block_size * num_seq_group
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert out.num_batched_tokens == num_tokens
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == num_seq_group
for s in running:
append_new_token(s, 1)
# Schedule seq groups generation.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert out.num_batched_tokens == num_seq_group
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == num_seq_group
def test_chunk():
"""Verify prefills are chunked properly."""
block_size = 4
max_seqs = 60
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# Verify the second request is chunked.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 60
# Verify it is chunked.
assert seq_group_meta[1].token_chunk_size == 4
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
# Only the first seq group has a new token appended.
append_new_token(running[0], 1)
# One chunked prefill, and one decoding.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# The first one is prefill. Scheduler guarantees ordering.
assert seq_group_meta[0].token_chunk_size == 56
# The second one is a chunked prefill.
assert seq_group_meta[1].token_chunk_size == 1
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 57
def test_complex():
block_size = 4
max_seqs = 60
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()
# Verify the second request is chunked.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 60
# Verify it is chunked.
assert seq_group_meta[1].token_chunk_size == 4
assert not running[0].is_prefill()
assert running[1].is_prefill()
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
# Only the first seq group has a new token appended.
append_new_token(running[0], 1)
# Add 2 more requsets.
for i in range(2, 4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 3
# The first one is the first chunked prefill.
assert seq_group_meta[0].token_chunk_size == 7
# The second one is the second new chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56
# The last one is decode.
assert seq_group_meta[2].token_chunk_size == 1
# Two of them are in chunked prefill.
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
# The first 2 requests are now in decodine phase.
append_new_token(running[0], 1)
assert not running[0].is_prefill()
append_new_token(running[1], 1)
assert not running[1].is_prefill()
# The third request is still in prefill stage.
assert running[2].is_prefill()
def test_maximal_decoding():
"""Verify decoding requests are prioritized."""
block_size = 4
max_seqs = 2
max_model_len = 2
max_num_batched_tokens = 2
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=2)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()
# The first prefill is scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 1
assert seq_group_meta[0].token_chunk_size == 2
assert not running[0].is_prefill()
assert running[1].is_prefill()
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 2
# Only the first seq group has a new token appended.
append_new_token(running[0], 1)
# Create one more seq_group.
_, seq_group = create_dummy_prompt("3", prompt_length=2)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()
# The first decoding + second chunk is scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 2
assert seq_group_meta[0].token_chunk_size == 1
assert seq_group_meta[1].token_chunk_size == 1
assert not running[0].is_prefill()
assert running[1].is_prefill()
assert running[2].is_prefill()
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 2
append_new_token(running[0], 1)
# Decoding + running prefill is prioritized.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 2
assert seq_group_meta[0].token_chunk_size == 1
assert seq_group_meta[1].token_chunk_size == 1
assert not running[0].is_prefill()
assert not running[1].is_prefill()
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 2
append_new_token(running[0], 1)
append_new_token(running[1], 1)
# Only decoding is prioritized.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 2
assert seq_group_meta[0].token_chunk_size == 1
assert seq_group_meta[1].token_chunk_size == 1
assert not running[0].is_prefill()
assert not running[1].is_prefill()
assert out.num_prefill_groups == 0
assert out.num_batched_tokens == 2
append_new_token(running[0], 1)
append_new_token(running[1], 1)
# After aborting the decoding request, the fcfs new prefill is prioritized.
scheduler.abort_seq_group(running[0].request_id)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 2
assert seq_group_meta[0].token_chunk_size == 1
assert seq_group_meta[1].token_chunk_size == 1
assert not running[1].is_prefill()
assert running[2].is_prefill()
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 2
def test_prompt_limit():
"""Verify max_num_batched_tokens < max_model_len is possible."""
block_size = 4
max_seqs = 32
max_model_len = 64
max_num_batched_tokens = 32
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
_, seq_group = create_dummy_prompt("1", prompt_length=48)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()
# The prompt length > max_num_batched_tokens should be still scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 1
assert seq_group_meta[0].token_chunk_size == 32
assert running[0].is_prefill()
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 32
def test_prompt_limit_exceed():
block_size = 4
max_seqs = 64
max_model_len = 32
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
_, seq_group = create_dummy_prompt("2", prompt_length=48)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.ignored_seq_groups) == 1
assert out.ignored_seq_groups[0] == seq_group
def test_swap():
"""Verify swapping works with chunked prefill requests"""
block_size = 4
max_seqs = 30
max_model_len = 200
max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
# The last request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "1"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
# The running prefill is now swapped.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
# Add 1 more task. Swap should be prioritized over new prefill.
_, seq_group = create_dummy_prompt("2", prompt_length=60)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def test_running_prefill_prioritized_over_swap():
block_size = 4
max_seqs = 30
max_model_len = 200
max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
# The request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "1"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
# The running prefill is now swapped.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
# Add 1 more task. Swap is not possible, so prefill is running.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
_, seq_group2 = create_dummy_prompt("2", prompt_length=60)
scheduler.add_seq_group(seq_group2)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert out.scheduled_seq_groups[0].seq_group == seq_group2
# Now although swap is possible, running prefill is prioritized.
scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1)
# Decoding is prioritized.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 1
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1)
# Since we abort the sequence group, we can finally swap.
scheduler.abort_seq_group(seq_group2.request_id)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def test_chunked_prefill_preempt():
"""Verify preempt works with chunked prefill requests"""
block_size = 4
max_seqs = 30
max_model_len = 200
max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
_, seq_group = create_dummy_prompt("1", prompt_length=60)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
# The request should be preempted.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "1"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
# The running prefill is now preempted.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out == []
assert out.blocks_to_swap_in == []
# Make sure we can reschedule preempted request.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
assert seq_group.get_num_uncomputed_tokens() == 30
# We should be able to run prefill twice as it is chunked.
def cannot_append_second_group(seq_group, num_lookahead_slots):
return True
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert not seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
def test_chunked_prefill_max_seqs():
block_size = 4
max_seqs = 2
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
_, seq_group = create_dummy_prompt("1", prompt_length=65)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# The first prefill is chunked.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens
assert len(get_sequence_groups(out)) == 1
# Add new requests.
for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=65)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# Make sure only 2 requests are scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert out.num_batched_tokens == max_num_batched_tokens
assert len(get_sequence_groups(out)) == 2
assert not running[0].is_prefill()
assert running[1].is_prefill()
append_new_token(running[0], 1)
# Although we have enough token budget, we can only schedule max_seqs.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 2
assert seq_group_meta[1].token_chunk_size == 1
assert out.num_batched_tokens == 3
assert len(get_sequence_groups(out)) == max_seqs
assert not running[0].is_prefill()
assert not running[1].is_prefill()