Skip to content

Commit

Permalink
Add model_alias option to override model_path in completions. Closes g…
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed May 16, 2023
1 parent 214589e commit a335292
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
19 changes: 14 additions & 5 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def generate(
if tokens_or_none is not None:
tokens.extend(tokens_or_none)

def create_embedding(self, input: str) -> Embedding:
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
"""Embed a string.
Args:
Expand All @@ -532,6 +532,7 @@ def create_embedding(self, input: str) -> Embedding:
An embedding object.
"""
assert self.ctx is not None
_model: str = model if model is not None else self.model_path

if self.params.embedding == False:
raise RuntimeError(
Expand Down Expand Up @@ -561,7 +562,7 @@ def create_embedding(self, input: str) -> Embedding:
"index": 0,
}
],
"model": self.model_path,
"model": _model,
"usage": {
"prompt_tokens": n_tokens,
"total_tokens": n_tokens,
Expand Down Expand Up @@ -598,6 +599,7 @@ def _create_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
Expand All @@ -610,6 +612,7 @@ def _create_completion(
text: bytes = b""
returned_characters: int = 0
stop = stop if stop is not None else []
_model: str = model if model is not None else self.model_path

if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
Expand Down Expand Up @@ -708,7 +711,7 @@ def _create_completion(
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": _model,
"choices": [
{
"text": text[start:].decode("utf-8", errors="ignore"),
Expand Down Expand Up @@ -737,7 +740,7 @@ def _create_completion(
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": _model,
"choices": [
{
"text": text[returned_characters:].decode(
Expand Down Expand Up @@ -807,7 +810,7 @@ def _create_completion(
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": _model,
"choices": [
{
"text": text_str,
Expand Down Expand Up @@ -842,6 +845,7 @@ def create_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
Expand Down Expand Up @@ -883,6 +887,7 @@ def create_completion(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks
Expand All @@ -909,6 +914,7 @@ def __call__(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
Expand Down Expand Up @@ -950,6 +956,7 @@ def __call__(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
)

def _convert_text_completion_to_chat(
Expand Down Expand Up @@ -1026,6 +1033,7 @@ def create_chat_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages.
Expand Down Expand Up @@ -1064,6 +1072,7 @@ def create_chat_completion(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
Expand Down
24 changes: 20 additions & 4 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class Settings(BaseSettings):
model: str = Field(
description="The path to the model to use for generating completions."
)
model_alias: Optional[str] = Field(
default=None,
description="The alias of the model to use for generating completions.",
)
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
n_gpu_layers: int = Field(
default=0,
Expand Down Expand Up @@ -64,6 +68,7 @@ class Settings(BaseSettings):

router = APIRouter()

settings: Optional[Settings] = None
llama: Optional[llama_cpp.Llama] = None


Expand Down Expand Up @@ -101,6 +106,12 @@ def create_app(settings: Optional[Settings] = None):
if settings.cache:
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
llama.set_cache(cache)

def set_settings(_settings: Settings):
global settings
settings = _settings

set_settings(settings)
return app


Expand All @@ -112,6 +123,10 @@ def get_llama():
yield llama


def get_settings():
yield settings


model_field = Field(description="The model to use for generating completions.")

max_tokens_field = Field(
Expand Down Expand Up @@ -236,7 +251,6 @@ def create_completion(
completion_or_chunks = llama(
**request.dict(
exclude={
"model",
"n",
"best_of",
"logit_bias",
Expand Down Expand Up @@ -274,7 +288,7 @@ class Config:
def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
return llama.create_embedding(**request.dict(exclude={"user"}))


class ChatCompletionRequestMessage(BaseModel):
Expand Down Expand Up @@ -335,7 +349,6 @@ def create_chat_completion(
completion_or_chunks = llama.create_chat_completion(
**request.dict(
exclude={
"model",
"n",
"logit_bias",
"user",
Expand Down Expand Up @@ -378,13 +391,16 @@ class ModelList(TypedDict):

@router.get("/v1/models", response_model=GetModelResponse)
def get_models(
settings: Settings = Depends(get_settings),
llama: llama_cpp.Llama = Depends(get_llama),
) -> ModelList:
return {
"object": "list",
"data": [
{
"id": llama.model_path,
"id": settings.model_alias
if settings.model_alias is not None
else llama.model_path,
"object": "model",
"owned_by": "me",
"permissions": [],
Expand Down

0 comments on commit a335292

Please sign in to comment.