Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
fix: stable beluga generation issues (stopping midway)
Browse files Browse the repository at this point in the history
  • Loading branch information
biswaroop1547 committed Nov 6, 2023
1 parent 4ea6a43 commit 3a055e3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def generate(
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "/s>",
stop: List[str] = ["</s>", "/s>"],
**kwargs,
) -> List:
prompt = cls.stitch_prompt(messages, cls.PROMPT_TEMPLATE)
Expand All @@ -61,7 +61,7 @@ def generate(
top_p=top_p,
max_new_tokens=max_tokens,
)
outputs = cls.safe_decode(cls.tokenizer, outputs[0, n_input_tokens:], streaming=stream, stop=stop)
outputs = cls.safe_decode(cls.tokenizer, outputs[0, n_input_tokens:], stop_tokens=stop)
return [outputs]

@classmethod
Expand All @@ -73,7 +73,7 @@ def generate_streaming(
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "/s>",
stop: List[str] = ["</s>", "/s>"],
session=None,
inputs=None,
**kwargs,
Expand All @@ -95,7 +95,7 @@ def generate_streaming(
)
delta = outputs[0, n_input_tokens:].tolist()
token_count = len(delta) # noqa
outputs = cls.safe_decode(cls.tokenizer, delta, streaming=stream, stop=stop)
outputs = cls.safe_decode(cls.tokenizer, delta, stop_tokens=stop)
if not outputs:
return None # end
outputs = outputs.lstrip() if inputs is not None else outputs
Expand Down Expand Up @@ -153,13 +153,13 @@ def stitch_prompt(messages: list, prompt_template: Dict[str, str]) -> str:
def safe_decode(
tokenizer: PreTrainedTokenizer,
outputs: Union[torch.Tensor, List[int]],
streaming: bool = False,
stop: str = "/s>",
stop_tokens: List[str] = ["</s>", "/s>"],
) -> str:
# Workaround to make SentencePiece .decode() keep leading spaces in a token
fake_token = tokenizer("^")["input_ids"][0]
outputs = outputs.tolist() if isinstance(outputs, torch.Tensor) else outputs
result = tokenizer.decode([fake_token] + outputs)
if streaming:
return result.lstrip("<s>").lstrip(stop)
return result.lstrip("<s>").rsplit("</s>", 1)[0].rsplit(stop, 1)[0].strip()
result = tokenizer.decode([fake_token] + outputs).replace("<s>", "")

for stop_token in stop_tokens:
result = result.split(stop_token)[0]
return result
2 changes: 1 addition & 1 deletion cht-petals/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class ChatCompletionInput(BaseModel):
model: str
messages: List[dict]
stop: Optional[Union[str, List[str]]] = "/s>"
stop: Optional[Union[str, List[str]]] = ["</s>", "/s>"]
temperature: float = 1.0
top_p: float = 1.0
n: int = 1
Expand Down

0 comments on commit 3a055e3

Please sign in to comment.