Skip to content

Commit

Permalink
Accept tool use messages via public Chat APIs
Browse files Browse the repository at this point in the history
Closes #20
  • Loading branch information
ncoghlan committed Mar 3, 2025
1 parent 200fa67 commit 1c2c07c
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 79 deletions.
6 changes: 4 additions & 2 deletions examples/tool-use-multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def is_prime(n: int) -> bool:
return False
return True

model = lms.llm("qwen2.5-7b-instruct")
chat = lms.Chat()
model = lms.llm("qwen2.5-7b-instruct-1m")
model.act(
"Is the result of 12345 + 45668 a prime? Think step by step.",
[add, is_prime],
on_message=print,
on_message=chat.append,
)
print(chat)
6 changes: 4 additions & 2 deletions examples/tool-use.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ def multiply(a: float, b: float) -> float:
"""Given two numbers a and b. Returns the product of them."""
return a * b

model = lms.llm("qwen2.5-7b-instruct")
chat = lms.Chat()
model = lms.llm("qwen2.5-7b-instruct-1m")
model.act(
"What is the result of 12345 multiplied by 54321?",
[multiply],
on_message=print,
on_message=chat.append,
)
print(chat)
118 changes: 73 additions & 45 deletions src/lmstudio/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
ChatMessagePartFileDataDict as _FileHandleDict,
ChatMessagePartTextData as TextData,
ChatMessagePartTextDataDict as TextDataDict,
ChatMessagePartToolCallRequestData as _ToolCallRequestData,
ChatMessagePartToolCallRequestDataDict as _ToolCallRequestDataDict,
ChatMessagePartToolCallResultData as _ToolCallResultData,
ChatMessagePartToolCallResultDataDict as _ToolCallResultDataDict,
ChatMessagePartToolCallRequestData as ToolCallRequestData,
ChatMessagePartToolCallRequestDataDict as ToolCallRequestDataDict,
ChatMessagePartToolCallResultData as ToolCallResultData,
ChatMessagePartToolCallResultDataDict as ToolCallResultDataDict,
# Private until LM Studio file handle support stabilizes
# FileType,
FilesRpcUploadFileBase64Parameter,
# Private until user level tool call request management is defined
ToolCallRequest as _ToolCallRequest,
ToolCallRequest as ToolCallRequest,
FunctionToolCallRequestDict as ToolCallRequestDict,
)

