Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Oct 3, 2024
1 parent 69c1943 commit 7b7fc55
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 115 deletions.
Empty file.
43 changes: 10 additions & 33 deletions src/axolotl/core/chat/format/chatml.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional

from ..messages import MessageContents, Messages
from .shared import wrap_tools


def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
def format_message(
message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument
) -> Messages:
if message.is_chat_formatted:
return message

Expand All @@ -21,37 +28,7 @@ def format_message(message: Messages, message_index: Optional[int] = None) -> Me
)
message.content.append(MessageContents(type="text", value="\n", weight=0))

# loop over message.content by index to find tool calls, we need to wrap each with tags,
# so be wary of indexing issues when changing the list while iterating
# iterate over the range in reverse order to avoid index shifting
for i in range(len(message.content) - 1, -1, -1):
if message.content[i].type == "tool_call":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_call>\n", weight=message.weight
),
)
# make sure the actual tool call content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i, MessageContents(type="text", value="<tool_call>\n", weight=message.weight)
)
elif message.content[i].type == "tool_response":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_response>\n", weight=message.weight
),
)
# make sure the actual tool response content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i, MessageContents(type="text", value="<tool_response>\n", weight=message.weight)
)
message = wrap_tools(message)

message.is_chat_formatted = True
return message
37 changes: 5 additions & 32 deletions src/axolotl/core/chat/format/llama3x.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional

from ..messages import MessageContents, Messages
from .shared import wrap_tools


def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
Expand All @@ -24,38 +28,7 @@ def format_message(message: Messages, message_index: Optional[int] = None) -> Me
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
)

# loop over message.content by index to find tool calls, we need to wrap each with tags,
# so be wary of indexing issues when changing the list while iterating
# iterate over the range in reverse order to avoid index shifting
for i in range(len(message.content) - 1, -1, -1):
if message.content[i].type == "tool_call":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_call>\n", weight=message.weight
),
)
# make sure the actual tool call content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i, MessageContents(type="text", value="<tool_call>\n", weight=message.weight)
)
elif message.content[i].type == "tool_response":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_response>\n", weight=message.weight
),
)
# make sure the actual tool response content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i, MessageContents(type="text", value="<tool_response>\n", weight=message.weight)
)
message = wrap_tools(message)

if message_index == 0:
message.content.insert(
Expand Down
47 changes: 47 additions & 0 deletions src/axolotl/core/chat/format/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages


def wrap_tools(message: Messages):
# loop over message.content by index to find tool calls, we need to wrap each with tags,
# so be wary of indexing issues when changing the list while iterating.
# iterate over the range in reverse order to avoid index shifting
for i in range(len(message.content) - 1, -1, -1):
if message.content[i].type == "tool_call":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_call>\n", weight=message.weight
),
)
# make sure the actual tool call content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_call>\n", weight=message.weight
),
)
elif message.content[i].type == "tool_response":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_response>\n", weight=message.weight
),
)
# make sure the actual tool response content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_response>\n", weight=message.weight
),
)

return message
70 changes: 60 additions & 10 deletions src/axolotl/core/chat/messages.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
"""
internal message representations of chat messages
"""
import json
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, List, Optional, Union

from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer


class MessageRoles(str, Enum):
"""
Message roles for the system, user, assistant, and tools
"""

system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = (
"ipython" # pylint: disable=invalid-name # for responses from builtin tools
ipython = ( # pylint: disable=invalid-name
# for responses from builtin tools
"ipython"
)


class MessageContentTypes(str, Enum):
special_token = "special_token" # pylint: disable=invalid-name
"""
Message content types for text, image, audio, tool calls, and tool responses
"""

special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
Expand All @@ -29,28 +39,44 @@ class MessageContentTypes(str, Enum):

@dataclass
class SpecialToken(str, Enum):
bos_token = "bos_token" # pylint: disable=invalid-name
eos_token = "eos_token" # pylint: disable=invalid-name
"""
Special tokens for beginning of string and end of string
"""

bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105


@dataclass
class ToolCallFunction:
"""
Tool call function with name and arguments
"""

name: str
arguments: dict[str, str]


@dataclass
class Tool:
"""
Tool with description, function, and parameters
"""

description: str
function: ToolCallFunction
parameters: dict[str, str] # .properties


@dataclass
class ToolCallContents:
"""
Tool call contents with name, arguments, and optional id
"""

name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None
id: Optional[str] = None # pylint: disable=invalid-name

def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
Expand All @@ -61,9 +87,13 @@ def __str__(self) -> str:

@dataclass
class ToolResponseContents:
"""
Tool response contents with name, content, and optional id
"""

name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None
id: Optional[str] = None # pylint: disable=invalid-name

def __str__(self) -> str:
data = {"name": self.name, "content": self.content}
Expand All @@ -74,6 +104,10 @@ def __str__(self) -> str:

@dataclass
class MessageContents:
"""
Message contents with type, value, metadata, weight, newline, and end of contents
"""

type: Union[str, MessageContentTypes]
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
Expand All @@ -90,6 +124,10 @@ def __str__(self) -> str:

@dataclass
class Messages:
"""
Messages with role, content, metadata, weight, and chat formatting
"""

role: Union[MessageRoles, str] # allows for arbitrary roles
content: List["MessageContents"]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
Expand Down Expand Up @@ -148,6 +186,10 @@ def tokenized(

@dataclass
class Chats:
"""
top level data structure for chat conversations
"""

conversation: List[Messages]

def __str__(self) -> str:
Expand All @@ -173,7 +215,11 @@ def tokenized(

@dataclass
class ChatFormattedChats(Chats):
formatter: Callable # [[Union[dict, Chats]], Chats]
"""
Chat formatted chats with formatter and optional train on inputs
"""

formatter: Callable # [[Union[dict, Chats]], Chats]
train_on_inputs: bool = False

def __post_init__(self):
Expand All @@ -185,6 +231,10 @@ def __post_init__(self):

@dataclass
class PreferenceChats:
"""
representation for preference data for chat
"""

prompt: List[Messages]
chosen: Messages
rejected: Messages
14 changes: 12 additions & 2 deletions src/axolotl/core/datasets/chat.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union

from dacite import from_dict
from datasets import Dataset
from transformers import PreTrainedTokenizer

from axolotl.core.chat.messages import ChatFormattedChats, Chats
from axolotl.core.chat.messages import ChatFormattedChats


class TokenizedChatDataset(Dataset):
"""
Tokenized chat dataset
"""

def __init__(
self,
data: Dataset,
Expand All @@ -35,7 +42,10 @@ def map_fn(ex):
)
return ex.tokenized(model_transform)

num_proc = min(64, process_count if process_count else os.cpu_count())
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
tokenized_data = data.map(
map_fn,
num_proc=num_proc,
Expand Down
Loading

0 comments on commit 7b7fc55

Please sign in to comment.