Skip to content

Commit

Permalink
feat: add LastMessage, Range, and improve tool call handling methods
Browse files Browse the repository at this point in the history
  • Loading branch information
presbrey committed Jan 29, 2025
1 parent 99548f0 commit 8e2c68e
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 47 deletions.
69 changes: 52 additions & 17 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ type Chat struct {
Created time.Time `json:"created"`
// LastUpdated is the timestamp of the most recent message or modification
LastUpdated time.Time `json:"last_updated"`
// Metadata stores arbitrary session-related data
Metadata map[string]any `json:"metadata,omitempty"`
// Meta stores arbitrary session-related data
Meta map[string]any `json:"meta,omitempty"`
// Options contains the configuration for these chat sessions
Options Options `json:"-"`
}

// AddRoleContent adds a role and content to the session
// AddRoleContent adds a role and content to the c
func (chat *Chat) AddRoleContent(role string, content any) {
chat.Messages = append(chat.Messages, Message{
Role: role,
Expand All @@ -39,16 +39,44 @@ func (chat *Chat) AddRoleContent(role string, content any) {
chat.LastUpdated = time.Now()
}

// AddUserMessage adds a user message to the session
// AddUserMessage adds a user message to the chat
func (chat *Chat) AddUserMessage(content any) {
chat.AddRoleContent("user", content)
}

// AddAssistantMessage adds an assistant message to the session
// AddAssistantMessage adds an assistant message to the chat
func (chat *Chat) AddAssistantMessage(content any) {
chat.AddRoleContent("assistant", content)
}

// AddToolRawContent adds a raw content to the chat
func (chat *Chat) AddToolRawContent(name string, toolCallID string, content any) {
chat.Messages = append(chat.Messages, Message{
Role: "tool",
Name: name,
ToolCallID: toolCallID,
Content: content,
})
chat.LastUpdated = time.Now()
}

// AddToolContent adds a tool content to the chat
func (chat *Chat) AddToolContent(name string, toolCallID string, content any) error {
switch contentT := content.(type) {
case []byte:
content = string(contentT)
case string:
default:
b, err := json.Marshal(contentT)
if err != nil {
return err
}
content = string(b)
}
chat.AddToolRawContent(name, toolCallID, content)
return nil
}

// AddAssistantToolCall adds an assistant message with tool calls
func (chat *Chat) AddAssistantToolCall(toolCalls []ToolCall) {
chat.Messages = append(chat.Messages, Message{
Expand All @@ -59,18 +87,25 @@ func (chat *Chat) AddAssistantToolCall(toolCalls []ToolCall) {
chat.LastUpdated = time.Now()
}

// AddToolResponse adds a tool response message
func (chat *Chat) AddToolResponse(name, toolCallID, content string) {
chat.Messages = append(chat.Messages, Message{
Role: "tool",
Name: name,
ToolCallID: toolCallID,
Content: content,
})
chat.LastUpdated = time.Now()
// LastMessage returns the last message in the chat
func (chat *Chat) LastMessage() *Message {
if len(chat.Messages) == 0 {
return nil
}
return &chat.Messages[len(chat.Messages)-1]
}

// Range iterates through messages
func (chat *Chat) Range(fn func(msg Message) error) error {
for _, msg := range chat.Messages {
if err := fn(msg); err != nil {
return err
}
}
return nil
}

// MarshalJSON implements custom JSON marshaling for the session
// MarshalJSON implements custom JSON marshaling for the chat
func (chat *Chat) MarshalJSON() ([]byte, error) {
type Alias Chat
return json.Marshal(&struct {
Expand All @@ -80,7 +115,7 @@ func (chat *Chat) MarshalJSON() ([]byte, error) {
})
}

// UnmarshalJSON implements custom JSON unmarshaling for the session
// UnmarshalJSON implements custom JSON unmarshaling for the chat
func (chat *Chat) UnmarshalJSON(data []byte) error {
type Alias Chat
aux := &struct {
Expand All @@ -89,7 +124,7 @@ func (chat *Chat) UnmarshalJSON(data []byte) error {
Alias: (*Alias)(chat),
}
if err := json.Unmarshal(data, &aux); err != nil {
return fmt.Errorf("failed to unmarshal session: %v", err)
return fmt.Errorf("failed to unmarshal chat: %v", err)
}
return nil
}
2 changes: 1 addition & 1 deletion chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestChat(t *testing.T) {
session.AddAssistantToolCall(toolCalls)

// Test adding tool response
session.AddToolResponse(
session.AddToolContent(
"get_current_weather",
"call_9pw1qnYScqvGrCH58HWCvFH6",
`{"temperature": "22", "unit": "celsius", "description": "Sunny"}`,
Expand Down
8 changes: 4 additions & 4 deletions storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func TestChatStorage(t *testing.T) {

// Add some test data
session.AddUserMessage("Test message")
session.Metadata = make(map[string]any)
session.Metadata["test"] = "value"
session.Meta = make(map[string]any)
session.Meta["test"] = "value"

// Test saving
err := session.Save(ctx, "test-key")
Expand All @@ -75,8 +75,8 @@ func TestChatStorage(t *testing.T) {
t.Errorf("Expected %d messages, got %d", len(session.Messages), len(loadedSession.Messages))
}

if loadedSession.Metadata["test"] != "value" {
t.Errorf("Expected metadata value 'value', got %v", loadedSession.Metadata["test"])
if loadedSession.Meta["test"] != "value" {
t.Errorf("Expected metadata value 'value', got %v", loadedSession.Meta["test"])
}

err = loadedSession.Delete(ctx, "test-key")
Expand Down
46 changes: 24 additions & 22 deletions tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (f *Function) ArgumentsMap() (map[string]interface{}, error) {
// RangePendingToolCalls iterates through messages to find and process tool calls that haven't received a response.
// It performs two passes: first to identify which tool calls have responses, then to process pending calls.
// The provided function is called for each pending tool call.
func (chat *Chat) RangePendingToolCalls(fn func(toolCall *ToolCallMessage) error) error {
func (chat *Chat) RangePendingToolCalls(fn func(toolCallContext *ToolCallContext) error) error {
// Create a map to track which tool calls have responses
responded := make(map[string]bool)

Expand All @@ -45,47 +45,49 @@ func (chat *Chat) RangePendingToolCalls(fn func(toolCall *ToolCallMessage) error
}
}

// Second pass: call the tool
for _, msg := range chat.Messages {
return chat.Range(func(msg Message) error {
if len(msg.ToolCalls) == 0 {
return nil
}
for _, call := range msg.ToolCalls {
if !responded[call.ID] {
if err := fn(&ToolCallMessage{
Chat: chat,
ToolCall: &call,
}); err != nil {
return err
}
responded[call.ID] = true
if responded[call.ID] {
continue
}
if err := fn(&ToolCallContext{
Chat: chat,
ToolCall: &call,
}); err != nil {
return err
}
responded[call.ID] = true
}
}

return nil
return nil
})
}

// ToolCallMessage represents a tool call within a chat context, managing the lifecycle
// ToolCallContext represents a tool call within a chat context, managing the lifecycle
// of a single tool invocation including its execution and response handling.
type ToolCallMessage struct {
type ToolCallContext struct {
ToolCall *ToolCall
Chat *Chat
}

// Name returns the name of the function
func (tcs *ToolCallMessage) Name() string {
return tcs.ToolCall.Function.Name
func (tcc *ToolCallContext) Name() string {
return tcc.ToolCall.Function.Name
}

// Arguments returns the arguments to the function as a map
func (tcs *ToolCallMessage) Arguments() (map[string]any, error) {
return tcs.ToolCall.Function.ArgumentsMap()
func (tcc *ToolCallContext) Arguments() (map[string]any, error) {
return tcc.ToolCall.Function.ArgumentsMap()
}

// Return sends the result of the function call back to the chat
func (tcs *ToolCallMessage) Return(result map[string]any) error {
func (tcc *ToolCallContext) Return(result map[string]any) error {
jsonData, err := json.Marshal(result)
if err != nil {
return fmt.Errorf("failed to marshal result: %v", err)
}
tcs.Chat.AddToolResponse(tcs.Name(), tcs.ToolCall.ID, string(jsonData))
tcc.Chat.AddToolContent(tcc.Name(), tcc.ToolCall.ID, string(jsonData))
return nil
}
6 changes: 3 additions & 3 deletions tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func TestRangePendingToolCalls(t *testing.T) {
chat := &aichat.Chat{Messages: tt.messages}
var gotIDs []string

err := chat.RangePendingToolCalls(func(toolCall *aichat.ToolCallMessage) error {
err := chat.RangePendingToolCalls(func(toolCall *aichat.ToolCallContext) error {
gotIDs = append(gotIDs, toolCall.ToolCall.ID)
return nil
})
Expand All @@ -206,7 +206,7 @@ func TestRangePendingToolCalls(t *testing.T) {
}

expectedErr := "test error"
err := chat.RangePendingToolCalls(func(toolCall *aichat.ToolCallMessage) error {
err := chat.RangePendingToolCalls(func(toolCall *aichat.ToolCallContext) error {
return errors.New(expectedErr)
})

Expand Down Expand Up @@ -269,7 +269,7 @@ func TestToolCallMessage(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chat := &aichat.Chat{}
tcm := &aichat.ToolCallMessage{
tcm := &aichat.ToolCallContext{
ToolCall: tt.toolCall,
Chat: chat,
}
Expand Down

0 comments on commit 8e2c68e

Please sign in to comment.