From 1188aaedac4313bbbb1526b14fb9e03672a61eb8 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 8 Jan 2025 00:35:07 +0200 Subject: [PATCH 1/2] fix: reranking probabilities --- src/evaluator/LlamaRankingContext.ts | 21 +++++- test/modelDependent/bgeReranker/rank.test.ts | 78 +++++++++++--------- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/src/evaluator/LlamaRankingContext.ts b/src/evaluator/LlamaRankingContext.ts index 4c9ed593..d6f844a5 100644 --- a/src/evaluator/LlamaRankingContext.ts +++ b/src/evaluator/LlamaRankingContext.ts @@ -76,6 +76,9 @@ export class LlamaRankingContext { /** * Get the ranking score for a document for a query. + * + * A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query. + * @returns a ranking score between 0 and 1 representing the probability that the document is relevant to the query. */ public async rank(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) { if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null) @@ -96,6 +99,9 @@ export class LlamaRankingContext { /** * Get the ranking scores for all the given documents for a query. + * + * A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query. + * @returns an array of ranking scores between 0 and 1 representing the probability that the document is relevant to the query. */ public async rankAll(query: Token[] | string | LlamaText, documents: Array): Promise { const resolvedTokens = documents.map((document) => this._getEvaluationInput(query, document)); @@ -120,9 +126,15 @@ export class LlamaRankingContext { /** * Get the ranking scores for all the given documents for a query and sort them by score from highest to lowest. + * + * A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query. */ public async rankAndSort(query: Token[] | string | LlamaText, documents: T[]): Promise> { const scores = await this.rankAll(query, documents); @@ -190,7 +202,10 @@ export class LlamaRankingContext { if (embedding.length === 0) return 0; - return embedding[0]!; + const logit = embedding[0]!; + const probability = logitToSigmoid(logit); + + return probability; }); } @@ -249,3 +264,7 @@ function findLayer(tensorInfo: GgufTensorInfo[] | undefined, name: string, suffi return undefined; } + +function logitToSigmoid(logit: number) { + return 1 / (1 + Math.exp(-logit)); +} diff --git a/test/modelDependent/bgeReranker/rank.test.ts b/test/modelDependent/bgeReranker/rank.test.ts index 5966b214..5d3dee3b 100644 --- a/test/modelDependent/bgeReranker/rank.test.ts +++ b/test/modelDependent/bgeReranker/rank.test.ts @@ -40,19 +40,19 @@ describe("bgeReranker", () => { const highestRankDocument = documents[highestRankIndex]; expect(highestRankDocument).to.eql("Mount Everest is the tallest mountain in the world"); - expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("-4"); + expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("0.01798620996209156"); expect(simplifyRanks(ranks)).toMatchInlineSnapshot(` [ - -11, - -11, - -11, - -5.6, - -11, - -4, - -11, - -11, - -11, - -11, + 0.00001670142184809518, + 0.00001670142184809518, + 0.00001670142184809518, + 0.003684239899435989, + 0.00001670142184809518, + 0.01798620996209156, + 0.00001670142184809518, + 0.00001670142184809518, + 0.00001670142184809518, + 0.00001670142184809518, ] `); }); @@ -91,19 +91,19 @@ describe("bgeReranker", () => { const highestRankDocument = documents[highestRankIndex]; expect(highestRankDocument).to.eql("Mount Everest is the tallest mountain in the world"); - expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("-4"); + expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("0.01798620996209156"); expect(simplifyRanks(ranks)).toMatchInlineSnapshot(` [ - -11, - -11, - -11, - -5.6, - -11, - -4, - -11, - -11, - -11, - -11, + 0.00001670142184809518, + 0.00001670142184809518, + 0.00001670142184809518, + 0.003684239899435989, + 0.00001670142184809518, + 0.01798620996209156, + 0.00001670142184809518, + 0.00001670142184809518, + 0.00001670142184809518, + 0.00001670142184809518, ] `); }); @@ -141,42 +141,42 @@ describe("bgeReranker", () => { expect(simplifySortedRanks([topDocument])[0]).toMatchInlineSnapshot(` { "document": "Mount Everest is the tallest mountain in the world", - "score": -4, + "score": 0.01798620996209156, } `); expect(simplifySortedRanks(rankedDocuments)).toMatchInlineSnapshot(` [ { "document": "Mount Everest is the tallest mountain in the world", - "score": -4, + "score": 0.01798620996209156, }, { "document": "The capital of France is Paris", - "score": -5.6, + "score": 0.003684239899435989, }, { "document": "Not all the things that shine are made of gold", - "score": -11, + "score": 0.00001670142184809518, }, { "document": "I love eating pizza with extra cheese", - "score": -11, + "score": 0.00001670142184809518, }, { "document": "Dogs love to play fetch with their owners", - "score": -11, + "score": 0.00001670142184809518, }, { "document": "The sky is clear and blue today", - "score": -11, + "score": 0.00001670142184809518, }, { "document": "Cleaning the house is a good way to keep it tidy", - "score": -11, + "score": 0.00001670142184809518, }, { "document": "A warm cup of tea is perfect for a cold winter day", - "score": -11, + "score": 0.00001670142184809518, }, ] `); @@ -185,16 +185,28 @@ describe("bgeReranker", () => { }); function simplifyRanks(ranks: T): T { - return ranks.map((rank) => parseFloat(roundToPrecision(rank, 0.2).toFixed(1))) as T; + return ranks.map((rank) => simplifyScore(rank)) as T; } function simplifySortedRanks(values: T): T { return values.map((item) => ({ document: item.document, - score: parseFloat(roundToPrecision(item.score, 0.2).toFixed(1)) + score: simplifyScore(item.score) })) as T; } +function simplifyScore(score: number) { + return toSigmoid(parseFloat(roundToPrecision(toLogit(score), 0.2).toFixed(1))); +} + function roundToPrecision(value: number, precision: number): number { return Math.round(value / precision) * precision; } + +function toLogit(sigmoid: number) { + return Math.log(sigmoid / (1 - sigmoid)); +} + +function toSigmoid(logit: number) { + return 1 / (1 + Math.exp(-logit)); +} From d4c5b468ce47bbd55c6e81c16b091ddf120156ac Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 8 Jan 2025 00:37:35 +0200 Subject: [PATCH 2/2] chore: update module --- package-lock.json | 10 +++++----- package.json | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/package-lock.json b/package-lock.json index bc63811e..c2973dd0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -86,7 +86,7 @@ "typedoc": "^0.27.6", "typedoc-plugin-markdown": "^4.4.1", "typedoc-plugin-mdn-links": "^4.0.7", - "typedoc-vitepress-theme": "^1.1.1", + "typedoc-vitepress-theme": "^1.1.2", "typescript": "^5.7.2", "typescript-eslint": "^8.19.1", "vite-node": "^2.1.8", @@ -18508,13 +18508,13 @@ } }, "node_modules/typedoc-vitepress-theme": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/typedoc-vitepress-theme/-/typedoc-vitepress-theme-1.1.1.tgz", - "integrity": "sha512-1UbhZdQIkGKLkIZCbw8putrel+Vo7KKFfd8RhQRSBgetUZGUJkum89kIyF3+Kzy+1nqE56/MLKVxpPgQYubYYg==", + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/typedoc-vitepress-theme/-/typedoc-vitepress-theme-1.1.2.tgz", + "integrity": "sha512-hQvCZRr5uKDqY1bRuY1+eNTNn6d4TE4OP5pnw65Y7WGgajkJW9X1/lVJK2UJpcwCmwkdjw1QIO49H9JQlxWhhw==", "dev": true, "license": "MIT", "peerDependencies": { - "typedoc-plugin-markdown": ">=4.3.0" + "typedoc-plugin-markdown": ">=4.4.0" } }, "node_modules/typescript": { diff --git a/package.json b/package.json index 73825289..18b5230b 100644 --- a/package.json +++ b/package.json @@ -167,7 +167,7 @@ "typedoc": "^0.27.6", "typedoc-plugin-markdown": "^4.4.1", "typedoc-plugin-mdn-links": "^4.0.7", - "typedoc-vitepress-theme": "^1.1.1", + "typedoc-vitepress-theme": "^1.1.2", "typescript": "^5.7.2", "typescript-eslint": "^8.19.1", "vite-node": "^2.1.8",