diff --git a/chat_test.go b/chat_test.go index 294d2f4..95558f9 100644 --- a/chat_test.go +++ b/chat_test.go @@ -217,3 +217,43 @@ func TestAddToolContent(t *testing.T) { }) } } + +func TestAddToolContentError(t *testing.T) { + chat := &aichat.Chat{} + + // Create a struct that will fail JSON marshaling + badContent := make(chan int) + + err := chat.AddToolContent("test", "test-id", badContent) + if err == nil { + t.Error("Expected error when marshaling invalid content, got nil") + } +} + +func TestUnmarshalJSONError(t *testing.T) { + chat := &aichat.Chat{} + + // Invalid JSON that will cause an unmarshal error + invalidJSON := []byte(`{"messages": [{"role": "user", "content": invalid}]}`) + + err := chat.UnmarshalJSON(invalidJSON) + if err == nil { + t.Error("Expected error when unmarshaling invalid JSON, got nil") + } +} + +func TestContentPartsError(t *testing.T) { + msg := &aichat.Message{ + Role: "user", + // Content that will fail JSON marshaling + Content: []interface{}{make(chan int)}, + } + + parts, err := msg.ContentParts() + if err == nil { + t.Error("Expected error when processing invalid content parts, got nil") + } + if parts != nil { + t.Error("Expected nil parts when error occurs") + } +} diff --git a/storage.go b/storage.go index 3f7fe98..5586b7d 100644 --- a/storage.go +++ b/storage.go @@ -31,7 +31,7 @@ func (chat *Chat) Load(ctx context.Context, key string) error { defer reader.Close() if err := json.NewDecoder(reader).Decode(chat); err != nil { - return fmt.Errorf("failed to decode session: %v", err) + return fmt.Errorf("failed to decode chat data: %v", err) } return nil @@ -45,7 +45,7 @@ func (chat *Chat) Save(ctx context.Context, key string) error { data, err := json.Marshal(chat) if err != nil { - return fmt.Errorf("failed to marshal session: %v", err) + return fmt.Errorf("failed to marshal chat data: %v", err) } return chat.Options.S3.Put(ctx, key, bytes.NewReader(data)) diff --git a/storage_test.go b/storage_test.go index a07791e..0c8ec9e 100644 --- a/storage_test.go +++ b/storage_test.go @@ -3,6 +3,7 @@ package aichat_test import ( "context" "errors" + "fmt" "io" "strings" "testing" @@ -43,6 +44,23 @@ func (m *mockS3) Delete(ctx context.Context, key string) error { return nil } +// mockS3WithErrors is a mock S3 implementation that returns errors +type mockS3WithErrors struct { + mockS3 + shouldErrorOnGet bool + returnInvalidJSON bool +} + +func (m *mockS3WithErrors) Get(ctx context.Context, key string) (io.ReadCloser, error) { + if m.shouldErrorOnGet { + return nil, fmt.Errorf("mock get error") + } + if m.returnInvalidJSON { + return io.NopCloser(strings.NewReader("invalid json")), nil + } + return m.mockS3.Get(ctx, key) +} + func TestChatStorage(t *testing.T) { ctx := context.Background() s3 := newMockS3() @@ -98,6 +116,53 @@ func TestStorageErrors(t *testing.T) { if err := session.Load(ctx, "test-key"); err == nil { t.Error("Expected error when loading with nil S3") } + + t.Run("get error", func(t *testing.T) { + s3 := &mockS3WithErrors{shouldErrorOnGet: true} + chat := &aichat.Chat{Options: aichat.Options{S3: s3}} + + err := chat.Load(ctx, "test-key") + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to get session from storage") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("decode error", func(t *testing.T) { + s3 := &mockS3WithErrors{returnInvalidJSON: true} + chat := &aichat.Chat{Options: aichat.Options{S3: s3}} + + err := chat.Load(ctx, "test-key") + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to decode chat data") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("marshal error", func(t *testing.T) { + s3 := newMockS3() + chat := &aichat.Chat{ + Options: aichat.Options{S3: s3}, + Messages: []aichat.Message{ + { + Role: "user", + Content: make(chan int), // channels cannot be marshaled to JSON + }, + }, + } + + err := chat.Save(ctx, "test-key") + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to marshal chat data") { + t.Errorf("unexpected error message: %v", err) + } + }) } func TestNewStorage(t *testing.T) {