Skip to content

Commit

Permalink
Misc Fixes (#15) [release]
Browse files Browse the repository at this point in the history
* UTF-8 fix

* Support cosine similarity loss for train sentence transformers

* Fix coverage

* Fix peft

* Fix UTF-8 setting

* Fix UTF-8 setting

* Fix coverage

* Support new OpenAI models

* Update openai

* Update LLMs

* Add chat prompt templates for new models

* Bump version
  • Loading branch information
AjayP13 authored Mar 17, 2024
1 parent 9d834d5 commit b48b58e
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 89 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "DataDreamer"
version = "0.23.0"
version = "0.24.0"
description = "Prompt. Generate Synthetic Data. Train & Align Models."
license = "MIT"
authors= [
Expand Down
29 changes: 28 additions & 1 deletion src/datadreamer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
import locale
import logging
import os
import sys
from collections import UserDict, defaultdict
from multiprocessing.context import SpawnProcess
from threading import Lock
Expand Down Expand Up @@ -368,12 +370,37 @@ def _disable_setfit_logging():
setfit_logging.set_verbosity_error()
setfit_logging.disable_progress_bar()

def __enter__(self):
def __enter__(self): # noqa: C901
from .utils.distributed_utils import is_distributed

if hasattr(DataDreamer.ctx, "steps"):
raise RuntimeError("Only one DataDreamer context may be active at a time.")

# Set encoding to UTF-8
is_utf_8_encoding = lambda: any( # noqa: E731
utf8_locale in (locale.getlocale(locale.LC_CTYPE)[1] or "").lower()
for utf8_locale in ["utf8", "utf-8"]
)
if not is_utf_8_encoding(): # pragma: no cover
# Detect if the default encoding is not UTF-8 and try to see if it is available
# and can be changed. This is to fix a bug on some improperly configured older
# Linux systems.
# See: https://github.com/datadreamer-dev/DataDreamer/issues/13
for locale_string in ["C.UTF8", "C.UTF-8", "en_US.UTF-8"]:
try:
locale.setlocale(locale.LC_CTYPE, locale_string)
if is_utf_8_encoding():
# Worked we were able to reset the encoding back to UTF-8
# Now, we apply hacks to now set the encodings to utf-8 across some of
# the standard places where Python may use the wrong encoding.
sys.stdin.reconfigure(encoding="utf-8") # type:ignore[attr-defined]
sys.stdout.reconfigure(encoding="utf-8") # type:ignore[attr-defined]
sys.stderr.reconfigure(encoding="utf-8") # type:ignore[attr-defined]
locale.getpreferredencoding = lambda do_setlocale=True: "utf-8"
break
except locale.Error:
pass

# Initialize
_DATADREAMER_CTX_LOCK.acquire()
if self.output_folder_path:
Expand Down
18 changes: 16 additions & 2 deletions src/llms/_chat_prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
CHAT_PROMPT_TEMPLATES = {
"llama_system": "[INST] <<SYS>>\n{{system_prompt}}\n<</SYS>>\n\n{{prompt}} [/INST] ",
"llama": "[INST] {{prompt}} [/INST] ",
"olmo": "<|endoftext|><|user|>\n{{prompt}}\n<|assistant|>\n",
"command_r": "<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{prompt}}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
"phi": "Instruct: {{prompt}}\nOutput: ",
"openchat": "GPT4 Correct User: {{prompt}}<|end_of_turn|>GPT4 Correct Assistant: ",
"orca_hashes": "### System:\n{{system_prompt}}\n\n### User:\n{{prompt}}\n\n### Assistant:\n",
Expand All @@ -16,6 +18,7 @@
"chatml": "<|im_start|>user\n{{prompt}}<|im_end|>\n<|im_start|>assistant\n",
"tinyllama": "<|system|>\n{{system_prompt}}</s>\n<|user|>\n{{prompt}}</s>\n<|assistant|>\n",
"zephyr": "<|system|>\n</s>\n<|user|>\n{{prompt}}</s>\n<|assistant|>\n",
"zephyr_system": "<|system|>\n{{system_prompt}}</s>\n<|user|>\n{{prompt}}</s>\n<|assistant|>\n",
"oasst_system": "<|system|>{{system_prompt}}</s><|prompter|>{{prompt}}</s><|assistant|>",
"oasst": "<|prompter|>{{prompt}}<|endoftext|><|assistant|>",
"oasst_h2o": "<|prompt|>{{prompt}}<|endoftext|><|answer|>",
Expand All @@ -31,6 +34,7 @@
"decilm": "### System:\n{{system_prompt}}\n### User:\n{{prompt}}\n### Assistant:\n",
"vicuna": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {{prompt}}\nASSISTANT: ",
"vicuna_v1": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: {{prompt}}\n### Assistant: ",
"xwin": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{prompt}} ASSISTANT: ",
"vicuna_simple": "USER: {{prompt}}\nASSISTANT: ",
"minotaur": "The following is a chat between a USER and a friendly and helpful ASSISTANT.\nUSER: {{prompt}}\nASSISTANT: ",
"bluemoon": "{{system_prompt}}\nLEAD: {{prompt}}\nASSOCIATE: ",
Expand Down Expand Up @@ -61,6 +65,7 @@

SYSTEM_PROMPT_TYPES = {
"llama_system": "llama_system",
"zephyr_system": "llama_system",
"tinyllama": "llama_system",
"orca_hashes": "llama_system",
"openorca_openchat": "openorca_openchat",
Expand Down Expand Up @@ -139,7 +144,10 @@ def _model_name_to_chat_prompt_template_type( # noqa: C901
):
chat_prompt_template_type = "llama"
elif "zephyr-" in model_name_lower and "stablelm" not in model_name_lower:
chat_prompt_template_type = "zephyr"
if "-beta" in model_name_lower:
chat_prompt_template_type = "zephyr_system"
else:
chat_prompt_template_type = "zephyr"
elif all(
fragment in model_name_lower for fragment in ["mistral-", "-instruct"]
):
Expand All @@ -148,8 +156,14 @@ def _model_name_to_chat_prompt_template_type( # noqa: C901
fragment in model_name_lower for fragment in ["mixtral-", "-instruct"]
):
chat_prompt_template_type = "llama"
elif all(fragment in model_name_lower for fragment in ["olmo-", "-instruct"]):
chat_prompt_template_type = "olmo"
elif all(fragment in model_name_lower for fragment in ["c4ai-", "-command-r"]):
chat_prompt_template_type = "command_r"
elif all(fragment in model_name_lower for fragment in ["phi-", "-2"]):
chat_prompt_template_type = "phi"
elif "xwin" in model_name_lower:
chat_prompt_template_type = "xwin"
elif all(fragment in model_name_lower for fragment in ["solar-", "-instruct"]):
chat_prompt_template_type = "solar"
elif all(fragment in model_name_lower for fragment in ["yi-", "-chat"]):
Expand Down Expand Up @@ -355,7 +369,7 @@ def _model_name_to_system_prompt(
return None

# Try to get the system prompt from `transformers`
# Skipping due to https://github.com/huggingface/transformers/pull/26765
# TODO: Skipping due to https://github.com/huggingface/transformers/pull/26765
# result = _chat_prompt_template_and_system_prompt(
# model_name=model_name, revision=revision
# )
Expand Down
22 changes: 21 additions & 1 deletion src/llms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@ def _check_max_new_tokens_possible(

# Check max_new_tokens
max_context_length = self.get_max_context_length(max_new_tokens=0)
max_output_length = self._get_max_output_length()
max_new_tokens_possible = self.get_max_context_length(
max_new_tokens=max_prompt_length
)
if max_new_tokens_possible > 0 and max_new_tokens is None:
max_new_tokens = max_new_tokens_possible
max_new_tokens = min(
max_new_tokens_possible, max_output_length or max_new_tokens_possible
)
elif max_output_length is not None and (
max_new_tokens is not None and max_new_tokens > max_output_length
):
raise ValueError(
"The requested max_new_tokens exceeds the maximum output length of the"
" model."
)
elif (
max_new_tokens_possible <= 0
or (max_new_tokens is not None and max_new_tokens_possible < max_new_tokens)
Expand Down Expand Up @@ -117,6 +127,16 @@ def get_max_context_length(self, max_new_tokens: int) -> int:
"""
pass

def _get_max_output_length(self) -> None | int:
"""Gets the maximum output length for the model. If there is no maximum output
limit and the only limit is the context length of the model, ``None`` is
returned.
Returns:
The maximum output length.
"""
return None

def format_prompt( # noqa: C901
self,
max_new_tokens: None | int = None,
Expand Down
51 changes: 39 additions & 12 deletions src/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,22 @@ def _normalize_model_name(model_name: str) -> str:
def _is_gpt_3(model_name: str):
model_name = _normalize_model_name(model_name)
return any(
gpt3_name in model_name
for gpt3_name in ["davinci", "ada", "curie", "gpt-3-", "gpt-3.5-", "gpt-35-"]
gpt3_name in model_name for gpt3_name in ["davinci", "ada", "curie", "gpt-3-"]
)


@lru_cache(maxsize=None)
def _is_gpt_3_5(model_name: str):
model_name = _normalize_model_name(model_name)
return any(gpt35_name in model_name for gpt35_name in ["gpt-3.5-", "gpt-35-"])


@lru_cache(maxsize=None)
def _is_gpt_3_5_legacy(model_name: str):
model_name = _normalize_model_name(model_name)
return _is_gpt_3_5(model_name) and (
"-0613" in model_name
or (_is_instruction_tuned(model_name) and not _is_chat_model(model_name))
)


Expand All @@ -54,11 +68,17 @@ def _is_gpt_4(model_name: str):
)


@lru_cache(maxsize=None)
def _is_preview_model(model_name: str):
model_name = _normalize_model_name(model_name)
return "-preview" in model_name


@lru_cache(maxsize=None)
def _is_chat_model(model_name: str):
model_name = _normalize_model_name(model_name)
return (
"gpt-3.5-" in model_name or "gpt-35-" in model_name or _is_gpt_4(model_name)
_is_gpt_3_5(model_name) or _is_gpt_4(model_name)
) and not model_name.endswith("-instruct")


Expand Down Expand Up @@ -180,7 +200,7 @@ def client(self) -> openai.OpenAI | openai.AzureOpenAI:
def tokenizer(self) -> Encoding:
try:
return tiktoken.encoding_for_model(self.model_name)
except KeyError:
except KeyError: # pragma: no cover
return tiktoken.get_encoding("cl100k_base")

@ring.lru(maxsize=128)
Expand All @@ -202,17 +222,17 @@ def get_max_context_length(self, max_new_tokens: int) -> int: # pragma: no cove
# (system prompt, user prompt, assistant response)
# and then we have to account for the system prompt
format_tokens = 4 * 3 + self.count_tokens(cast(str, self.system_prompt))
if "-preview" in model_name:
max_context_length = 128000
if "32k" in model_name:
max_context_length = 32768
elif "16k" in model_name:
max_context_length = 16384
elif "gpt-3.5-turbo" in model_name or "gpt-35-turbo" in model_name:
if "-1106" in model_name:
max_context_length = 16385
else:
elif _is_preview_model(model_name):
max_context_length = 128000
elif _is_gpt_3_5(self.model_name):
if _is_gpt_3_5_legacy(self.model_name):
max_context_length = 4096
else:
max_context_length = 16385
elif model_name.startswith("text-davinci"):
max_context_length = 4097
elif model_name.startswith("code-davinci"):
Expand All @@ -226,6 +246,13 @@ def get_max_context_length(self, max_new_tokens: int) -> int: # pragma: no cove
max_context_length = 8192
return max_context_length - max_new_tokens - format_tokens

def _get_max_output_length(self) -> None | int: # pragma: no cover
if (_is_gpt_4(self.model_name) and _is_preview_model(self.model_name)) or (
_is_gpt_3_5(self.model_name) and not (_is_gpt_3_5_legacy(self.model_name))
):
return 4096
return None

@ring.lru(maxsize=5000)
def count_tokens(self, value: str) -> int:
"""Counts the number of tokens in a string.
Expand Down Expand Up @@ -398,7 +425,7 @@ def max_length_func(prompts: list[str]) -> int:

@cached_property
def model_card(self) -> None | str:
if _is_gpt_3(self.model_name):
if _is_gpt_3(self.model_name) or _is_gpt_3_5(self.model_name):
return (
"https://github.com/openai/gpt-3/blob/"
"d7a9bb505df6f630f9bab3b30c889e52f22eb9ea/model-card.md"
Expand All @@ -414,7 +441,7 @@ def license(self) -> None | str:
@cached_property
def citation(self) -> None | list[str]:
citations = []
if _is_gpt_3(self.model_name):
if _is_gpt_3(self.model_name) or _is_gpt_3_5(self.model_name):
citations.append(
"""
@article{brown2020language,
Expand Down
72 changes: 69 additions & 3 deletions src/tests/llms/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from ...llms._litellm import LiteLLM
from ...llms.hf_transformers import CachedTokenizer
from ...llms.llm import _check_temperature_and_top_p
from ...llms.llm import _check_max_new_tokens_possible, _check_temperature_and_top_p
from ...utils.hf_model_utils import get_orig_model
from ...utils.import_utils import (
ignore_litellm_warnings,
Expand Down Expand Up @@ -378,6 +378,55 @@ def test_check_temperature_and_top_p(self):
supports_one_top_p=False,
) == (0.3, 0.999)

def test_check_max_new_tokens_possible(create_datadreamer):
# Check max output length
llm = OpenAI("gpt-4")
assert _check_max_new_tokens_possible(
self=llm,
max_length_func=lambda prompts: 100,
prompts=[],
max_new_tokens=None,
) == (8174 - 100)
assert (
_check_max_new_tokens_possible(
self=llm,
max_length_func=lambda prompts: 100,
prompts=[],
max_new_tokens=5000,
)
== 5000
)
llm = OpenAI("gpt-4-turbo-preview")
assert (
_check_max_new_tokens_possible(
self=llm,
max_length_func=lambda prompts: 100,
prompts=[],
max_new_tokens=None,
)
== 4096
)
assert (
_check_max_new_tokens_possible(
self=llm,
max_length_func=lambda prompts: 100,
prompts=[],
max_new_tokens=4096,
)
== 4096
)
# Make sure an error is thrown if the model's output length is surpassed
with pytest.raises(ValueError):
assert (
_check_max_new_tokens_possible(
self=llm,
max_length_func=lambda prompts: 100,
prompts=[],
max_new_tokens=5000,
)
== 4096
)

def test_run_with_no_prompts(self, create_datadreamer):
llm = OpenAI("gpt-3.5-turbo-instruct")
generated_texts = llm.run(
Expand Down Expand Up @@ -997,10 +1046,27 @@ def test_count_tokens(self, create_datadreamer):

def test_get_max_context_length(self, create_datadreamer):
with create_datadreamer():
# Check max context length
llm = OpenAI("gpt-4")
assert llm.get_max_context_length(max_new_tokens=0) == 8174
llm = OpenAI("gpt-4-turbo-preview")
assert llm.get_max_context_length(max_new_tokens=0) == 127982
llm = OpenAI("gpt-3.5-turbo")
assert llm.get_max_context_length(max_new_tokens=0) == 16367
llm = OpenAI("gpt-3.5-turbo-instruct")
assert llm.get_max_context_length(max_new_tokens=0) == 4096

def test_get_max_output_length(self, create_datadreamer):
with create_datadreamer():
# Check max output length
llm = OpenAI("gpt-4")
assert llm._get_max_output_length() is None
llm = OpenAI("gpt-4-turbo-preview")
assert llm._get_max_output_length() == 4096
llm = OpenAI("gpt-3.5-turbo")
assert llm.get_max_context_length(max_new_tokens=0) == 4078
assert llm._get_max_output_length() == 4096
llm = OpenAI("gpt-3.5-turbo-instruct")
assert llm._get_max_output_length() is None

@pytest.mark.skipif(
"OPENAI_API_KEY" not in os.environ, reason="requires OpenAI API key"
Expand Down Expand Up @@ -2903,7 +2969,7 @@ def test_metadata(self, create_datadreamer):
@pytest.mark.order("last")
def test_petals_network(self, create_datadreamer):
with create_datadreamer():
llm = Petals("bigscience/bloom-560m", dtype=torch.float32)
llm = Petals("petals-team/StableBeluga2", dtype=torch.float32)
generated_texts = llm.run(
["A", "B"],
max_new_tokens=1,
Expand Down
Loading

0 comments on commit b48b58e

Please sign in to comment.