From 30b5de5ce554c8b44df27e41190f42929997dc2c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 18 Nov 2024 15:01:57 -0500 Subject: [PATCH] genai: omit empty text parts from session history (#226) When adding to session history, remove empty text parts. This is the same change as https://github.com/googleapis/google-cloud-go/pull/10362. It fixes TestLive/tools/direct. --- genai/chat.go | 14 +++++++++++++- genai/client_test.go | 7 ++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/genai/chat.go b/genai/chat.go index 334df1d..5e44e42 100644 --- a/genai/chat.go +++ b/genai/chat.go @@ -70,8 +70,20 @@ func (cs *ChatSession) addToHistory(cands []*Candidate) bool { return false } c.Role = roleModel - cs.History = append(cs.History, c) + cs.History = append(cs.History, copySanitizedModelContent(c)) return true } return false } + +// copySanitizedModelContent creates a (shallow) copy of c with role set to +// model and empty text parts removed. +func copySanitizedModelContent(c *Content) *Content { + newc := &Content{Role: roleModel} + for _, part := range c.Parts { + if t, ok := part.(Text); !ok || len(string(t)) > 0 { + newc.Parts = append(newc.Parts, part) + } + } + return newc +} diff --git a/genai/client_test.go b/genai/client_test.go index 0717a03..777761a 100644 --- a/genai/client_test.go +++ b/genai/client_test.go @@ -30,6 +30,7 @@ import ( "testing" "time" + "github.com/googleapis/gax-go/v2/apierror" "google.golang.org/api/googleapi" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -391,13 +392,17 @@ func TestLive(t *testing.T) { if c := "Mountain View"; !strings.Contains(locArg, c) { t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, locArg, c) } - res, err = session.SendMessage(ctx, FunctionResponse{ + res, err = session.SendMessage(ctx, Text("response:"), FunctionResponse{ Name: movieTool.FunctionDeclarations[0].Name, Response: map[string]any{ "theater": "AMC16", }, }) if err != nil { + if ae, ok := err.(*apierror.APIError); ok { + t.Fatal(ae.Unwrap()) + + } t.Fatal(err) } checkMatch(t, responseString(res), "AMC")