__all__ = [
Expand All @@ -81,8 +81,8 @@
"TextData",
"TextDataDict",
# Private until user level tool call request management is defined
"_ToolCallRequest", # Other modules need this to be exported
"_ToolCallResultData", # Other modules need this to be exported
"ToolCallRequest",
"ToolCallResultData",
# "ToolCallRequest",
# "ToolCallResult",
"UserMessageContent",
Expand All @@ -109,11 +109,11 @@
SystemPromptContentDict = TextDataDict
UserMessageContent = TextData | _FileHandle
UserMessageContentDict = TextDataDict | _FileHandleDict
AssistantResponseContent = TextData | _FileHandle | _ToolCallRequestData
AssistantResponseContentDict = TextDataDict | _FileHandleDict | _ToolCallRequestDataDict
ChatMessageContent = TextData | _FileHandle | _ToolCallRequestData | _ToolCallResultData
AssistantResponseContent = TextData | _FileHandle
AssistantResponseContentDict = TextDataDict | _FileHandleDict
ChatMessageContent = TextData | _FileHandle | ToolCallRequestData | ToolCallResultData
ChatMessageContentDict = (
TextDataDict | _FileHandleDict | _ToolCallRequestData | _ToolCallResultDataDict
TextDataDict | _FileHandleDict | ToolCallRequestData | ToolCallResultDataDict
)


Expand All @@ -132,7 +132,13 @@ def _to_history_content(self) -> str:
AnyUserMessageInput = UserMessageInput | UserMessageMultiPartInput
AssistantResponseInput = str | AssistantResponseContent | AssistantResponseContentDict
AnyAssistantResponseInput = AssistantResponseInput | _ServerAssistantResponse
_ToolCallResultInput = _ToolCallResultData | _ToolCallResultDataDict
ToolCallRequestInput = (
ToolCallRequest
| ToolCallRequestDict
| ToolCallRequestData
| ToolCallRequestDataDict
)
ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict
ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict
ChatMessageMultiPartInput = UserMessageMultiPartInput
AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput
Expand Down Expand Up @@ -355,6 +361,21 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
if role == "user":
messages = cast(AnyUserMessageInput, content)
return self.add_user_message(messages)
# Assistant responses consist of a text response with zero or more tool requests
if role == "assistant":
if _is_chat_message_input(content):
response = cast(AssistantResponseInput, content)
return self.add_assistant_response(response)
try:
(response_content, *tool_request_contents) = content
except ValueError:
raise LMStudioValueError(
f"Unable to parse assistant response content: {content}"
) from None
response = cast(AssistantResponseInput, response_content)
tool_requests = cast(Iterable[ToolCallRequest], tool_request_contents)
return self.add_assistant_response(response, tool_requests)

# Other roles do not accept multi-part messages, so ensure there
# is exactly one content item given. We still accept iterables because
# that's how the wire format is defined and we want to accept that.
Expand All @@ -368,17 +389,13 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
except ValueError:
err_msg = f"{role!r} role does not support multi-part message content."
raise LMStudioValueError(err_msg) from None

match role:
case "system":
prompt = cast(SystemPromptInput, content_item)
result = self.add_system_prompt(prompt)
case "assistant":
response = cast(AssistantResponseInput, content_item)
result = self.add_assistant_response(response)
case "tool":
tool_result = cast(_ToolCallResultInput, content_item)
result = self._add_tool_result(tool_result)
tool_result = cast(ToolCallResultInput, content_item)
result = self.add_tool_result(tool_result)
case _:
raise LMStudioValueError(f"Unknown history role: {role}")
return result
Expand Down Expand Up @@ -556,11 +573,14 @@ def add_user_message(
@classmethod
def _parse_assistant_response(
cls, response: AnyAssistantResponseInput
) -> AssistantResponseContent:
) -> TextData | _FileHandle:
# Note: tool call requests are NOT accepted here, as they're expected
# to follow an initial text response
# It's not clear if file handles should be accepted as it's not obvious
# how client applications should process those (even though the API
# format nominally permits them here)
match response:
# Sadly, we can't use the union type aliases for matching,
# since the compiler needs visibility into every match target
case TextData() | _FileHandle() | _ToolCallRequestData():
case TextData() | _FileHandle():
return response
case str():
return TextData(text=response)
Expand All @@ -575,59 +595,67 @@ def _parse_assistant_response(
}:
# We accept snake_case here for consistency, but don't really expect it
return _FileHandle._from_any_dict(response)
case {"toolCallRequest": [*_]} | {"tool_call_request": [*_]}:
# We accept snake_case here for consistency, but don't really expect it
return _ToolCallRequestData._from_any_dict(response)
case _:
raise LMStudioValueError(
f"Unable to parse assistant response content: {response}"
)

@classmethod
def _parse_tool_call_request(
cls, request: ToolCallRequestInput
) -> ToolCallRequestData:
match request:
case ToolCallRequestData():
return request
case ToolCallRequest():
return ToolCallRequestData(tool_call_request=request)
case {"type": "toolCallRequest"}:
return ToolCallRequestData._from_any_dict(request)
case {"toolCallRequest": [*_]} | {"tool_call_request": [*_]}:
request_details = ToolCallRequest._from_any_dict(request)
return ToolCallRequestData(tool_call_request=request_details)
case _:
raise LMStudioValueError(
f"Unable to parse tool call request content: {request}"
)

@sdk_public_api()
def add_assistant_response(
self, response: AnyAssistantResponseInput
self,
response: AnyAssistantResponseInput,
tool_call_requests: Iterable[ToolCallRequestInput] = (),
) -> AssistantResponse:
"""Add a new 'assistant' response to the chat history."""
self._raise_if_consecutive(AssistantResponse.role, "assistant responses")
message_data = self._parse_assistant_response(response)
message = AssistantResponse(content=[message_data])
self._messages.append(message)
return message

def _add_assistant_tool_requests(
self, response: _ServerAssistantResponse, requests: Iterable[_ToolCallRequest]
) -> AssistantResponse:
self._raise_if_consecutive(AssistantResponse.role, "assistant responses")
message_text = self._parse_assistant_response(response)
request_parts = [
_ToolCallRequestData(tool_call_request=req) for req in requests
self._parse_tool_call_request(req) for req in tool_call_requests
]
message = AssistantResponse(content=[message_text, *request_parts])
self._messages.append(message)
return message

@classmethod
def _parse_tool_result(cls, result: _ToolCallResultInput) -> _ToolCallResultData:
def _parse_tool_result(cls, result: ToolCallResultInput) -> ToolCallResultData:
match result:
# Sadly, we can't use the union type aliases for matching,
# since the compiler needs visibility into every match target
case _ToolCallResultData():
case ToolCallResultData():
return result
case {"toolCallId": _, "content": _} | {"tool_call_id": _, "content": _}:
# We accept snake_case here for consistency, but don't really expect it
return _ToolCallResultData.from_dict(result)
return ToolCallResultData.from_dict(result)
case _:
raise LMStudioValueError(f"Unable to parse tool result: {result}")

def _add_tool_results(
self, results: Iterable[_ToolCallResultInput]
def add_tool_results(
self, results: Iterable[ToolCallResultInput]
) -> ToolResultMessage:
"""Add multiple tool results to the chat history as a single message."""
message_content = [self._parse_tool_result(result) for result in results]
message = ToolResultMessage(content=message_content)
self._messages.append(message)
return message

def _add_tool_result(self, result: _ToolCallResultInput) -> ToolResultMessage:
def add_tool_result(self, result: ToolCallResultInput) -> ToolResultMessage:
"""Add a new tool result to the chat history."""
# Consecutive tool result messages are allowed,
# so skip checking if the last message was a tool result
Expand Down
37 changes: 24 additions & 13 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
sdk_public_type,
_truncate_traceback,
)
from .history import AssistantResponse, Chat, _ToolCallRequest, _ToolCallResultData
from .history import AssistantResponse, Chat, ToolCallRequest, ToolCallResultData
from .schemas import (
AnyLMStudioStruct,
DictObject,
Expand Down Expand Up @@ -1067,7 +1067,7 @@ class PredictionFragmentEvent(ChannelRxEvent[LlmPredictionFragment]):
pass


class PredictionToolCallEvent(ChannelRxEvent[_ToolCallRequest]):
class PredictionToolCallEvent(ChannelRxEvent[ToolCallRequest]):
pass


Expand Down Expand Up @@ -1114,7 +1114,7 @@ def __init__(
on_prompt_processing_progress: PromptProcessingCallback | None = None,
# The remaining options are only relevant for multi-round tool actions
handle_invalid_tool_request: Callable[
[LMStudioPredictionError, _ToolCallRequest | None], str
[LMStudioPredictionError, ToolCallRequest | None], str
]
| None = None,
llm_tools: LlmToolUseSettingToolArray | None = None,
Expand Down Expand Up @@ -1224,7 +1224,7 @@ def iter_message_events(
"toolCallRequest": tool_call_request,
}:
yield PredictionToolCallEvent(
_ToolCallRequest._from_api_dict(tool_call_request)
ToolCallRequest._from_api_dict(tool_call_request)
)
case {
"type": "toolCallGenerationFailed",
Expand Down Expand Up @@ -1267,10 +1267,17 @@ def handle_rx_event(self, event: PredictionRxEvent) -> None:
self._report_prompt_processing_progress(progress)
case PredictionFragmentEvent(_fragment):
if self._on_first_token is not None:
self._on_first_token()
self._logger.debug("Invoking on_first_token callback")
err_msg = f"First token callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
self._on_first_token()
self._on_first_token = None
if self._on_prediction_fragment is not None:
self._on_prediction_fragment(_fragment)
# TODO: Define an even-spammier-than-debug trace logging level for this
# self._logger.trace("Invoking on_prediction_fragment callback")
err_msg = f"Prediction fragment callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
self._on_prediction_fragment(_fragment)
pass
case PredictionToolCallEvent(_tool_call_request):
# Handled externally when iterating over events
Expand All @@ -1294,32 +1301,34 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
assert self._on_prompt_processing_progress is not None
err_msg = f"Prediction progress callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
self._logger.debug("Invoking on_prompt_processing_progress callback")
self._on_prompt_processing_progress(progress)

def _handle_invalid_tool_request(
self, err_msg: str, request: _ToolCallRequest | None = None
self, err_msg: str, request: ToolCallRequest | None = None
) -> str:
exc = LMStudioPredictionError(err_msg)
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
if _on_handle_invalid_tool_request is not None:
# Allow users to override the error message, or force an exception
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
err_msg = _on_handle_invalid_tool_request(exc, request)
if request is not None:
return err_msg
# We don't allow users to prevent the exception when there's no request
raise LMStudioPredictionError(err_msg)

def request_tool_call(
self, request: _ToolCallRequest
) -> Callable[[], _ToolCallResultData]:
self, request: ToolCallRequest
) -> Callable[[], ToolCallResultData]:
tool_name = request.name
tool_call_id = request.id
client_tool = self._client_tools.get(tool_name, None)
if client_tool is None:
err_msg = self._handle_invalid_tool_request(
f"Cannot find tool with name {tool_name}.", request
)
result = _ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
return lambda: result
# Validate parameters against their specification
params_struct, implementation = client_tool
Expand All @@ -1330,14 +1339,14 @@ def request_tool_call(
err_msg = self._handle_invalid_tool_request(
f"Failed to parse arguments for tool {tool_name}: {exc}", request
)
result = _ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
return lambda: result
kwds = to_builtins(parsed_kwds)

# Allow caller to schedule the tool call request for background execution
def _call_requested_tool() -> _ToolCallResultData:
def _call_requested_tool() -> ToolCallResultData:
call_result = implementation(**kwds)
return _ToolCallResultData(
return ToolCallResultData(
content=json.dumps(call_result), tool_call_id=tool_call_id
)

Expand Down Expand Up @@ -1980,6 +1989,8 @@ def __init__(self, model_identifier: str, session: TSession) -> None:
"""Initialize the LM Studio model reference."""
self.identifier = model_identifier
self._session = session
self._logger = logger = get_logger(type(self).__name__)
logger.update_context(model_identifier=model_identifier)

def __repr__(self) -> str:
return f"{type(self).__name__}(identifier={self.identifier!r})"
Expand Down
Loading

0 comments on commit 1c2c07c

Please sign in to comment.