From 3b44d238dd63ec94076c322385c1b02b9c5c5054 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Jul 2024 20:31:20 -0400 Subject: [PATCH] [Go] plugins/googlai: use batch embed RPC Use the BatchEmbedContents RPC so a single call can handle multiple documents. --- go/plugins/googleai/googleai.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index aabcf9080..6bd4b3eae 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -163,19 +163,23 @@ func DefineEmbedder(name string) *ai.Embedder { // requires state.mu func defineEmbedder(name string) *ai.Embedder { return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) { - // TODO: use the batch embedding API. em := state.client.EmbeddingModel(name) - var res ai.EmbedResponse + // TODO: set em.TaskType from EmbedRequest.Options? + batch := em.NewBatch() for _, doc := range input.Documents { parts, err := convertParts(doc.Content) if err != nil { return nil, err } - eres, err := em.EmbedContent(ctx, parts...) - if err != nil { - return nil, err - } - res.Embeddings = append(res.Embeddings, &ai.DocumentEmbedding{Embedding: eres.Embedding.Values}) + batch.AddContent(parts...) + } + bres, err := em.BatchEmbedContents(ctx, batch) + if err != nil { + return nil, err + } + var res ai.EmbedResponse + for _, emb := range bres.Embeddings { + res.Embeddings = append(res.Embeddings, &ai.DocumentEmbedding{Embedding: emb.Values}) } return &res, nil })