Skip to content

Commit

Permalink
[hotfix] fix inference typo (#5438)
Browse files Browse the repository at this point in the history
  • Loading branch information
hugo-syn authored May 13, 2024
1 parent 785cd9a commit 393c8f5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions colossalai/legacy/inference/async_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def _step(self):
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
has_new_finished, outputs = self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens = 0

else:
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1

else:
Expand All @@ -78,7 +78,7 @@ def _step(self):
else:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1

if has_new_finished:
Expand Down
8 changes: 4 additions & 4 deletions colossalai/legacy/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ def _step(self):
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
yield from self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens = 0
return

if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1
return
else:
Expand All @@ -154,7 +154,7 @@ def _step(self):
else:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1

return
Expand Down Expand Up @@ -243,7 +243,7 @@ def _handle_finish_req(self, batch: Batch, has_new_finished_req):
self._filter_batch(batch)
yield from self._output_process(finished_reqs)

def _filter_runing_batch(self):
def _filter_running_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
self.running_batch = None

Expand Down

0 comments on commit 393c8f5

Please sign in to comment.