diff --git a/src/axolotl/core/chat/format/__init__.py b/src/axolotl/core/chat/format/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/core/chat/format/chatml.py b/src/axolotl/core/chat/format/chatml.py index d38265454..315d101a8 100644 --- a/src/axolotl/core/chat/format/chatml.py +++ b/src/axolotl/core/chat/format/chatml.py @@ -1,9 +1,16 @@ +""" +ChatML transformation functions for MessageContents +""" from typing import Optional from ..messages import MessageContents, Messages +from .shared import wrap_tools -def format_message(message: Messages, message_index: Optional[int] = None) -> Messages: +def format_message( + message: Messages, + message_index: Optional[int] = None, # pylint: disable=unused-argument +) -> Messages: if message.is_chat_formatted: return message @@ -21,37 +28,7 @@ def format_message(message: Messages, message_index: Optional[int] = None) -> Me ) message.content.append(MessageContents(type="text", value="\n", weight=0)) - # loop over message.content by index to find tool calls, we need to wrap each with tags, - # so be wary of indexing issues when changing the list while iterating - # iterate over the range in reverse order to avoid index shifting - for i in range(len(message.content) - 1, -1, -1): - if message.content[i].type == "tool_call": - # append a MessageContents text tag after - message.content.insert( - i + 1, - MessageContents( - type="text", value="\n", weight=message.weight - ), - ) - # make sure the actual tool call content ends with a newline - message.content[i].has_newline = True - # prepend a MessageContents text tag before - message.content.insert( - i, MessageContents(type="text", value="\n", weight=message.weight) - ) - elif message.content[i].type == "tool_response": - # append a MessageContents text tag after - message.content.insert( - i + 1, - MessageContents( - type="text", value="\n", weight=message.weight - ), - ) - # make sure the actual tool response content ends with a newline - message.content[i].has_newline = True - # prepend a MessageContents text tag before - message.content.insert( - i, MessageContents(type="text", value="\n", weight=message.weight) - ) + message = wrap_tools(message) + message.is_chat_formatted = True return message diff --git a/src/axolotl/core/chat/format/llama3x.py b/src/axolotl/core/chat/format/llama3x.py index e5bc7b849..352b74160 100644 --- a/src/axolotl/core/chat/format/llama3x.py +++ b/src/axolotl/core/chat/format/llama3x.py @@ -1,6 +1,10 @@ +""" +Llama 3.x chat formatting functions for MessageContents +""" from typing import Optional from ..messages import MessageContents, Messages +from .shared import wrap_tools def format_message(message: Messages, message_index: Optional[int] = None) -> Messages: @@ -24,38 +28,7 @@ def format_message(message: Messages, message_index: Optional[int] = None) -> Me MessageContents(type="text", value="<|eot_id|>", weight=message.weight) ) - # loop over message.content by index to find tool calls, we need to wrap each with tags, - # so be wary of indexing issues when changing the list while iterating - # iterate over the range in reverse order to avoid index shifting - for i in range(len(message.content) - 1, -1, -1): - if message.content[i].type == "tool_call": - # append a MessageContents text tag after - message.content.insert( - i + 1, - MessageContents( - type="text", value="\n", weight=message.weight - ), - ) - # make sure the actual tool call content ends with a newline - message.content[i].has_newline = True - # prepend a MessageContents text tag before - message.content.insert( - i, MessageContents(type="text", value="\n", weight=message.weight) - ) - elif message.content[i].type == "tool_response": - # append a MessageContents text tag after - message.content.insert( - i + 1, - MessageContents( - type="text", value="\n", weight=message.weight - ), - ) - # make sure the actual tool response content ends with a newline - message.content[i].has_newline = True - # prepend a MessageContents text tag before - message.content.insert( - i, MessageContents(type="text", value="\n", weight=message.weight) - ) + message = wrap_tools(message) if message_index == 0: message.content.insert( diff --git a/src/axolotl/core/chat/format/shared.py b/src/axolotl/core/chat/format/shared.py new file mode 100644 index 000000000..9efa2353d --- /dev/null +++ b/src/axolotl/core/chat/format/shared.py @@ -0,0 +1,47 @@ +""" +shared functions for format transforms +""" +from axolotl.core.chat.messages import MessageContents, Messages + + +def wrap_tools(message: Messages): + # loop over message.content by index to find tool calls, we need to wrap each with tags, + # so be wary of indexing issues when changing the list while iterating. + # iterate over the range in reverse order to avoid index shifting + for i in range(len(message.content) - 1, -1, -1): + if message.content[i].type == "tool_call": + # append a MessageContents text tag after + message.content.insert( + i + 1, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + # make sure the actual tool call content ends with a newline + message.content[i].has_newline = True + # prepend a MessageContents text tag before + message.content.insert( + i, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + elif message.content[i].type == "tool_response": + # append a MessageContents text tag after + message.content.insert( + i + 1, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + # make sure the actual tool response content ends with a newline + message.content[i].has_newline = True + # prepend a MessageContents text tag before + message.content.insert( + i, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + + return message diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py index 95fc334c6..76ab2f733 100644 --- a/src/axolotl/core/chat/messages.py +++ b/src/axolotl/core/chat/messages.py @@ -1,25 +1,35 @@ +""" +internal message representations of chat messages +""" import json -import logging from dataclasses import dataclass from enum import Enum from typing import Any, Callable, List, Optional, Union from transformers import PreTrainedTokenizer -from transformers import PreTrainedTokenizer class MessageRoles(str, Enum): + """ + Message roles for the system, user, assistant, and tools + """ + system = "system" # pylint: disable=invalid-name user = "user" # pylint: disable=invalid-name assistant = "assistant" # pylint: disable=invalid-name tool = "tool" # pylint: disable=invalid-name - ipython = ( - "ipython" # pylint: disable=invalid-name # for responses from builtin tools + ipython = ( # pylint: disable=invalid-name + # for responses from builtin tools + "ipython" ) class MessageContentTypes(str, Enum): - special_token = "special_token" # pylint: disable=invalid-name + """ + Message content types for text, image, audio, tool calls, and tool responses + """ + + special_token = "special_token" # pylint: disable=invalid-name # nosec B105 text = "text" # pylint: disable=invalid-name image = "image" # pylint: disable=invalid-name audio = "audio" # pylint: disable=invalid-name @@ -29,18 +39,30 @@ class MessageContentTypes(str, Enum): @dataclass class SpecialToken(str, Enum): - bos_token = "bos_token" # pylint: disable=invalid-name - eos_token = "eos_token" # pylint: disable=invalid-name + """ + Special tokens for beginning of string and end of string + """ + + bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105 + eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105 @dataclass class ToolCallFunction: + """ + Tool call function with name and arguments + """ + name: str arguments: dict[str, str] @dataclass class Tool: + """ + Tool with description, function, and parameters + """ + description: str function: ToolCallFunction parameters: dict[str, str] # .properties @@ -48,9 +70,13 @@ class Tool: @dataclass class ToolCallContents: + """ + Tool call contents with name, arguments, and optional id + """ + name: str arguments: dict[str, Union[str, int]] - id: Optional[str] = None + id: Optional[str] = None # pylint: disable=invalid-name def __str__(self) -> str: data = {"name": self.name, "arguments": self.arguments} @@ -61,9 +87,13 @@ def __str__(self) -> str: @dataclass class ToolResponseContents: + """ + Tool response contents with name, content, and optional id + """ + name: str content: Union[str, dict[str, Union[str, int, float]]] - id: Optional[str] = None + id: Optional[str] = None # pylint: disable=invalid-name def __str__(self) -> str: data = {"name": self.name, "content": self.content} @@ -74,6 +104,10 @@ def __str__(self) -> str: @dataclass class MessageContents: + """ + Message contents with type, value, metadata, weight, newline, and end of contents + """ + type: Union[str, MessageContentTypes] value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken] meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata @@ -90,6 +124,10 @@ def __str__(self) -> str: @dataclass class Messages: + """ + Messages with role, content, metadata, weight, and chat formatting + """ + role: Union[MessageRoles, str] # allows for arbitrary roles content: List["MessageContents"] meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata @@ -148,6 +186,10 @@ def tokenized( @dataclass class Chats: + """ + top level data structure for chat conversations + """ + conversation: List[Messages] def __str__(self) -> str: @@ -173,7 +215,11 @@ def tokenized( @dataclass class ChatFormattedChats(Chats): - formatter: Callable # [[Union[dict, Chats]], Chats] + """ + Chat formatted chats with formatter and optional train on inputs + """ + + formatter: Callable # [[Union[dict, Chats]], Chats] train_on_inputs: bool = False def __post_init__(self): @@ -185,6 +231,10 @@ def __post_init__(self): @dataclass class PreferenceChats: + """ + representation for preference data for chat + """ + prompt: List[Messages] chosen: Messages rejected: Messages diff --git a/src/axolotl/core/datasets/chat.py b/src/axolotl/core/datasets/chat.py index 2caa71975..863642b24 100644 --- a/src/axolotl/core/datasets/chat.py +++ b/src/axolotl/core/datasets/chat.py @@ -1,3 +1,6 @@ +""" +chat dataset module +""" import os from typing import Callable, Optional, Union @@ -5,10 +8,14 @@ from datasets import Dataset from transformers import PreTrainedTokenizer -from axolotl.core.chat.messages import ChatFormattedChats, Chats +from axolotl.core.chat.messages import ChatFormattedChats class TokenizedChatDataset(Dataset): + """ + Tokenized chat dataset + """ + def __init__( self, data: Dataset, @@ -35,7 +42,10 @@ def map_fn(ex): ) return ex.tokenized(model_transform) - num_proc = min(64, process_count if process_count else os.cpu_count()) + process_or_cpu_count: int = ( + process_count or os.cpu_count() # type: ignore[assignment] + ) + num_proc = min(64, process_or_cpu_count) tokenized_data = data.map( map_fn, num_proc=num_proc, diff --git a/src/axolotl/core/datasets/transforms/chat_builder.py b/src/axolotl/core/datasets/transforms/chat_builder.py index 810967d5d..3499ca09e 100644 --- a/src/axolotl/core/datasets/transforms/chat_builder.py +++ b/src/axolotl/core/datasets/transforms/chat_builder.py @@ -1,12 +1,22 @@ -from typing import Mapping, Union +""" +This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. +""" +from typing import Any, Mapping, Union -def chat_message_transform_builder( +def chat_message_transform_builder( # pylint: disable=dangerous-default-value train_on_inputs=False, conversations_field: str = "conversations", message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role" - message_field_content: Union[str, list[str]] = ["value", "text", "content"], # commonly "content" - message_field_training: Union[str, list[str]] = ["train", "weight"], # commonly "weight" + message_field_content: Union[str, list[str]] = [ + "value", + "text", + "content", + ], # commonly "content" + message_field_training: Union[str, list[str]] = [ + "train", + "weight", + ], # commonly "weight" ): """Builds a transform that takes a row from the dataset and converts it to a Chat @@ -28,12 +38,20 @@ def chat_message_transform_builder( A function that takes a list of conversations and returns a list of messages. """ - message_field_role = [message_field_role] if isinstance(message_field_role, str) else message_field_role + message_field_role = ( + [message_field_role] + if isinstance(message_field_role, str) + else message_field_role + ) message_field_content = ( - [message_field_content] if isinstance(message_field_content, str) else message_field_content + [message_field_content] + if isinstance(message_field_content, str) + else message_field_content ) message_weight_fields = ( - [message_field_training] if isinstance(message_field_training, str) else message_field_training + [message_field_training] + if isinstance(message_field_training, str) + else message_field_training ) role_value_mappings = { @@ -62,38 +80,62 @@ def chat_message_transform_builder( "ipython": 0, } - def transform_builder(sample: Mapping[str, any]): + def transform_builder(sample: Mapping[str, Any]): if conversations_field not in sample: raise ValueError(f"Field '{conversations_field}' not found in sample.") # if none of the role fields are in the message, raise an error - if not any(role in sample[conversations_field][0] for role in message_field_role): - raise ValueError(f"No role field found in message.") - else: - role_field = next(role for role in message_field_role if role in sample[conversations_field][0]) - if not any(field in sample[conversations_field][0] for field in message_field_content): - raise ValueError(f"No message_content field found in message.") - else: - message_content_field = next(field for field in message_field_content if field in sample[conversations_field][0]) - if not any(field in sample[conversations_field][0] for field in message_field_training): + if not any( + role in sample[conversations_field][0] for role in message_field_role + ): + raise ValueError("No role field found in message.") + role_field = next( + role + for role in message_field_role + if role in sample[conversations_field][0] + ) + if not any( + field in sample[conversations_field][0] for field in message_field_content + ): + raise ValueError("No message_content field found in message.") + message_content_field = next( + field + for field in message_field_content + if field in sample[conversations_field][0] + ) + if not any( + field in sample[conversations_field][0] for field in message_field_training + ): message_weight_field = None else: - message_weight_field = next(field for field in message_weight_fields if field in sample[conversations_field][0]) + message_weight_field = next( + field + for field in message_weight_fields + if field in sample[conversations_field][0] + ) messages = [] for message in sample[conversations_field]: role = role_value_mappings[message[role_field]] - weight = int(message[message_weight_field]) if message_weight_field else role_default_weights_mappings[role] + weight = ( + int(message[message_weight_field]) + if message_weight_field + else role_default_weights_mappings[role] + ) # TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents - messages.append({ - "role": role, - "content": [{ - "type": "text", - "value": message[message_content_field], - }], - "weight": weight, - }) + messages.append( + { + "role": role, + "content": [ + { + "type": "text", + "value": message[message_content_field], + } + ], + "weight": weight, + } + ) return {"conversation": messages} diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 86f98b78e..56e2e87cd 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -30,4 +30,4 @@ def load(strategy, tokenizer, cfg, ds_cfg): except Exception as exc: # pylint: disable=broad-exception-caught LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") raise exc - return None + return None diff --git a/src/axolotl/prompt_strategies/chat.py b/src/axolotl/prompt_strategies/chat.py index c54fdb8ab..35d764902 100644 --- a/src/axolotl/prompt_strategies/chat.py +++ b/src/axolotl/prompt_strategies/chat.py @@ -1,4 +1,7 @@ -from typing import Optional, Dict, Any, Callable +""" +Chat dataset wrapping strategy for new internal messages representations +""" +from typing import Any, Callable, Dict, Optional from axolotl.core.datasets.chat import TokenizedChatDataset from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder @@ -6,7 +9,17 @@ class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy): - def __init__(self, processor, message_transform=None, formatter=None, **kwargs): + """ + Chat dataset wrapping strategy for new internal messages representations + """ + + def __init__( + self, + processor, + message_transform=None, + formatter=None, + **kwargs, # pylint: disable=unused-argument + ): """ :param processor: tokenizer or image processor :param kwargs: @@ -21,7 +34,7 @@ def wrap_dataset( dataset, process_count: Optional[int] = None, keep_in_memory: Optional[bool] = False, - **kwargs, + **kwargs, # pylint: disable=unused-argument ): self.dataset = TokenizedChatDataset( dataset, @@ -53,15 +66,19 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): builder_kwargs["message_field_training"] = message_field_training chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml")) - format_message = lambda x: x + format_message = ( + lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment + ) if chat_template == "chatml": - from axolotl.core.chat.format.chatml import format_message + from axolotl.core.chat.format.chatml import format_message # noqa F811 if chat_template.startswith("llama3"): - from axolotl.core.chat.format.llama3x import format_message + from axolotl.core.chat.format.llama3x import format_message # noqa F811 message_transform: Callable = chat_message_transform_builder( train_on_inputs=ds_cfg.get("train_on_inputs", False), **builder_kwargs, ) - strategy = ChatMessageDatasetWrappingStrategy(tokenizer, message_transform=message_transform, formatter=format_message) + strategy = ChatMessageDatasetWrappingStrategy( + tokenizer, message_transform=message_transform, formatter=format_message + ) return strategy diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 12945a812..b5a6dfa7d 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -23,10 +23,11 @@ AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, AlpacaReflectionPTStrategy, + DatasetWrappingStrategy, GPTeacherPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy, - SummarizeTLDRPromptTokenizingStrategy, PromptTokenizingStrategy, DatasetWrappingStrategy, + SummarizeTLDRPromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, diff --git a/tests/core/chat/test_messages.py b/tests/core/chat/test_messages.py index 8383b723f..fb63e7230 100644 --- a/tests/core/chat/test_messages.py +++ b/tests/core/chat/test_messages.py @@ -1,3 +1,6 @@ +""" +Tests for the chat messages module +""" import unittest import pytest @@ -8,8 +11,8 @@ from axolotl.core.chat.messages import ChatFormattedChats, Chats -@pytest.fixture(scope="session") -def llama_tokenizer(): +@pytest.fixture(scope="session", name="llama_tokenizer") +def llama_tokenizer_fixture(): return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B") @@ -109,6 +112,10 @@ def llama_tokenizer_w_chatml(llama_tokenizer): class TestMessagesCase: + """ + Test cases for the chat messages module + """ + def test_tool_call_stringify(self): assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str( chat_msgs_as_obj.conversation[2].content[1].value