diff --git a/.changeset/rare-birds-bow.md b/.changeset/rare-birds-bow.md new file mode 100644 index 00000000..3233c614 --- /dev/null +++ b/.changeset/rare-birds-bow.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": patch +--- + +Fix countTokens to include any params set on the model instance. diff --git a/package.json b/package.json index f6e6a7a1..472f5bc6 100644 --- a/package.json +++ b/package.json @@ -31,6 +31,7 @@ "@web/test-runner": "^0.18.0", "chai": "^4.3.10", "chai-as-promised": "^7.1.1", + "chai-deep-equal-ignore-undefined": "^1.1.1", "eslint": "^8.52.0", "eslint-plugin-import": "^2.29.0", "eslint-plugin-unused-imports": "^3.0.0", diff --git a/packages/main/src/models/generative-model.test.ts b/packages/main/src/models/generative-model.test.ts index 61cb630e..0245e01b 100644 --- a/packages/main/src/models/generative-model.test.ts +++ b/packages/main/src/models/generative-model.test.ts @@ -240,6 +240,7 @@ describe("GenerativeModel", () => { "apiKey", { model: "my-model", + systemInstruction: "you are a cat", }, { apiVersion: "v2000", @@ -257,7 +258,9 @@ describe("GenerativeModel", () => { request.Task.COUNT_TOKENS, match.any, false, - match.any, + match((value: string) => { + return value.includes("hello") && value.includes("you are a cat"); + }), match((value) => { return value.apiVersion === "v2000"; }), diff --git a/packages/main/src/models/generative-model.ts b/packages/main/src/models/generative-model.ts index 029ac95a..bbc17601 100644 --- a/packages/main/src/models/generative-model.ts +++ b/packages/main/src/models/generative-model.ts @@ -164,7 +164,15 @@ export class GenerativeModel { async countTokens( request: CountTokensRequest | string | Array, ): Promise { - const formattedParams = formatCountTokensInput(request, this.model); + const formattedParams = formatCountTokensInput(request, { + model: this.model, + generationConfig: this.generationConfig, + safetySettings: this.safetySettings, + tools: this.tools, + toolConfig: this.toolConfig, + systemInstruction: this.systemInstruction, + cachedContent: this.cachedContent, + }); return countTokens( this.apiKey, this.model, diff --git a/packages/main/src/requests/request-helpers.test.ts b/packages/main/src/requests/request-helpers.test.ts index 74c29092..f3d46cd0 100644 --- a/packages/main/src/requests/request-helpers.test.ts +++ b/packages/main/src/requests/request-helpers.test.ts @@ -17,10 +17,15 @@ import { expect, use } from "chai"; import * as sinonChai from "sinon-chai"; +import chaiDeepEqualIgnoreUndefined from "chai-deep-equal-ignore-undefined"; import { Content } from "../../types"; -import { formatGenerateContentInput } from "./request-helpers"; +import { + formatCountTokensInput, + formatGenerateContentInput, +} from "./request-helpers"; use(sinonChai); +use(chaiDeepEqualIgnoreUndefined); describe("request formatting methods", () => { describe("formatGenerateContentInput", () => { @@ -172,4 +177,102 @@ describe("request formatting methods", () => { }); }); }); + describe("formatCountTokensInput", () => { + it("formats a text string into a count request", () => { + const result = formatCountTokensInput("some text content", { + model: "gemini-1.5-flash", + }); + expect(result.generateContentRequest).to.deepEqualIgnoreUndefined({ + model: "gemini-1.5-flash", + contents: [ + { + role: "user", + parts: [{ text: "some text content" }], + }, + ], + }); + }); + it("formats a text string into a count request, along with model params", () => { + const result = formatCountTokensInput("some text content", { + model: "gemini-1.5-flash", + systemInstruction: "hello", + tools: [{ codeExecution: {} }], + cachedContent: { name: "mycache", contents: [] }, + }); + expect(result.generateContentRequest).to.deepEqualIgnoreUndefined({ + model: "gemini-1.5-flash", + contents: [ + { + role: "user", + parts: [{ text: "some text content" }], + }, + ], + systemInstruction: "hello", + tools: [{ codeExecution: {} }], + cachedContent: "mycache", + }); + }); + it("formats a 'contents' style count request, along with model params", () => { + const result = formatCountTokensInput( + { + contents: [ + { + role: "user", + parts: [{ text: "some text content" }], + }, + ], + }, + { + model: "gemini-1.5-flash", + systemInstruction: "hello", + tools: [{ codeExecution: {} }], + cachedContent: { name: "mycache", contents: [] }, + }, + ); + expect(result.generateContentRequest).to.deepEqualIgnoreUndefined({ + model: "gemini-1.5-flash", + contents: [ + { + role: "user", + parts: [{ text: "some text content" }], + }, + ], + systemInstruction: "hello", + tools: [{ codeExecution: {} }], + cachedContent: "mycache", + }); + }); + it("formats a 'generateContentRequest' style count request, along with model params", () => { + const result = formatCountTokensInput( + { + generateContentRequest: { + contents: [ + { + role: "user", + parts: [{ text: "some text content" }], + }, + ], + }, + }, + { + model: "gemini-1.5-flash", + systemInstruction: "hello", + tools: [{ codeExecution: {} }], + cachedContent: { name: "mycache", contents: [] }, + }, + ); + expect(result.generateContentRequest).to.deepEqualIgnoreUndefined({ + model: "gemini-1.5-flash", + contents: [ + { + role: "user", + parts: [{ text: "some text content" }], + }, + ], + systemInstruction: "hello", + tools: [{ codeExecution: {} }], + cachedContent: "mycache", + }); + }); + }); }); diff --git a/packages/main/src/requests/request-helpers.ts b/packages/main/src/requests/request-helpers.ts index 5a7396e6..58232e05 100644 --- a/packages/main/src/requests/request-helpers.ts +++ b/packages/main/src/requests/request-helpers.ts @@ -20,8 +20,10 @@ import { CountTokensRequest, EmbedContentRequest, GenerateContentRequest, + ModelParams, Part, _CountTokensRequestInternal, + _GenerateContentRequestInternal, } from "../../types"; import { GoogleGenerativeAIError, @@ -111,9 +113,18 @@ function assignRoleToPartsAndValidateSendMessageRequest( export function formatCountTokensInput( params: CountTokensRequest | string | Array, - model: string, + modelParams?: ModelParams, ): _CountTokensRequestInternal { - let formattedRequest: _CountTokensRequestInternal = {}; + let formattedGenerateContentRequest: _GenerateContentRequestInternal = { + model: modelParams?.model, + generationConfig: modelParams?.generationConfig, + safetySettings: modelParams?.safetySettings, + tools: modelParams?.tools, + toolConfig: modelParams?.toolConfig, + systemInstruction: modelParams?.systemInstruction, + cachedContent: modelParams?.cachedContent?.name, + contents: [], + }; const containsGenerateContentRequest = (params as CountTokensRequest).generateContentRequest != null; if ((params as CountTokensRequest).contents) { @@ -122,16 +133,20 @@ export function formatCountTokensInput( "CountTokensRequest must have one of contents or generateContentRequest, not both.", ); } - formattedRequest = { ...(params as CountTokensRequest) }; + formattedGenerateContentRequest.contents = ( + params as CountTokensRequest + ).contents; } else if (containsGenerateContentRequest) { - formattedRequest = { ...(params as CountTokensRequest) }; - formattedRequest.generateContentRequest.model = model; + formattedGenerateContentRequest = { + ...formattedGenerateContentRequest, + ...(params as CountTokensRequest).generateContentRequest, + }; } else { // Array or string const content = formatNewContent(params as string | Array); - formattedRequest.contents = [content]; + formattedGenerateContentRequest.contents = [content]; } - return formattedRequest; + return { generateContentRequest: formattedGenerateContentRequest }; } export function formatGenerateContentInput( diff --git a/samples/count_tokens.js b/samples/count_tokens.js index 447f76ea..d0fc4124 100644 --- a/samples/count_tokens.js +++ b/samples/count_tokens.js @@ -287,36 +287,30 @@ async function tokensCachedContent() { }); const genAI = new GoogleGenerativeAI(process.env.API_KEY); - const model = genAI.getGenerativeModel({ - model: "models/gemini-1.5-flash", - }); + const model = genAI.getGenerativeModelFromCachedContent(cacheResult); + + const prompt = "Please give a short summary of this file."; // Call `countTokens` to get the input token count // of the combined text and file (`totalTokens`). - const result = await model.countTokens({ - generateContentRequest: { - contents: [ - { - role: "user", - parts: [{ text: "Please give a short summary of this file." }], - }, - ], - cachedContent: cacheResult.name, - }, - }); + const result = await model.countTokens(prompt); console.log(result.totalTokens); // 10 - const generateResult = await model.generateContent( - "Please give a short summary of this file.", - ); + const generateResult = await model.generateContent(prompt); // On the response for `generateContent`, use `usageMetadata` // to get separate input and output token counts // (`promptTokenCount` and `candidatesTokenCount`, respectively), - // as well as the combined token count (`totalTokenCount`). + // as well as the cached content token count and the combined total + // token count. console.log(generateResult.response.usageMetadata); - // { promptTokenCount: 10, candidatesTokenCount: 31, totalTokenCount: 41 } + // { + // promptTokenCount: 323396, + // candidatesTokenCount: 113, + // totalTokenCount: 323509, + // cachedContentTokenCount: 323386 + // } await cacheManager.delete(cacheResult.name); // [END tokens_cached_content] @@ -329,22 +323,12 @@ async function tokensSystemInstruction() { const genAI = new GoogleGenerativeAI(process.env.API_KEY); const model = genAI.getGenerativeModel({ model: "models/gemini-1.5-flash", + systemInstruction: "You are a cat. Your name is Neko.", }); - const result = await model.countTokens({ - generateContentRequest: { - contents: [ - { - role: "user", - parts: [{ text: "The quick brown fox jumps over the lazy dog." }], - }, - ], - systemInstruction: { - role: "system", - parts: [{ text: "You are a cat. Your name is Neko." }], - }, - }, - }); + const result = await model.countTokens( + "The quick brown fox jumps over the lazy dog.", + ); console.log(result); // { @@ -360,9 +344,6 @@ async function tokensTools() { // Make sure to include these imports: // import { GoogleGenerativeAI } from "@google/generative-ai"; const genAI = new GoogleGenerativeAI(process.env.API_KEY); - const model = genAI.getGenerativeModel({ - model: "models/gemini-1.5-flash", - }); const functionDeclarations = [ { name: "add" }, @@ -371,22 +352,15 @@ async function tokensTools() { { name: "divide" }, ]; - const result = await model.countTokens({ - generateContentRequest: { - contents: [ - { - role: "user", - parts: [ - { - text: "I have 57 cats, each owns 44 mittens, how many mittens is that in total?", - }, - ], - }, - ], - tools: [{ functionDeclarations }], - }, + const model = genAI.getGenerativeModel({ + model: "models/gemini-1.5-flash", + tools: [{ functionDeclarations }], }); + const result = await model.countTokens( + "I have 57 cats, each owns 44 mittens, how many mittens is that in total?", + ); + console.log(result); // { // totalTokens: 99, diff --git a/yarn.lock b/yarn.lock index b58b094a..c0445408 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2389,6 +2389,11 @@ chai-as-promised@^7.1.1: dependencies: check-error "^1.0.2" +chai-deep-equal-ignore-undefined@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/chai-deep-equal-ignore-undefined/-/chai-deep-equal-ignore-undefined-1.1.1.tgz#c9e3736fed06c83572f03c592c025cf2703fd1a1" + integrity sha512-BE4nUR2Jbqmmv8A0EuAydFRB/lXgXWAfa9TvO3YzHeGHAU7ZRwPZyu074oDl/CZtNXM7jXINpQxKBOe7N0P4bg== + chai@^4.3.10: version "4.3.10" resolved "https://registry.yarnpkg.com/chai/-/chai-4.3.10.tgz#d784cec635e3b7e2ffb66446a63b4e33bd390384"