diff --git a/.changeset/dirty-wolves-sin.md b/.changeset/dirty-wolves-sin.md new file mode 100644 index 00000000..170f56df --- /dev/null +++ b/.changeset/dirty-wolves-sin.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": minor +--- + +Expand the model's `countTokens` method to alternatively accept a `GenerateContentRequest`. diff --git a/docs/reference/main/generative-ai.counttokensrequest.contents.md b/docs/reference/main/generative-ai.counttokensrequest.contents.md index 693fe3d5..948e6c79 100644 --- a/docs/reference/main/generative-ai.counttokensrequest.contents.md +++ b/docs/reference/main/generative-ai.counttokensrequest.contents.md @@ -7,5 +7,5 @@ **Signature:** ```typescript -contents: Content[]; +contents?: Content[]; ``` diff --git a/docs/reference/main/generative-ai.counttokensrequest.generatecontentrequest.md b/docs/reference/main/generative-ai.counttokensrequest.generatecontentrequest.md new file mode 100644 index 00000000..86bf60ad --- /dev/null +++ b/docs/reference/main/generative-ai.counttokensrequest.generatecontentrequest.md @@ -0,0 +1,11 @@ + + +[Home](./index.md) > [@google/generative-ai](./generative-ai.md) > [CountTokensRequest](./generative-ai.counttokensrequest.md) > [generateContentRequest](./generative-ai.counttokensrequest.generatecontentrequest.md) + +## CountTokensRequest.generateContentRequest property + +**Signature:** + +```typescript +generateContentRequest?: GenerateContentRequest; +``` diff --git a/docs/reference/main/generative-ai.counttokensrequest.md b/docs/reference/main/generative-ai.counttokensrequest.md index 7895bfcc..5217e1b3 100644 --- a/docs/reference/main/generative-ai.counttokensrequest.md +++ b/docs/reference/main/generative-ai.counttokensrequest.md @@ -4,7 +4,9 @@ ## CountTokensRequest interface -Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md) +Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md). + +The request must contain either a [Content](./generative-ai.content.md) array or a [GenerateContentRequest](./generative-ai.generatecontentrequest.md), but not both. If both are provided then a [GoogleGenerativeAIRequestInputError](./generative-ai.googlegenerativeairequestinputerror.md) is thrown. **Signature:** @@ -16,5 +18,6 @@ export interface CountTokensRequest | Property | Modifiers | Type | Description | | --- | --- | --- | --- | -| [contents](./generative-ai.counttokensrequest.contents.md) | | [Content](./generative-ai.content.md)\[\] | | +| [contents?](./generative-ai.counttokensrequest.contents.md) | | [Content](./generative-ai.content.md)\[\] | _(Optional)_ | +| [generateContentRequest?](./generative-ai.counttokensrequest.generatecontentrequest.md) | | [GenerateContentRequest](./generative-ai.generatecontentrequest.md) | _(Optional)_ | diff --git a/docs/reference/main/generative-ai.md b/docs/reference/main/generative-ai.md index ec5a3a7c..e2719d97 100644 --- a/docs/reference/main/generative-ai.md +++ b/docs/reference/main/generative-ai.md @@ -40,7 +40,7 @@ | [CitationSource](./generative-ai.citationsource.md) | A single citation source. | | [Content](./generative-ai.content.md) | Content type for both prompts and response candidates. | | [ContentEmbedding](./generative-ai.contentembedding.md) | A single content embedding. | -| [CountTokensRequest](./generative-ai.counttokensrequest.md) | Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md) | +| [CountTokensRequest](./generative-ai.counttokensrequest.md) |
Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md).
The request must contain either a [Content](./generative-ai.content.md) array or a [GenerateContentRequest](./generative-ai.generatecontentrequest.md), but not both. If both are provided then a [GoogleGenerativeAIRequestInputError](./generative-ai.googlegenerativeairequestinputerror.md) is thrown.
| | [CountTokensResponse](./generative-ai.counttokensresponse.md) | Response from calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md). | | [EmbedContentRequest](./generative-ai.embedcontentrequest.md) | Params for calling [GenerativeModel.embedContent()](./generative-ai.generativemodel.embedcontent.md) | | [EmbedContentResponse](./generative-ai.embedcontentresponse.md) | Response from calling [GenerativeModel.embedContent()](./generative-ai.generativemodel.embedcontent.md). | diff --git a/packages/main/src/models/generative-model.test.ts b/packages/main/src/models/generative-model.test.ts index 5f4ce1c9..f61a1375 100644 --- a/packages/main/src/models/generative-model.test.ts +++ b/packages/main/src/models/generative-model.test.ts @@ -18,6 +18,7 @@ import { expect, use } from "chai"; import { GenerativeModel } from "./generative-model"; import * as sinonChai from "sinon-chai"; import { + CountTokensRequest, FunctionCallingMode, HarmBlockThreshold, HarmCategory, @@ -319,4 +320,34 @@ describe("GenerativeModel", () => { ); restore(); }); + it("countTokens errors if contents and generateContentRequest are both defined", async () => { + const genModel = new GenerativeModel( + "apiKey", + { + model: "my-model", + }, + { + apiVersion: "v2000", + }, + ); + const mockResponse = getMockResponse( + "unary-success-basic-reply-short.json", + ); + const makeRequestStub = stub(request, "makeRequest").resolves( + mockResponse as Response, + ); + const countTokensRequest: CountTokensRequest = { + contents: [{ role: "user", parts: [{ text: "hello" }] }], + generateContentRequest: { + contents: [{ role: "user", parts: [{ text: "hello" }] }], + }, + }; + await expect( + genModel.countTokens(countTokensRequest), + ).to.eventually.be.rejectedWith( + "CountTokensRequest must have one of contents or generateContentRequest, not both.", + ); + expect(makeRequestStub).to.not.be.called; + restore(); + }); }); diff --git a/packages/main/src/models/generative-model.ts b/packages/main/src/models/generative-model.ts index df1ecfa9..4bc2c63e 100644 --- a/packages/main/src/models/generative-model.ts +++ b/packages/main/src/models/generative-model.ts @@ -43,6 +43,7 @@ import { ChatSession } from "../methods/chat-session"; import { countTokens } from "../methods/count-tokens"; import { batchEmbedContents, embedContent } from "../methods/embed-content"; import { + formatCountTokensInput, formatEmbedContentInput, formatGenerateContentInput, formatSystemInstruction, @@ -157,7 +158,7 @@ export class GenerativeModel { async countTokens( request: CountTokensRequest | string | Array