Skip to content

Commit

Permalink
feat: add knnQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
KennyLindahl authored and Andreas Franzon committed Mar 11, 2024
1 parent 09f73cc commit 89c150f
Show file tree
Hide file tree
Showing 6 changed files with 416 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/core/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ exports.Aggregation = require('./aggregation');

exports.Query = require('./query');

exports.KNN = require('./knn');

exports.Suggester = require('./suggester');

exports.Script = require('./script');
Expand Down
133 changes: 133 additions & 0 deletions src/core/knn.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
'use strict';

const { recursiveToJSON, checkType } = require('./util');
const Query = require('./query');

/**
* Class representing a k-Nearest Neighbors (k-NN) query.
* This class extends the Query class to support the specifics of k-NN search, including setting up the field,
* query vector, number of neighbors (k), and number of candidates.
*
* NOTE: kNN search was added to Elasticsearch in v8.0
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html)
*/
class KNN {
/**
* Creates an instance of KNN.
*/
constructor(field, k, numCandidates) {
if (k > numCandidates)
throw new Error('KNN numCandidates cannot be less than k');
this._body = {};
this._body.field = field;
this._body.k = k;
this._body.filter = [];
this._body.num_candidates = numCandidates;
}

/**
* Sets the query vector for the k-NN search.
* @param {Array<number>} vector - The query vector.
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
queryVector(vector) {
if (this._body.query_vector_builder)
throw new Error(
'cannot provide both query_vector_builder and query_vector'
);
this._body.query_vector = vector;
return this;
}

/**
* Sets the query vector builder for the k-NN search.
* This method configures a query vector builder using a specified model ID and model text.
* It's important to note that either a direct query vector or a query vector builder can be
* provided, but not both.
*
* @param {string} modelId - The ID of the model to be used for generating the query vector.
* @param {string} modelText - The text input based on which the query vector is generated.
* @returns {KNN} Returns the instance of KNN for method chaining.
* @throws {Error} Throws an error if both query_vector_builder and query_vector are provided.
*
* Usage example:
* let knn = new KNN();
* knn.queryVectorBuilder('model_123', 'Sample model text');
*/
queryVectorBuilder(modelId, modelText) {
if (this._body.query_vector)
throw new Error(
'cannot provide both query_vector_builder and query_vector'
);
this._body.query_vector_builder = {
text_embeddings: {
model_id: modelId,
model_text: modelText
}
};
return this;
}

/**
* Adds one or more filter queries to the k-NN search.
*
* This method is designed to apply filters to the k-NN search. It accepts either a single
* query or an array of queries. Each query acts as a filter, refining the search results
* according to the specified conditions. These queries must be instances of the `Query` class.
* If any provided query is not an instance of `Query`, a TypeError is thrown.
*
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
* @returns {KNN} Returns `this` to allow method chaining.
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
*
* Usage example:
* let knn = new KNN();
* knn.filter(new TermQuery('field', 'value')); // Applying a single filter query
* knn.filter([new TermQuery('field1', 'value1'), new TermQuery('field2', 'value2')]); // Applying multiple filter queries
*/
filter(queries) {
const queryArray = Array.isArray(queries) ? queries : [queries];
queryArray.forEach(query => {
checkType(query, Query);
this._body.filter.push(query);
});
return this;
}

/**
* Sets the field to perform the k-NN search on.
* @param {number} boost - The number of the boost
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
boost(boost) {
this._body.boost = boost;
return this;
}

/**
* Sets the field to perform the k-NN search on.
* @param {number} similarity - The number of the similarity
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
similarity(similarity) {
this._body.similarity = similarity;
return this;
}

/**
* Override default `toJSON` to return DSL representation for the `query`
*
* @override
* @returns {Object} returns an Object which maps to the elasticsearch query DSL
*/
toJSON() {
if (!this._body.query_vector && !this._body.query_vector_builder)
throw new Error(
'either query_vector_builder or query_vector must be provided'
);
return recursiveToJSON(this._body);
}
}

module.exports = KNN;
21 changes: 20 additions & 1 deletion src/core/request-body-search.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ const Query = require('./query'),
Rescore = require('./rescore'),
Sort = require('./sort'),
Highlight = require('./highlight'),
InnerHits = require('./inner-hits');
InnerHits = require('./inner-hits'),
KNN = require('./knn');

const { checkType, setDefault, recursiveToJSON } = require('./util');

