From f636fd6080d723d860fbc03c5774dc243f237881 Mon Sep 17 00:00:00 2001 From: Sam Julien Date: Thu, 19 Sep 2024 11:09:02 -0700 Subject: [PATCH] Add support for Writer models --- fastchat/model/model_registry.py | 7 +++++ fastchat/serve/api_provider.py | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 2eed9649e..280ae1ecf 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -1000,3 +1000,10 @@ def get_model_info(name: str) -> ModelInfo: "https://huggingface.co/cllm", "consistency-llm is a new generation of parallel decoder LLMs with fast generation speed.", ) + +register_model_info( + ["palmyra-x-004"], + "Palmyra X 004", + "https://dev.writer.com/home/models", + "Palmyra by Writer", +) diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index e00326b30..4509b0c5b 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -246,6 +246,16 @@ def get_api_provider_stream_iter( api_key=model_api_dict["api_key"], conversation_id=state.conv_id, ) + elif model_api_dict["api_type"] == "writer": + prompt = conv.to_openai_api_messages() + stream_iter = writer_api_stream_iter( + model_name, + prompt, + temperature, + top_p, + max_tokens=max_new_tokens, + api_key=model_api_dict["api_key"], + ) else: raise NotImplementedError() @@ -1264,3 +1274,41 @@ def metagen_api_stream_iter( "text": f"**API REQUEST ERROR** Reason: Unknown.", "error_code": 1, } + + +def writer_api_stream_iter( + model_name, messages, temperature, top_p, max_tokens, api_key +): + from writerai import Writer + + api_key = api_key or os.environ["WRITER_API_KEY"] + + client = Writer(api_key=api_key) + + # Make requests + gen_params = { + "model": model_name, + "messages": messages, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + res = client.chat.chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + model=model_name, + stream=True, + ) + text = "" + for chunk in res: + if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: + text += chunk.choices[0].delta.content + data = { + "text": text, + "error_code": 0, + } + yield data