Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

fixes stable beluga generation stopping midway #141

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import torch
from petals import AutoDistributedModelForCausalLM
Expand Down Expand Up @@ -42,12 +42,12 @@ class PetalsBasedModel(ChatModel):
def generate(
cls,
messages: list,
stop: Optional[Union[str, List[str]]] = None,
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "/s>",
**kwargs,
) -> List:
prompt = cls.stitch_prompt(messages, cls.PROMPT_TEMPLATE)
Expand All @@ -61,19 +61,19 @@ def generate(
top_p=top_p,
max_new_tokens=max_tokens,
)
outputs = cls.safe_decode(cls.tokenizer, outputs[0, n_input_tokens:], streaming=stream, stop=stop)
outputs = cls.safe_decode(cls.tokenizer, outputs[0, n_input_tokens:], stop_tokens=stop)
return [outputs]

@classmethod
def generate_streaming(
cls,
messages: list,
stop: Optional[Union[str, List[str]]] = None,
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "/s>",
session=None,
inputs=None,
**kwargs,
Expand All @@ -95,7 +95,7 @@ def generate_streaming(
)
delta = outputs[0, n_input_tokens:].tolist()
token_count = len(delta) # noqa
outputs = cls.safe_decode(cls.tokenizer, delta, streaming=stream, stop=stop)
outputs = cls.safe_decode(cls.tokenizer, delta, stop_tokens=stop)
if not outputs:
return None # end
outputs = outputs.lstrip() if inputs is not None else outputs
Expand Down Expand Up @@ -153,13 +153,13 @@ def stitch_prompt(messages: list, prompt_template: Dict[str, str]) -> str:
def safe_decode(
tokenizer: PreTrainedTokenizer,
outputs: Union[torch.Tensor, List[int]],
streaming: bool = False,
stop: str = "/s>",
stop_tokens: Optional[Union[str, List[str]]] = None,
) -> str:
# Workaround to make SentencePiece .decode() keep leading spaces in a token
fake_token = tokenizer("^")["input_ids"][0]
outputs = outputs.tolist() if isinstance(outputs, torch.Tensor) else outputs
result = tokenizer.decode([fake_token] + outputs)
if streaming:
return result.lstrip("<s>").lstrip(stop)
return result.lstrip("<s>").rsplit("</s>", 1)[0].rsplit(stop, 1)[0].strip()
result = tokenizer.decode([fake_token] + outputs).replace("<s>", "")

for stop_token in stop_tokens:
result = result.split(stop_token)[0]
return result
2 changes: 1 addition & 1 deletion cht-petals/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class ChatCompletionInput(BaseModel):
model: str
messages: List[dict]
stop: Optional[Union[str, List[str]]] = "/s>"
stop: Optional[Union[str, List[str]]] = ["</s>", "/s>"]
temperature: float = 1.0
top_p: float = 1.0
n: int = 1
Expand Down
Loading