Expand Down Expand Up @@ -69,6 +70,7 @@ class RequestBodySearch {
constructor() {
// Maybe accept some optional parameter?
this._body = {};
this._knn = [];
this._aggs = [];
this._suggests = [];
this._suggestText = null;
Expand All @@ -87,6 +89,21 @@ class RequestBodySearch {
return this;
}

/**
* Sets knn on the search request body.
*
* @param {Knn|Knn[]} knn
* @returns {RequestBodySearch} returns `this` so that calls can be chained.
*/
kNN(knn) {
const knns = Array.isArray(knn) ? knn : [knn];
knns.forEach(_knn => {
checkType(_knn, KNN);
this._knn.push(_knn);
});
return this;
}

/**
* Sets aggregation on the request body.
* Alias for method `aggregation`
Expand Down Expand Up @@ -785,6 +802,8 @@ class RequestBodySearch {
toJSON() {
const dsl = recursiveToJSON(this._body);

if (!isEmpty(this._knn)) dsl.knn = this._knn;

if (!isEmpty(this._aggs)) dsl.aggs = recMerge(this._aggs);

if (!isEmpty(this._suggests) || !isNil(this._suggestText)) {
Expand Down
167 changes: 166 additions & 1 deletion src/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ declare namespace esb {
*/
query(query: Query): this;

/**
* Sets knn on the request body.
*
* @param {KNN|KNN[]} knn
*/
kNN(knn: KNN | KNN[]): this;

/**
* Sets aggregation on the request body.
* Alias for method `aggregation`
Expand Down Expand Up @@ -3074,7 +3081,7 @@ declare namespace esb {

/**
* Sets the script used to compute the score of documents returned by the query.
*
*
* @param {Script} script A valid `Script` object
*/
script(script: Script): this;
Expand Down Expand Up @@ -3268,6 +3275,86 @@ declare namespace esb {
*/
export function distanceFeatureQuery(field?: string): DistanceFeatureQuery;

/**
* The `rank_feature` query boosts the relevance score on the numeric value of
* document with a rank_feature/rank_features field.
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-rank-feature-query.html)
*
* NOTE: This query was added in elasticsearch v7.0.
*
* @example
* const query = new RankFeatureQuery('rank_feature_field');
* query
* .linear()
* .toJSON();
* @param {string} field The field inside the document to be used in the query
* @return {RankFeatureQuery}
*/
export class RankFeatureQuery extends Query {
constructor(field?: string);

/**
* Sets the field for the `distance_feature` query
* @param {string} fieldName Name of the field inside the document
* @returns {DistanceFeatureQuery} Instance of the query
*/
field(fieldName: string): RankFeatureQuery;

/**
* Linear function to boost relevance scores based on the value of the rank feature field
* @returns {RankFeatureQuery}
*/
linear(): RankFeatureQuery;

/**
* Saturation function to boost relevance scores based on the value of the rank feature field.
* Uses a default pivot value computed by Elasticsearch.
* @returns {RankFeatureQuery}
*/
saturation(): RankFeatureQuery;

/**
* Saturation function to boost relevance scores based on the value of the rank feature field.
* @param {number} pivot
* @returns {RankFeatureQuery}
*/
saturationPivot(pivot: number): RankFeatureQuery;

/**
* The log function gives a score equal to log(scaling_factor + S), where S
* is the value of the rank feature field and scaling_factor is a configurable
* scaling factor.
* @param {number} scaling_factor
* @returns {RankFeatureQuery}
*/
log(scalingFactor: number): RankFeatureQuery;

/**
* The sigmoid function extends the saturation function with a configurable exponent.
* @param {number} pivot
* @param {number} exponent
* @returns {RankFeatureQuery}
*/
sigmoid(pivot: number, exponent: number): RankFeatureQuery;
}

/**
* The `rank_feature` query boosts the relevance score on the numeric value of
* document with a rank_feature/rank_features field.
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-rank-feature-query.html)
*
* @example
* const query = new RankFeatureQuery('rank_feature_field');
* query
* .linear()
* .toJSON();
* @param {string} field The field inside the document to be used in the query
* @return {RankFeatureQuery}
*/
export function rankFeatureQuery(field?: string): RankFeatureQuery;

/**
* Interface-like class used to group and identify various implementations of Span queries.
*
Expand Down Expand Up @@ -3614,6 +3701,84 @@ declare namespace esb {
spanQry?: SpanQueryBase
): SpanFieldMaskingQuery;

/**
* Knn performs k-nearest neighbor (KNN) searches.
* This class allows configuring the KNN search with various parameters such as field, query vector,
* number of nearest neighbors (k), number of candidates, boost factor, and similarity metric.
*
* NOTE: Only available in Elasticsearch v8.0+
*/
export class KNN {
/**
* Creates an instance of Knn, initializing the internal state for the k-NN search.
*
* @param {string} field - (Optional) The field against which to perform the k-NN search.
* @param {number} k - (Optional) The number of nearest neighbors to retrieve.
* @param {number} numCandidates - (Optional) The number of candidate neighbors to consider during the search.
* @throws {Error} If the number of candidates (numCandidates) is less than the number of neighbors (k).
*/
constructor(field: string, k: number, numCandidates: number);

/**
* Sets the query vector for the KNN search, an array of numbers representing the reference point.
*
* @param {number[]} vector
*/
queryVector(vector: number[]): this;

/**
* Sets the query vector builder for the k-NN search.
* This method configures a query vector builder using a specified model ID and model text.
* Note that either a direct query vector or a query vector builder can be provided, but not both.
*
* @param {string} modelId - The ID of the model used for generating the query vector.
* @param {string} modelText - The text input based on which the query vector is generated.
* @returns {KNN} Returns the instance of Knn for method chaining.
* @throws {Error} If both query_vector_builder and query_vector are provided.
*/
queryVectorBuilder(modelId: string, modelText: string): this;

/**
* Adds one or more filter queries to the k-NN search.
* This method is designed to apply filters to the k-NN search. It accepts either a single
* query or an array of queries. Each query acts as a filter, refining the search results
* according to the specified conditions. These queries must be instances of the `Query` class.
*
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
* @returns {KNN} Returns `this` to allow method chaining.
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
*/
filter(queries: Query | Query[]): this;

/**
* Applies a boost factor to the query to influence the relevance score of returned documents.
*
* @param {number} boost
*/
boost(boost: number): this;

/**
* Sets the similarity metric used in the KNN algorithm to calculate similarity.
*
* @param {number} similarity
*/
similarity(similarity: number): this;

/**
* Override default `toJSON` to return DSL representation for the `query`
*
* @override
*/
toJSON(): object;
}

/**
* Factory function to instantiate a new Knn object.
*
* @returns {KNN}
*/
export function kNN(field: string, k: number, numCandidates: number): KNN;

/**
* Base class implementation for all aggregation types.
*
Expand Down
Loading

0 comments on commit 89c150f

Please sign in to comment.