diff --git a/.changeset/dirty-wolves-sin.md b/.changeset/dirty-wolves-sin.md new file mode 100644 index 00000000..170f56df --- /dev/null +++ b/.changeset/dirty-wolves-sin.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": minor +--- + +Expand the model's `countTokens` method to alternatively accept a `GenerateContentRequest`. diff --git a/docs/reference/main/generative-ai.counttokensrequest.contents.md b/docs/reference/main/generative-ai.counttokensrequest.contents.md index 693fe3d5..948e6c79 100644 --- a/docs/reference/main/generative-ai.counttokensrequest.contents.md +++ b/docs/reference/main/generative-ai.counttokensrequest.contents.md @@ -7,5 +7,5 @@ **Signature:** ```typescript -contents: Content[]; +contents?: Content[]; ``` diff --git a/docs/reference/main/generative-ai.counttokensrequest.generatecontentrequest.md b/docs/reference/main/generative-ai.counttokensrequest.generatecontentrequest.md new file mode 100644 index 00000000..86bf60ad --- /dev/null +++ b/docs/reference/main/generative-ai.counttokensrequest.generatecontentrequest.md @@ -0,0 +1,11 @@ + + +[Home](./index.md) > [@google/generative-ai](./generative-ai.md) > [CountTokensRequest](./generative-ai.counttokensrequest.md) > [generateContentRequest](./generative-ai.counttokensrequest.generatecontentrequest.md) + +## CountTokensRequest.generateContentRequest property + +**Signature:** + +```typescript +generateContentRequest?: GenerateContentRequest; +``` diff --git a/docs/reference/main/generative-ai.counttokensrequest.md b/docs/reference/main/generative-ai.counttokensrequest.md index 7895bfcc..5217e1b3 100644 --- a/docs/reference/main/generative-ai.counttokensrequest.md +++ b/docs/reference/main/generative-ai.counttokensrequest.md @@ -4,7 +4,9 @@ ## CountTokensRequest interface -Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md) +Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md). + +The request must contain either a [Content](./generative-ai.content.md) array or a [GenerateContentRequest](./generative-ai.generatecontentrequest.md), but not both. If both are provided then a [GoogleGenerativeAIRequestInputError](./generative-ai.googlegenerativeairequestinputerror.md) is thrown. **Signature:** @@ -16,5 +18,6 @@ export interface CountTokensRequest | Property | Modifiers | Type | Description | | --- | --- | --- | --- | -| [contents](./generative-ai.counttokensrequest.contents.md) | | [Content](./generative-ai.content.md)\[\] | | +| [contents?](./generative-ai.counttokensrequest.contents.md) | | [Content](./generative-ai.content.md)\[\] | _(Optional)_ | +| [generateContentRequest?](./generative-ai.counttokensrequest.generatecontentrequest.md) | | [GenerateContentRequest](./generative-ai.generatecontentrequest.md) | _(Optional)_ | diff --git a/docs/reference/main/generative-ai.md b/docs/reference/main/generative-ai.md index ec5a3a7c..e2719d97 100644 --- a/docs/reference/main/generative-ai.md +++ b/docs/reference/main/generative-ai.md @@ -40,7 +40,7 @@ | [CitationSource](./generative-ai.citationsource.md) | A single citation source. | | [Content](./generative-ai.content.md) | Content type for both prompts and response candidates. | | [ContentEmbedding](./generative-ai.contentembedding.md) | A single content embedding. | -| [CountTokensRequest](./generative-ai.counttokensrequest.md) | Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md) | +| [CountTokensRequest](./generative-ai.counttokensrequest.md) |

Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md).

The request must contain either a [Content](./generative-ai.content.md) array or a [GenerateContentRequest](./generative-ai.generatecontentrequest.md), but not both. If both are provided then a [GoogleGenerativeAIRequestInputError](./generative-ai.googlegenerativeairequestinputerror.md) is thrown.

