Skip to content

Commit

Permalink
genai: omit empty text parts from session history (#226)
Browse files Browse the repository at this point in the history
When adding to session history, remove empty text parts.

This is the same change as
googleapis/google-cloud-go#10362.

It fixes TestLive/tools/direct.
  • Loading branch information
jba authored Nov 18, 2024
1 parent ae7597a commit 30b5de5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
14 changes: 13 additions & 1 deletion genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,20 @@ func (cs *ChatSession) addToHistory(cands []*Candidate) bool {
return false
}
c.Role = roleModel
cs.History = append(cs.History, c)
cs.History = append(cs.History, copySanitizedModelContent(c))
return true
}
return false
}

// copySanitizedModelContent creates a (shallow) copy of c with role set to
// model and empty text parts removed.
func copySanitizedModelContent(c *Content) *Content {
newc := &Content{Role: roleModel}
for _, part := range c.Parts {
if t, ok := part.(Text); !ok || len(string(t)) > 0 {
newc.Parts = append(newc.Parts, part)
}
}
return newc
}
7 changes: 6 additions & 1 deletion genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"testing"
"time"

"github.com/googleapis/gax-go/v2/apierror"
"google.golang.org/api/googleapi"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
Expand Down Expand Up @@ -391,13 +392,17 @@ func TestLive(t *testing.T) {
if c := "Mountain View"; !strings.Contains(locArg, c) {
t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, locArg, c)
}
res, err = session.SendMessage(ctx, FunctionResponse{
res, err = session.SendMessage(ctx, Text("response:"), FunctionResponse{
Name: movieTool.FunctionDeclarations[0].Name,
Response: map[string]any{
"theater": "AMC16",
},
})
if err != nil {
if ae, ok := err.(*apierror.APIError); ok {
t.Fatal(ae.Unwrap())

}
t.Fatal(err)
}
checkMatch(t, responseString(res), "AMC")
Expand Down

0 comments on commit 30b5de5

Please sign in to comment.