Skip to content

Commit

Permalink
Fixing graph capture for flash decoding. (#2163)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored and ErikKaum committed Jul 26, 2024
1 parent bec6a17 commit ba906df
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
"slots": slots,
"input_lengths": input_lengths,
}
input_lengths = Seqlen(input_lengths=input_lengths)
input_lengths_ = Seqlen(input_lengths=input_lengths)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph

Expand All @@ -939,14 +939,15 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
input_lengths=input_lengths_,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()

with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
Expand Down

0 comments on commit ba906df

Please sign in to comment.