diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index a331e9cf8cfc..9969c6786eab 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -298,8 +298,8 @@ def search_tokens(self, generation_config: GenerationConfig, logits): """ # do logit processor # NOTE: need to decide the granularity to process logits (sequence or batch) + config_dict = generation_config.to_dict() for type in ["top_k", "top_p", "min_p"]: - config_dict = generation_config.to_dict() if type in config_dict and config_dict[type] is not None: logits = logit_processor(type, logits, config_dict[type]) diff --git a/colossalai/inference/logit_processors.py b/colossalai/inference/logit_processors.py index e13f14557c6a..557b3df653cc 100644 --- a/colossalai/inference/logit_processors.py +++ b/colossalai/inference/logit_processors.py @@ -36,21 +36,23 @@ def top_p_logit_processor(logits, top_p: float): cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + + sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1) sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) logits[indices_to_remove] = -float("inf") return logits -def logit_processor(processor:str, logits , attrs): + +def logit_processor(processor: str, logits, attrs): """ do logit process for given logits. Args: - processor(str): the type of logit processor + processor(str): the type of logit processor logits(torch.Tensor): input logits - attrs(dict): attrs of the logit processor + attrs(dict): attrs of the logit processor Returns: logits after process @@ -61,6 +63,6 @@ def logit_processor(processor:str, logits , attrs): func = _LOGIT_PROCESSOR_MAP[processor] try: logits = func(logits, attrs) - except Exception as e: + except Exception: return logits - return logits \ No newline at end of file + return logits