| | [CountTokensResponse](./generative-ai.counttokensresponse.md) | Response from calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md). | | [EmbedContentRequest](./generative-ai.embedcontentrequest.md) | Params for calling [GenerativeModel.embedContent()](./generative-ai.generativemodel.embedcontent.md) | | [EmbedContentResponse](./generative-ai.embedcontentresponse.md) | Response from calling [GenerativeModel.embedContent()](./generative-ai.generativemodel.embedcontent.md). | diff --git a/packages/main/src/models/generative-model.test.ts b/packages/main/src/models/generative-model.test.ts index 5f4ce1c9..f61a1375 100644 --- a/packages/main/src/models/generative-model.test.ts +++ b/packages/main/src/models/generative-model.test.ts @@ -18,6 +18,7 @@ import { expect, use } from "chai"; import { GenerativeModel } from "./generative-model"; import * as sinonChai from "sinon-chai"; import { + CountTokensRequest, FunctionCallingMode, HarmBlockThreshold, HarmCategory, @@ -319,4 +320,34 @@ describe("GenerativeModel", () => { ); restore(); }); + it("countTokens errors if contents and generateContentRequest are both defined", async () => { + const genModel = new GenerativeModel( + "apiKey", + { + model: "my-model", + }, + { + apiVersion: "v2000", + }, + ); + const mockResponse = getMockResponse( + "unary-success-basic-reply-short.json", + ); + const makeRequestStub = stub(request, "makeRequest").resolves( + mockResponse as Response, + ); + const countTokensRequest: CountTokensRequest = { + contents: [{ role: "user", parts: [{ text: "hello" }] }], + generateContentRequest: { + contents: [{ role: "user", parts: [{ text: "hello" }] }], + }, + }; + await expect( + genModel.countTokens(countTokensRequest), + ).to.eventually.be.rejectedWith( + "CountTokensRequest must have one of contents or generateContentRequest, not both.", + ); + expect(makeRequestStub).to.not.be.called; + restore(); + }); }); diff --git a/packages/main/src/models/generative-model.ts b/packages/main/src/models/generative-model.ts index df1ecfa9..4bc2c63e 100644 --- a/packages/main/src/models/generative-model.ts +++ b/packages/main/src/models/generative-model.ts @@ -43,6 +43,7 @@ import { ChatSession } from "../methods/chat-session"; import { countTokens } from "../methods/count-tokens"; import { batchEmbedContents, embedContent } from "../methods/embed-content"; import { + formatCountTokensInput, formatEmbedContentInput, formatGenerateContentInput, formatSystemInstruction, @@ -157,7 +158,7 @@ export class GenerativeModel { async countTokens( request: CountTokensRequest | string | Array, ): Promise { - const formattedParams = formatGenerateContentInput(request); + const formattedParams = formatCountTokensInput(request, this.model); return countTokens( this.apiKey, this.model, diff --git a/packages/main/src/requests/request-helpers.ts b/packages/main/src/requests/request-helpers.ts index 9c6694a8..afadaefc 100644 --- a/packages/main/src/requests/request-helpers.ts +++ b/packages/main/src/requests/request-helpers.ts @@ -17,11 +17,16 @@ import { Content, + CountTokensRequest, + CountTokensRequestInternal, EmbedContentRequest, GenerateContentRequest, Part, } from "../../types"; -import { GoogleGenerativeAIError } from "../errors"; +import { + GoogleGenerativeAIError, + GoogleGenerativeAIRequestInputError, +} from "../errors"; export function formatSystemInstruction( input?: string | Part | Content, @@ -104,6 +109,31 @@ function assignRoleToPartsAndValidateSendMessageRequest( return functionContent; } +export function formatCountTokensInput( + params: CountTokensRequest | string | Array, + model: string, +): CountTokensRequestInternal { + let formattedRequest: CountTokensRequestInternal = {}; + const containsGenerateContentRequest = + (params as CountTokensRequest).generateContentRequest != null; + if ((params as CountTokensRequest).contents) { + if (containsGenerateContentRequest) { + throw new GoogleGenerativeAIRequestInputError( + "CountTokensRequest must have one of contents or generateContentRequest, not both.", + ); + } + formattedRequest = { ...(params as CountTokensRequest) }; + } else if (containsGenerateContentRequest) { + formattedRequest = { ...(params as CountTokensRequest) }; + formattedRequest.generateContentRequest.model = model; + } else { + // Array or string + const content = formatNewContent(params as string | Array); + formattedRequest.contents = [content]; + } + return formattedRequest; +} + export function formatGenerateContentInput( params: GenerateContentRequest | string | Array, ): GenerateContentRequest { diff --git a/packages/main/test-integration/node/count-tokens.test.ts b/packages/main/test-integration/node/count-tokens.test.ts index 7e8dda45..76288198 100644 --- a/packages/main/test-integration/node/count-tokens.test.ts +++ b/packages/main/test-integration/node/count-tokens.test.ts @@ -18,6 +18,7 @@ import { expect, use } from "chai"; import * as chaiAsPromised from "chai-as-promised"; import { GoogleGenerativeAI, HarmBlockThreshold, HarmCategory } from "../.."; +import { CountTokensRequest } from "../../types"; use(chaiAsPromised); @@ -46,4 +47,23 @@ describe("countTokens", function () { expect(response1.totalTokens).to.equal(3); expect(response2.totalTokens).to.equal(3); }); + it("counts tokens with GenerateContentRequest", async () => { + const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || ""); + const model = genAI.getGenerativeModel({ + model: "gemini-1.5-flash-latest", + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH, + }, + ], + }); + const countTokensRequest: CountTokensRequest = { + generateContentRequest: { + contents: [{ role: "user", parts: [{ text: "count me" }] }], + }, + }; + const response = await model.countTokens(countTokensRequest); + expect(response.totalTokens).to.equal(3); + }); }); diff --git a/packages/main/types/requests.ts b/packages/main/types/requests.ts index 6787ea8b..9152f489 100644 --- a/packages/main/types/requests.ts +++ b/packages/main/types/requests.ts @@ -54,6 +54,14 @@ export interface GenerateContentRequest extends BaseParams { systemInstruction?: string | Part | Content; } +/** + * Request sent to `generateContent` endpoint. + * @internal + */ +export interface GenerateContentRequestInternal extends GenerateContentRequest { + model?: string; +} + /** * Safety setting that can be sent as part of request parameters. * @public @@ -95,11 +103,26 @@ export interface StartChatParams extends BaseParams { } /** - * Params for calling {@link GenerativeModel.countTokens} + * Params for calling {@link GenerativeModel.countTokens}. + * + * The request must contain either a {@link Content} array or a + * {@link GenerateContentRequest}, but not both. If both are provided + * then a {@link GoogleGenerativeAIRequestInputError} is thrown. + * * @public */ export interface CountTokensRequest { - contents: Content[]; + generateContentRequest?: GenerateContentRequest; + contents?: Content[]; +} + +/** + * Params for calling {@link GenerativeModel.countTokens} + * @internal + */ +export interface CountTokensRequestInternal { + generateContentRequest?: GenerateContentRequestInternal; + contents?: Content[]; } /**