Skip to content

Commit

Permalink
feat: Add new chat methods for message management and filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
presbrey committed Feb 2, 2025
1 parent 8973daa commit 1827190
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 17 deletions.
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@ A Go package for managing AI chat sessions with support for message history, too
### Creating a New Chat

```go
import "github.com/presbrey/aichat"
// Create new chat in-memory
chat := new(aichat.Chat)

// Initialize with S3 storage
s3Storage := YourS3Implementation{} // Implements aichat.S3 interface
options := aichat.Options{S3: s3Storage}

// Create new chat
chat := aichat.NewChat("chat-123", options)

// Or use storage wrapper
storage := aichat.NewChatStorage(options)
chat, err := storage.Load("chat-123")
// Or use persistent/S3-compatible storage wrapper
opts := aichat.Options{...}
storage := aichat.NewChatStorage(opts)
chat, err := storage.Load("chat-f00ba0ba0")
```

### Convinence Methods and Direct Access

The `Chat` and `Message` structs are designed to be transparent - you are welcome to access their members directly in your applications. For example, you can directly access `chat.Messages`, `chat.Meta`, or `message.Role`.
The `Chat`, `Message`, and `ToolCall` structs are designed to be transparent - you are welcome to access their members directly in your applications. For example, you can directly access `chat.Messages`, `chat.Meta`, or `message.Role`.

For convenience, the package also provides several helper methods:

Expand All @@ -50,7 +45,14 @@ For convenience, the package also provides several helper methods:
- `AddAssistantToolCall(toolCalls)`: Add an assistant message with tool calls
- `LastMessage()`: Get the most recent message
- `LastMessageRole()`: Get the role of the most recent message
- `LastMessageByRole(role)`: Get the last message with a specific role
- `LastMessageByType(contentType)`: Get the last message with a specific content type
- `Range(fn)`: Iterate through messages with a callback function
- `RangeByRole(role, fn)`: Iterate through messages with a specific role
- `MessageCount()`: Get the total number of messages in the chat
- `MessageCountByRole(role)`: Get the count of messages with a specific role
- `ClearMessages()`: Remove all messages from the chat
- `RemoveLastMessage()`: Remove and return the last message from the chat

