From 61850ef50075367918f2a73f1eda47020278cdbf Mon Sep 17 00:00:00 2001 From: Haydar Kulekci Date: Fri, 6 Oct 2023 14:36:21 +0300 Subject: [PATCH 1/5] knn endpoint support --- src/Knn/Knn.php | 193 ++++++++++++++++++ src/Search.php | 17 ++ src/SearchEndpoint/KnnEndpoint.php | 47 +++++ src/SearchEndpoint/SearchEndpointFactory.php | 15 +- src/SearchEndpoint/SuggestEndpoint.php | 4 +- src/Suggest/Suggest.php | 1 - .../SearchEndpointFactoryTest.php | 4 +- 7 files changed, 269 insertions(+), 12 deletions(-) create mode 100644 src/Knn/Knn.php create mode 100644 src/SearchEndpoint/KnnEndpoint.php diff --git a/src/Knn/Knn.php b/src/Knn/Knn.php new file mode 100644 index 00000000..984e3a69 --- /dev/null +++ b/src/Knn/Knn.php @@ -0,0 +1,193 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace ONGR\ElasticsearchDSL\Knn; + +use ONGR\ElasticsearchDSL\BuilderInterface; +use ONGR\ElasticsearchDSL\FieldAwareTrait; + +class Knn implements BuilderInterface +{ + use FieldAwareTrait; + + /** + * @var string + */ + private $field; + + /** + * @var array + */ + private $queryVector; + + /** + * @var int + */ + private $k; + + /** + * @var int + */ + private $numCandidates; + + /** + * @var int + */ + private $similarity = null; + + /** + * @var BuilderInterface + */ + private $filter = null; + + + /** + * TermSuggest constructor. + * @param string $field + * @param array $queryVector + * @param int $k + * @param int $numCandidates + */ + public function __construct( + string $field, + array $queryVector, + int $k, + int $numCandidates + ) { + $this->setField($field); + $this->setQueryVector($queryVector); + $this->setK($k); + $this->setNumCandidates($numCandidates); + } + + /** + * @return string + */ + public function getField(): string + { + return $this->field; + } + + /** + * @param string $field + */ + public function setField(string $field): void + { + $this->field = $field; + } + + /** + * @return array + */ + public function getQueryVector(): array + { + return $this->queryVector; + } + + /** + * @param array $queryVector + */ + public function setQueryVector(array $queryVector): void + { + $this->queryVector = $queryVector; + } + + /** + * @return int + */ + public function getK(): int + { + return $this->k; + } + + /** + * @param int $k + */ + public function setK(int $k): void + { + $this->k = $k; + } + + /** + * @return int + */ + public function getNumCandidates(): int + { + return $this->numCandidates; + } + + /** + * @param int $numCandidates + */ + public function setNumCandidates(int $numCandidates): void + { + $this->numCandidates = $numCandidates; + } + + /** + * @return int|null + */ + public function getSimilarity(): ?int + { + return $this->similarity; + } + + /** + * @param int $similarity + */ + public function setSimilarity(int $similarity): void + { + $this->similarity = $similarity; + } + + /** + * @return BuilderInterface|null + */ + public function getFilter(): ?BuilderInterface + { + return $this->filter; + } + + /** + * @param BuilderInterface $filter + */ + public function setFilter(BuilderInterface $filter): void + { + $this->filter = $filter; + } + + /** + * {@inheritdoc} + */ + public function getType() + { + return 'knn'; + } + + /** + * {@inheritdoc} + */ + public function toArray() + { + $output = [ + 'field' => $this->getField(), + 'query_vector' => $this->getQueryVector(), + 'k' => $this->getK(), + 'num_candidates' => $this->getNumCandidates(), + ]; + + if ($this->getFilter()) { + $output['filter'] = $this->getFilter()->toArray(); + } + + return $output; + } +} diff --git a/src/Search.php b/src/Search.php index 91c68fd7..1b6fe9fb 100644 --- a/src/Search.php +++ b/src/Search.php @@ -19,6 +19,7 @@ use ONGR\ElasticsearchDSL\SearchEndpoint\AggregationsEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\HighlightEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\InnerHitsEndpoint; +use ONGR\ElasticsearchDSL\SearchEndpoint\KnnEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\PostFilterEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\QueryEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\SearchEndpointFactory; @@ -229,6 +230,22 @@ public function addQuery(BuilderInterface $query, $boolType = BoolQuery::MUST, $ return $this; } + + /** + * Adds Knn to the search. + * + * @param BuilderInterface $query + * + * @return $this + */ + public function addKnn(BuilderInterface $query) + { + $endpoint = $this->getEndpoint(KnnEndpoint::NAME); + $endpoint->add($query); + + return $this; + } + /** * Returns endpoint instance. * diff --git a/src/SearchEndpoint/KnnEndpoint.php b/src/SearchEndpoint/KnnEndpoint.php new file mode 100644 index 00000000..7b20469e --- /dev/null +++ b/src/SearchEndpoint/KnnEndpoint.php @@ -0,0 +1,47 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace ONGR\ElasticsearchDSL\SearchEndpoint; + +use ONGR\ElasticsearchDSL\Knn\Knn; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * Search suggest dsl endpoint. + */ +class KnnEndpoint extends AbstractSearchEndpoint +{ + /** + * Endpoint name + */ + const NAME = 'knn'; + + /** + * {@inheritdoc} + */ + public function normalize( + NormalizerInterface $normalizer, + $format = null, + array $context = [] + ): array|string|int|float|bool { + $output = []; + if (count($this->getAll()) > 0) { + /** @var Knn $knn */ + foreach ($this->getAll() as $knn) { + if ($knn instanceof Knn) { + $output = $knn->toArray(); + } + } + } + + return $output; + } +} diff --git a/src/SearchEndpoint/SearchEndpointFactory.php b/src/SearchEndpoint/SearchEndpointFactory.php index 17e6838b..fb9149c2 100644 --- a/src/SearchEndpoint/SearchEndpointFactory.php +++ b/src/SearchEndpoint/SearchEndpointFactory.php @@ -20,13 +20,14 @@ class SearchEndpointFactory * @var array Holds namespaces for endpoints. */ private static $endpoints = [ - 'query' => 'ONGR\ElasticsearchDSL\SearchEndpoint\QueryEndpoint', - 'post_filter' => 'ONGR\ElasticsearchDSL\SearchEndpoint\PostFilterEndpoint', - 'sort' => 'ONGR\ElasticsearchDSL\SearchEndpoint\SortEndpoint', - 'highlight' => 'ONGR\ElasticsearchDSL\SearchEndpoint\HighlightEndpoint', - 'aggregations' => 'ONGR\ElasticsearchDSL\SearchEndpoint\AggregationsEndpoint', - 'suggest' => 'ONGR\ElasticsearchDSL\SearchEndpoint\SuggestEndpoint', - 'inner_hits' => 'ONGR\ElasticsearchDSL\SearchEndpoint\InnerHitsEndpoint', + 'query' => QueryEndpoint::class, + 'knn' => KnnEndpoint::class, + 'post_filter' => PostFilterEndpoint::class, + 'sort' => SortEndpoint::class, + 'highlight' => HighlightEndpoint::class, + 'aggregations' => AggregationsEndpoint::class, + 'suggest' => SuggestEndpoint::class, + 'inner_hits' => InnerHitsEndpoint::class, ]; /** diff --git a/src/SearchEndpoint/SuggestEndpoint.php b/src/SearchEndpoint/SuggestEndpoint.php index 9c9d8367..c996b2f7 100644 --- a/src/SearchEndpoint/SuggestEndpoint.php +++ b/src/SearchEndpoint/SuggestEndpoint.php @@ -11,7 +11,7 @@ namespace ONGR\ElasticsearchDSL\SearchEndpoint; -use ONGR\ElasticsearchDSL\Suggest\TermSuggest; +use ONGR\ElasticsearchDSL\Suggest\Suggest; use Symfony\Component\Serializer\Normalizer\NormalizerInterface; /** @@ -34,7 +34,7 @@ public function normalize( ): array|string|int|float|bool { $output = []; if (count($this->getAll()) > 0) { - /** @var TermSuggest $suggest */ + /** @var Suggest $suggest */ foreach ($this->getAll() as $suggest) { $output = array_merge($output, $suggest->toArray()); } diff --git a/src/Suggest/Suggest.php b/src/Suggest/Suggest.php index 1738d84d..6a639c48 100644 --- a/src/Suggest/Suggest.php +++ b/src/Suggest/Suggest.php @@ -13,7 +13,6 @@ use ONGR\ElasticsearchDSL\NamedBuilderInterface; use ONGR\ElasticsearchDSL\ParametersTrait; -use Symfony\Component\Serializer\Exception\InvalidArgumentException; class Suggest implements NamedBuilderInterface { diff --git a/tests/Unit/SearchEndpoint/SearchEndpointFactoryTest.php b/tests/Unit/SearchEndpoint/SearchEndpointFactoryTest.php index c928f19c..1a3725a9 100644 --- a/tests/Unit/SearchEndpoint/SearchEndpointFactoryTest.php +++ b/tests/Unit/SearchEndpoint/SearchEndpointFactoryTest.php @@ -34,8 +34,8 @@ public function testGet() */ public function testFactory() { - $endpoinnt = SearchEndpointFactory::get(AggregationsEndpoint::NAME); + $endpoint = SearchEndpointFactory::get(AggregationsEndpoint::NAME); - $this->assertInstanceOf(SearchEndpointInterface::class, $endpoinnt); + $this->assertInstanceOf(SearchEndpointInterface::class, $endpoint); } } From 2e303e0f62ec3a107cf581795049a8f6fadfe2a6 Mon Sep 17 00:00:00 2001 From: Haydar Kulekci Date: Fri, 6 Oct 2023 23:30:21 +0300 Subject: [PATCH 2/5] multiple vector support. test improvements. --- src/Knn/Knn.php | 17 +++-- src/SearchEndpoint/KnnEndpoint.php | 14 +++- tests/Unit/Knn/KnnTest.php | 66 +++++++++++++++++++ tests/Unit/SearchEndpoint/KnnEndpointTest.php | 64 ++++++++++++++++++ 4 files changed, 154 insertions(+), 7 deletions(-) create mode 100644 tests/Unit/Knn/KnnTest.php create mode 100644 tests/Unit/SearchEndpoint/KnnEndpointTest.php diff --git a/src/Knn/Knn.php b/src/Knn/Knn.php index 984e3a69..fe18b0b5 100644 --- a/src/Knn/Knn.php +++ b/src/Knn/Knn.php @@ -41,6 +41,11 @@ class Knn implements BuilderInterface /** * @var int */ + private $boost; + + /** + * @var float + */ private $similarity = null; /** @@ -133,17 +138,17 @@ public function setNumCandidates(int $numCandidates): void } /** - * @return int|null + * @return float|null */ - public function getSimilarity(): ?int + public function getSimilarity(): ?float { return $this->similarity; } /** - * @param int $similarity + * @param float $similarity */ - public function setSimilarity(int $similarity): void + public function setSimilarity(float $similarity): void { $this->similarity = $similarity; } @@ -184,6 +189,10 @@ public function toArray() 'num_candidates' => $this->getNumCandidates(), ]; + if ($this->getSimilarity()) { + $output['similarity'] = $this->getSimilarity(); + } + if ($this->getFilter()) { $output['filter'] = $this->getFilter()->toArray(); } diff --git a/src/SearchEndpoint/KnnEndpoint.php b/src/SearchEndpoint/KnnEndpoint.php index 7b20469e..a25d32bf 100644 --- a/src/SearchEndpoint/KnnEndpoint.php +++ b/src/SearchEndpoint/KnnEndpoint.php @@ -11,6 +11,7 @@ namespace ONGR\ElasticsearchDSL\SearchEndpoint; +use ONGR\ElasticsearchDSL\BuilderInterface; use ONGR\ElasticsearchDSL\Knn\Knn; use Symfony\Component\Serializer\Normalizer\NormalizerInterface; @@ -24,6 +25,15 @@ class KnnEndpoint extends AbstractSearchEndpoint */ const NAME = 'knn'; + public function add(BuilderInterface $builder, $key = null) + { + if ($builder instanceof Knn) { + return parent::add($builder, $key); + } + + throw new \LogicException('You need to add Knn builder!'); + } + /** * {@inheritdoc} */ @@ -36,9 +46,7 @@ public function normalize( if (count($this->getAll()) > 0) { /** @var Knn $knn */ foreach ($this->getAll() as $knn) { - if ($knn instanceof Knn) { - $output = $knn->toArray(); - } + $output[] = $knn->toArray(); } } diff --git a/tests/Unit/Knn/KnnTest.php b/tests/Unit/Knn/KnnTest.php new file mode 100644 index 00000000..3e479bcc --- /dev/null +++ b/tests/Unit/Knn/KnnTest.php @@ -0,0 +1,66 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace ONGR\ElasticsearchDSL\Tests\Unit\Knn; + +use ONGR\ElasticsearchDSL\Knn\Knn; +use ONGR\ElasticsearchDSL\Query\MatchAllQuery; + +class KnnTest extends \PHPUnit\Framework\TestCase +{ + /** + * Tests toArray(). + */ + public function testToArray(): void + { + $query = new Knn('field', [1, 2, 3], 10, 100); + $this->assertEquals([ + 'field' => 'field', + 'query_vector' => [1, 2, 3], + 'k' => 10, + 'num_candidates' => 100 + ], $query->toArray()); + } + + /** + * Tests toArray(). + */ + public function testToArrayWithFilter(): void + { + $query = new Knn('field', [1, 2, 3], 10, 100); + $query->setFilter(new MatchAllQuery()); + $this->assertEquals([ + 'field' => 'field', + 'query_vector' => [1, 2, 3], + 'k' => 10, + 'num_candidates' => 100, + 'filter' => [ + 'match_all' => new \stdClass() + ] + ], $query->toArray()); + } + + /** + * Tests toArray(). + */ + public function testToArrayWithSimilarity(): void + { + $query = new Knn('field', [1, 2, 3], 10, 100); + $query->setSimilarity(1); + $this->assertEquals([ + 'field' => 'field', + 'query_vector' => [1, 2, 3], + 'k' => 10, + 'num_candidates' => 100, + 'similarity' => 1 + ], $query->toArray()); + } +} diff --git a/tests/Unit/SearchEndpoint/KnnEndpointTest.php b/tests/Unit/SearchEndpoint/KnnEndpointTest.php new file mode 100644 index 00000000..84999d3b --- /dev/null +++ b/tests/Unit/SearchEndpoint/KnnEndpointTest.php @@ -0,0 +1,64 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace ONGR\ElasticsearchDSL\Tests\Unit\SearchEndpoint; + +use ONGR\ElasticsearchDSL\Knn\Knn; +use ONGR\ElasticsearchDSL\SearchEndpoint\KnnEndpoint; +use PHPUnit\Framework\TestCase; + +/** + * Class AggregationsEndpointTest. + */ +class KnnEndpointTest extends TestCase +{ + /** + * Tests constructor. + */ + public function testItCanBeInstantiated(): void + { + $this->assertInstanceOf( + KnnEndpoint::class, + new KnnEndpoint() + ); + } + + /** + * Tests if endpoint returns builders. + */ + public function testEndpointGetter(): void + { + $knn = new Knn('acme', [1, 2, 3], 10, 100); + $endpoint = new KnnEndpoint(); + $endpoint->add($knn, 'knn'); + $builders = $endpoint->getAll(); + + $this->assertCount(1, $builders); + $this->assertSame($knn, $builders['knn']); + } + + /** + * Tests if endpoint returns builders. + */ + public function testEndpointWithMultipleVector(): void + { + $knn1 = new Knn('acme', [1, 2, 3], 10, 100); + $knn2 = new Knn('acme', [1, 2, 4], 10, 100); + $endpoint = new KnnEndpoint(); + $endpoint->add($knn1, 'knn1'); + $endpoint->add($knn2, 'knn2'); + $builders = $endpoint->getAll(); + + $this->assertCount(2, $builders); + $this->assertSame($knn1, $builders['knn1']); + $this->assertSame($knn2, $builders['knn2']); + } +} From 48b8e3d3b5c2334560e428bd5a47e89dd7760b53 Mon Sep 17 00:00:00 2001 From: Haydar Kulekci Date: Fri, 6 Oct 2023 23:32:33 +0300 Subject: [PATCH 3/5] test improvement for exception --- src/SearchEndpoint/KnnEndpoint.php | 2 +- tests/Unit/SearchEndpoint/KnnEndpointTest.php | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/SearchEndpoint/KnnEndpoint.php b/src/SearchEndpoint/KnnEndpoint.php index a25d32bf..4bd89c00 100644 --- a/src/SearchEndpoint/KnnEndpoint.php +++ b/src/SearchEndpoint/KnnEndpoint.php @@ -31,7 +31,7 @@ public function add(BuilderInterface $builder, $key = null) return parent::add($builder, $key); } - throw new \LogicException('You need to add Knn builder!'); + throw new \LogicException('Add Knn builder instead!'); } /** diff --git a/tests/Unit/SearchEndpoint/KnnEndpointTest.php b/tests/Unit/SearchEndpoint/KnnEndpointTest.php index 84999d3b..ad1e447e 100644 --- a/tests/Unit/SearchEndpoint/KnnEndpointTest.php +++ b/tests/Unit/SearchEndpoint/KnnEndpointTest.php @@ -12,6 +12,7 @@ namespace ONGR\ElasticsearchDSL\Tests\Unit\SearchEndpoint; use ONGR\ElasticsearchDSL\Knn\Knn; +use ONGR\ElasticsearchDSL\Query\MatchAllQuery; use ONGR\ElasticsearchDSL\SearchEndpoint\KnnEndpoint; use PHPUnit\Framework\TestCase; @@ -45,6 +46,18 @@ public function testEndpointGetter(): void $this->assertSame($knn, $builders['knn']); } + /** + * Tests if endpoint returns builders. + */ + public function testEndpointException(): void + { + $this->expectException(\LogicException::class); + $this->expectExceptionMessage('Add Knn builder instead!'); + $knn = new MatchAllQuery(); + $endpoint = new KnnEndpoint(); + $endpoint->add($knn); + } + /** * Tests if endpoint returns builders. */ From 96d4f5ac0a917d5cd3226e8a24e0c2ca17dc9486 Mon Sep 17 00:00:00 2001 From: Haydar Kulekci Date: Sat, 7 Oct 2023 21:10:04 +0300 Subject: [PATCH 4/5] Knn functional tests elasticsearch version changed as 8.4.0 for github ci/cd --- .github/workflows/test-application.yaml | 2 +- src/SearchEndpoint/KnnEndpoint.php | 16 +++- .../AbstractElasticsearchTestCase.php | 21 ++--- tests/Functional/Knn/KnnTest.php | 85 +++++++++++++++++++ 4 files changed, 109 insertions(+), 15 deletions(-) create mode 100644 tests/Functional/Knn/KnnTest.php diff --git a/.github/workflows/test-application.yaml b/.github/workflows/test-application.yaml index 7c93b42e..607dbae9 100644 --- a/.github/workflows/test-application.yaml +++ b/.github/workflows/test-application.yaml @@ -50,7 +50,7 @@ jobs: services: elasticsearch: - image: elasticsearch:8.0.0 + image: elasticsearch:8.4.0 ports: - 9200:9200 env: diff --git a/src/SearchEndpoint/KnnEndpoint.php b/src/SearchEndpoint/KnnEndpoint.php index 4bd89c00..0cc55676 100644 --- a/src/SearchEndpoint/KnnEndpoint.php +++ b/src/SearchEndpoint/KnnEndpoint.php @@ -42,14 +42,22 @@ public function normalize( $format = null, array $context = [] ): array|string|int|float|bool { - $output = []; - if (count($this->getAll()) > 0) { + $knns = $this->getAll(); + if (count($knns) === 1) { /** @var Knn $knn */ - foreach ($this->getAll() as $knn) { + $knn = array_values($knns)[0]; + return $knn->toArray(); + } + + if (count($knns) > 1) { + $output = []; + /** @var Knn $knn */ + foreach ($knns as $knn) { $output[] = $knn->toArray(); } + return $output; } - return $output; + return []; } } diff --git a/tests/Functional/AbstractElasticsearchTestCase.php b/tests/Functional/AbstractElasticsearchTestCase.php index 6fa1a107..67b175e7 100644 --- a/tests/Functional/AbstractElasticsearchTestCase.php +++ b/tests/Functional/AbstractElasticsearchTestCase.php @@ -21,7 +21,7 @@ abstract class AbstractElasticsearchTestCase extends TestCase /** * Test index name in the elasticsearch. */ - const INDEX_NAME = 'elasticsaerch-dsl-test'; + const INDEX_NAME = 'elasticsearch-dsl-test'; /** * @var Client @@ -35,17 +35,18 @@ protected function setUp(): void { parent::setUp(); - $this->client = ClientBuilder::create()->build(); + $this->client = ClientBuilder::create() + ->build(); $this->deleteIndex(); - $this->client->indices()->create( - array_filter( - [ - 'index' => self::INDEX_NAME, - 'mapping' => $this->getMapping() - ] - ) - ); + $params = [ + 'index' => self::INDEX_NAME + ]; + if ($this->getMapping()) { + $params['body']['mappings'] = $this->getMapping(); + } + + $response = $this->client->indices()->create($params); $bulkBody = []; diff --git a/tests/Functional/Knn/KnnTest.php b/tests/Functional/Knn/KnnTest.php new file mode 100644 index 00000000..8eae63fc --- /dev/null +++ b/tests/Functional/Knn/KnnTest.php @@ -0,0 +1,85 @@ + + */ + +namespace ONGR\ElasticsearchDSL\Tests\Functional\Knn; + +use ONGR\ElasticsearchDSL\Aggregation\Bucketing\DateHistogramAggregation; +use ONGR\ElasticsearchDSL\Knn\Knn; +use ONGR\ElasticsearchDSL\Query\TermLevel\TermQuery; +use ONGR\ElasticsearchDSL\Search; +use ONGR\ElasticsearchDSL\Tests\Functional\AbstractElasticsearchTestCase; + +class KnnTest extends AbstractElasticsearchTestCase +{ + /** + * {@inheritdoc} + */ + protected function getDataArray(): array + { + return [ + 'knn_data' => [ + 'doc_1' => [ + 'label' => 1, + 'vector_field' => [1, 2, 3], + ], + 'doc_2' => [ + 'label' => 1, + 'vector_field' => [1, 2, 4], + ], + 'doc_3' => [ + 'label' => 2, + 'vector_field' => [1, 2, 30], + ], + ] + ]; + } + + protected function getMapping(): array + { + return [ + 'properties' => [ + 'label' => [ + 'type' => 'long' + ], + 'vector_field' => [ + 'type' => 'dense_vector', + 'similarity' => 'cosine', + 'index' => true, + 'dims' => 3 + ] + ] + ]; + } + + /** + * Match all test + */ + public function testKnnSearch(): void + { + $knn = new Knn('vector_field', [1, 2, 3], 1, 1); + + $search = new Search(); + $search->addKnn($knn); + $results = $this->executeSearch($search, true); + $this->assertCount(1, $results['hits']['hits']); + $this->assertEquals('doc_1', $results['hits']['hits'][0]['_id']); + } + + /** + * Match all test + */ + public function testKnnSearchWithFilter(): void + { + $knn = new Knn('vector_field', [1, 2, 3], 1, 1); + $knn->setFilter(new TermQuery('label', 2)); + + $search = new Search(); + $search->addKnn($knn); + $results = $this->executeSearch($search, true); + $this->assertCount(1, $results['hits']['hits']); + $this->assertEquals('doc_3', $results['hits']['hits'][0]['_id']); + } +} From 7ab7a8389632a8792e08d35bf78f031ce83be936 Mon Sep 17 00:00:00 2001 From: Haydar Kulekci Date: Fri, 17 Nov 2023 15:13:59 +0300 Subject: [PATCH 5/5] added boost field setter and getter for Knn Search --- src/Knn/Knn.php | 22 +++++++++++++++++++++- tests/Functional/Knn/KnnTest.php | 19 ++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/Knn/Knn.php b/src/Knn/Knn.php index fe18b0b5..07b6489a 100644 --- a/src/Knn/Knn.php +++ b/src/Knn/Knn.php @@ -39,7 +39,7 @@ class Knn implements BuilderInterface private $numCandidates; /** - * @var int + * @var float|null */ private $boost; @@ -153,6 +153,22 @@ public function setSimilarity(float $similarity): void $this->similarity = $similarity; } + /** + * @return float|null + */ + public function getBoost(): ?float + { + return $this->boost; + } + + /** + * @param float $boost + */ + public function setBoost(float $boost): void + { + $this->boost = $boost; + } + /** * @return BuilderInterface|null */ @@ -193,6 +209,10 @@ public function toArray() $output['similarity'] = $this->getSimilarity(); } + if ($this->getBoost()) { + $output['boost'] = $this->getBoost(); + } + if ($this->getFilter()) { $output['filter'] = $this->getFilter()->toArray(); } diff --git a/tests/Functional/Knn/KnnTest.php b/tests/Functional/Knn/KnnTest.php index 8eae63fc..9aac8566 100644 --- a/tests/Functional/Knn/KnnTest.php +++ b/tests/Functional/Knn/KnnTest.php @@ -6,7 +6,8 @@ namespace ONGR\ElasticsearchDSL\Tests\Functional\Knn; -use ONGR\ElasticsearchDSL\Aggregation\Bucketing\DateHistogramAggregation; +use Composer\InstalledVersions; +use Elastic\Elasticsearch\Client; use ONGR\ElasticsearchDSL\Knn\Knn; use ONGR\ElasticsearchDSL\Query\TermLevel\TermQuery; use ONGR\ElasticsearchDSL\Search; @@ -82,4 +83,20 @@ public function testKnnSearchWithFilter(): void $this->assertCount(1, $results['hits']['hits']); $this->assertEquals('doc_3', $results['hits']['hits'][0]['_id']); } + + /** + * Match all test + */ + public function testMultipleKnnSearchWithBoost(): void + { + $knn1 = new Knn('vector_field', [1, 2, 3], 1, 1); + $knn1->setFilter(new TermQuery('label', 2)); + $knn1->setBoost(0.5); + + $search = new Search(); + $search->addKnn($knn1); + $results = $this->executeSearch($search, true); + $this->assertCount(1, $results['hits']['hits']); + $this->assertEquals('doc_3', $results['hits']['hits'][0]['_id']); + } }