```go
// Example of helper method usage
Expand Down
58 changes: 58 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ func (chat *Chat) AddAssistantToolCall(toolCalls []ToolCall) {
chat.LastUpdated = time.Now()
}

// ClearMessages removes all messages from the chat
func (chat *Chat) ClearMessages() {
chat.Messages = []Message{}
chat.LastUpdated = time.Now()
}

// LastMessage returns the last message in the chat
func (chat *Chat) LastMessage() *Message {
if len(chat.Messages) == 0 {
Expand Down Expand Up @@ -118,6 +124,35 @@ func (chat *Chat) LastMessageRole() string {
return msg.Role
}

// LastMessageByType returns the last message in the chat with the given content type
func (chat *Chat) LastMessageByType(contentType string) *Message {
for i := len(chat.Messages) - 1; i >= 0; i-- {
msg := chat.Messages[i]
if content, ok := msg.Content.(map[string]interface{}); ok {
if t, ok := content["type"].(string); ok && t == contentType {
return &chat.Messages[i]
}
}
}
return nil
}

// MessageCount returns the total number of messages in the chat
func (chat *Chat) MessageCount() int {
return len(chat.Messages)
}

// MessageCountByRole returns the number of messages with a specific role
func (chat *Chat) MessageCountByRole(role string) int {
count := 0
for _, msg := range chat.Messages {
if msg.Role == role {
count++
}
}
return count
}

// Range iterates through messages
func (chat *Chat) Range(fn func(msg Message) error) error {
for _, msg := range chat.Messages {
Expand All @@ -128,6 +163,29 @@ func (chat *Chat) Range(fn func(msg Message) error) error {
return nil
}

// RangeByRole iterates through messages with a specific role
func (chat *Chat) RangeByRole(role string, fn func(msg Message) error) error {
for _, msg := range chat.Messages {
if msg.Role == role {
if err := fn(msg); err != nil {
return err
}
}
}
return nil
}

// RemoveLastMessage removes and returns the last message from the chat
func (chat *Chat) RemoveLastMessage() *Message {
if len(chat.Messages) == 0 {
return nil
}
lastMsg := chat.Messages[len(chat.Messages)-1]
chat.Messages = chat.Messages[:len(chat.Messages)-1]
chat.LastUpdated = time.Now()
return &lastMsg
}

// MarshalJSON implements custom JSON marshaling for the chat
func (chat *Chat) MarshalJSON() ([]byte, error) {
type Alias Chat
Expand Down
160 changes: 155 additions & 5 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package aichat_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"testing"

"github.com/presbrey/aichat"
Expand Down Expand Up @@ -220,10 +223,10 @@ func TestAddToolContent(t *testing.T) {

func TestAddToolContentError(t *testing.T) {
chat := &aichat.Chat{}

// Create a struct that will fail JSON marshaling
badContent := make(chan int)

err := chat.AddToolContent("test", "test-id", badContent)
if err == nil {
t.Error("Expected error when marshaling invalid content, got nil")
Expand All @@ -232,10 +235,10 @@ func TestAddToolContentError(t *testing.T) {

func TestUnmarshalJSONError(t *testing.T) {
chat := &aichat.Chat{}

// Invalid JSON that will cause an unmarshal error
invalidJSON := []byte(`{"messages": [{"role": "user", "content": invalid}]}`)

err := chat.UnmarshalJSON(invalidJSON)
if err == nil {
t.Error("Expected error when unmarshaling invalid JSON, got nil")
Expand All @@ -248,7 +251,7 @@ func TestContentPartsError(t *testing.T) {
// Content that will fail JSON marshaling
Content: []interface{}{make(chan int)},
}

parts, err := msg.ContentParts()
if err == nil {
t.Error("Expected error when processing invalid content parts, got nil")
Expand Down Expand Up @@ -296,3 +299,150 @@ func TestLastMessageByRole(t *testing.T) {
t.Error("Expected nil for non-existent role")
}
}

func TestLastMessageByType(t *testing.T) {
chat := new(aichat.Chat)

// Add messages with different content types
chat.AddRoleContent("user", map[string]interface{}{
"type": "text",
"text": "Hello",
})
chat.AddRoleContent("assistant", map[string]interface{}{
"type": "image",
"url": "test.jpg",
})
chat.AddRoleContent("user", map[string]interface{}{
"type": "text",
"text": "World",
})

// Test finding last message of each type
textMsg := chat.LastMessageByType("text")
if textMsg == nil || textMsg.Content.(map[string]interface{})["text"] != "World" {
t.Error("Expected last text message to be 'World'")
}

imageMsg := chat.LastMessageByType("image")
if imageMsg == nil || imageMsg.Content.(map[string]interface{})["url"] != "test.jpg" {
t.Error("Expected last image message to have URL 'test.jpg'")
}

// Test non-existent type
audioMsg := chat.LastMessageByType("audio")
if audioMsg != nil {
t.Error("Expected no message for non-existent type")
}
}

func TestMessageCount(t *testing.T) {
chat := new(aichat.Chat)

if chat.MessageCount() != 0 {
t.Error("Expected empty chat to have 0 messages")
}

chat.AddUserContent("Hello")
chat.AddAssistantContent("Hi")
chat.AddUserContent("How are you?")

if chat.MessageCount() != 3 {
t.Errorf("Expected 3 messages, got %d", chat.MessageCount())
}

chat.ClearMessages()

if chat.MessageCount() != 0 {
t.Error("Expected empty chat to have 0 messages")
}
}

func TestMessageCountByRole(t *testing.T) {
chat := new(aichat.Chat)

chat.AddUserContent("Hello")
chat.AddAssistantContent("Hi")
chat.AddUserContent("How are you?")
chat.AddToolRawContent("test-tool", "123", "result")

tests := []struct {
role string
expected int
}{
{"user", 2},
{"assistant", 1},
{"tool", 1},
{"system", 0},
}

for _, test := range tests {
count := chat.MessageCountByRole(test.role)
if count != test.expected {
t.Errorf("Expected %d messages for role '%s', got %d", test.expected, test.role, count)
}
}
}

func TestRangeByRole(t *testing.T) {
chat := new(aichat.Chat)

// Add test messages
chat.AddUserContent("U1")
chat.AddAssistantContent("A1")
chat.AddUserContent("U2")
chat.AddAssistantContent("A2")

// Test ranging over user messages
userMsgs := []string{}
err := chat.RangeByRole("user", func(msg aichat.Message) error {
content, ok := msg.Content.(string)
if !ok {
return fmt.Errorf("expected string content")
}
userMsgs = append(userMsgs, content)
return nil
})

if err != nil {
t.Error("Unexpected error:", err)
}
if !reflect.DeepEqual(userMsgs, []string{"U1", "U2"}) {
t.Errorf("Expected user messages [U1, U2], got %v", userMsgs)
}

// Test ranging with error
expectedErr := errors.New("test error")
err = chat.RangeByRole("assistant", func(msg aichat.Message) error {
return expectedErr
})
if err != expectedErr {
t.Error("Expected error to be propagated")
}
}

func TestRemoveLastMessage(t *testing.T) {
chat := new(aichat.Chat)

// Test removing from empty chat
if msg := chat.RemoveLastMessage(); msg != nil {
t.Error("Expected nil when removing from empty chat")
}

// Add and remove messages
chat.AddUserContent("First")
chat.AddAssistantContent("Second")
chat.AddUserContent("Third")

initialCount := chat.MessageCount()
lastMsg := chat.RemoveLastMessage()

if lastMsg == nil || lastMsg.Content != "Third" {
t.Error("Expected last message content to be 'Third'")
}
if chat.MessageCount() != initialCount-1 {
t.Error("Expected message count to decrease by 1")
}
if last := chat.LastMessage(); last == nil || last.Content != "Second" {
t.Error("Expected new last message to be 'Second'")
}
}

0 comments on commit 1827190

Please sign in to comment.