diff --git a/.github/workflows/test-application.yaml b/.github/workflows/test-application.yaml index 7c93b42e..607dbae9 100644 --- a/.github/workflows/test-application.yaml +++ b/.github/workflows/test-application.yaml @@ -50,7 +50,7 @@ jobs: services: elasticsearch: - image: elasticsearch:8.0.0 + image: elasticsearch:8.4.0 ports: - 9200:9200 env: diff --git a/src/Knn/Knn.php b/src/Knn/Knn.php new file mode 100644 index 00000000..07b6489a --- /dev/null +++ b/src/Knn/Knn.php @@ -0,0 +1,222 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace ONGR\ElasticsearchDSL\Knn; + +use ONGR\ElasticsearchDSL\BuilderInterface; +use ONGR\ElasticsearchDSL\FieldAwareTrait; + +class Knn implements BuilderInterface +{ + use FieldAwareTrait; + + /** + * @var string + */ + private $field; + + /** + * @var array + */ + private $queryVector; + + /** + * @var int + */ + private $k; + + /** + * @var int + */ + private $numCandidates; + + /** + * @var float|null + */ + private $boost; + + /** + * @var float + */ + private $similarity = null; + + /** + * @var BuilderInterface + */ + private $filter = null; + + + /** + * TermSuggest constructor. + * @param string $field + * @param array $queryVector + * @param int $k + * @param int $numCandidates + */ + public function __construct( + string $field, + array $queryVector, + int $k, + int $numCandidates + ) { + $this->setField($field); + $this->setQueryVector($queryVector); + $this->setK($k); + $this->setNumCandidates($numCandidates); + } + + /** + * @return string + */ + public function getField(): string + { + return $this->field; + } + + /** + * @param string $field + */ + public function setField(string $field): void + { + $this->field = $field; + } + + /** + * @return array + */ + public function getQueryVector(): array + { + return $this->queryVector; + } + + /** + * @param array $queryVector + */ + public function setQueryVector(array $queryVector): void + { + $this->queryVector = $queryVector; + } + + /** + * @return int + */ + public function getK(): int + { + return $this->k; + } + + /** + * @param int $k + */ + public function setK(int $k): void + { + $this->k = $k; + } + + /** + * @return int + */ + public function getNumCandidates(): int + { + return $this->numCandidates; + } + + /** + * @param int $numCandidates + */ + public function setNumCandidates(int $numCandidates): void + { + $this->numCandidates = $numCandidates; + } + + /** + * @return float|null + */ + public function getSimilarity(): ?float + { + return $this->similarity; + } + + /** + * @param float $similarity + */ + public function setSimilarity(float $similarity): void + { + $this->similarity = $similarity; + } + + /** + * @return float|null + */ + public function getBoost(): ?float + { + return $this->boost; + } + + /** + * @param float $boost + */ + public function setBoost(float $boost): void + { + $this->boost = $boost; + } + + /** + * @return BuilderInterface|null + */ + public function getFilter(): ?BuilderInterface + { + return $this->filter; + } + + /** + * @param BuilderInterface $filter + */ + public function setFilter(BuilderInterface $filter): void + { + $this->filter = $filter; + } + + /** + * {@inheritdoc} + */ + public function getType() + { + return 'knn'; + } + + /** + * {@inheritdoc} + */ + public function toArray() + { + $output = [ + 'field' => $this->getField(), + 'query_vector' => $this->getQueryVector(), + 'k' => $this->getK(), + 'num_candidates' => $this->getNumCandidates(), + ]; + + if ($this->getSimilarity()) { + $output['similarity'] = $this->getSimilarity(); + } + + if ($this->getBoost()) { + $output['boost'] = $this->getBoost(); + } + + if ($this->getFilter()) { + $output['filter'] = $this->getFilter()->toArray(); + } + + return $output; + } +} diff --git a/src/Search.php b/src/Search.php index 91c68fd7..1b6fe9fb 100644 --- a/src/Search.php +++ b/src/Search.php @@ -19,6 +19,7 @@ use ONGR\ElasticsearchDSL\SearchEndpoint\AggregationsEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\HighlightEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\InnerHitsEndpoint; +use ONGR\ElasticsearchDSL\SearchEndpoint\KnnEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\PostFilterEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\QueryEndpoint; use ONGR\ElasticsearchDSL\SearchEndpoint\SearchEndpointFactory; @@ -229,6 +230,22 @@ public function addQuery(BuilderInterface $query, $boolType = BoolQuery::MUST, $ return $this; } + + /** + * Adds Knn to the search. + * + * @param BuilderInterface $query + * + * @return $this + */ + public function addKnn(BuilderInterface $query) + { + $endpoint = $this->getEndpoint(KnnEndpoint::NAME); + $endpoint->add($query); + + return $this; + } + /** * Returns endpoint instance. * diff --git a/src/SearchEndpoint/KnnEndpoint.php b/src/SearchEndpoint/KnnEndpoint.php new file mode 100644 index 00000000..0cc55676 --- /dev/null +++ b/src/SearchEndpoint/KnnEndpoint.php @@ -0,0 +1,63 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace ONGR\ElasticsearchDSL\SearchEndpoint; + +use ONGR\ElasticsearchDSL\BuilderInterface; +use ONGR\ElasticsearchDSL\Knn\Knn; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * Search suggest dsl endpoint. + */ +class KnnEndpoint extends AbstractSearchEndpoint +{ + /** + * Endpoint name + */ + const NAME = 'knn'; + + public function add(BuilderInterface $builder, $key = null) + { + if ($builder instanceof Knn) { + return parent::add($builder, $key); + } + + throw new \LogicException('Add Knn builder instead!'); + } + + /** + * {@inheritdoc} + */ + public function normalize( + NormalizerInterface $normalizer, + $format = null, + array $context = [] + ): array|string|int|float|bool { + $knns = $this->getAll(); + if (count($knns) === 1) { + /** @var Knn $knn */ + $knn = array_values($knns)[0]; + return $knn->toArray(); + } + + if (count($knns) > 1) { + $output = []; + /** @var Knn $knn */ + foreach ($knns as $knn) { + $output[] = $knn->toArray(); + } + return $output; + } + + return []; + } +} diff --git a/src/SearchEndpoint/SearchEndpointFactory.php b/src/SearchEndpoint/SearchEndpointFactory.php index 17e6838b..fb9149c2 100644 --- a/src/SearchEndpoint/SearchEndpointFactory.php +++ b/src/SearchEndpoint/SearchEndpointFactory.php @@ -20,13 +20,14 @@ class SearchEndpointFactory * @var array Holds namespaces for endpoints. */ private static $endpoints = [ - 'query' => 'ONGR\ElasticsearchDSL\SearchEndpoint\QueryEndpoint', - 'post_filter' => 'ONGR\ElasticsearchDSL\SearchEndpoint\PostFilterEndpoint', - 'sort' => 'ONGR\ElasticsearchDSL\SearchEndpoint\SortEndpoint', - 'highlight' => 'ONGR\ElasticsearchDSL\SearchEndpoint\HighlightEndpoint', - 'aggregations' => 'ONGR\ElasticsearchDSL\SearchEndpoint\AggregationsEndpoint', - 'suggest' => 'ONGR\ElasticsearchDSL\SearchEndpoint\SuggestEndpoint', - 'inner_hits' => 'ONGR\ElasticsearchDSL\SearchEndpoint\InnerHitsEndpoint', + 'query' => QueryEndpoint::class, + 'knn' => KnnEndpoint::class, + 'post_filter' => PostFilterEndpoint::class, + 'sort' => SortEndpoint::class, + 'highlight' => HighlightEndpoint::class, + 'aggregations' => AggregationsEndpoint::class, + 'suggest' => SuggestEndpoint::class, + 'inner_hits' => InnerHitsEndpoint::class, ]; /** diff --git a/src/SearchEndpoint/SuggestEndpoint.php b/src/SearchEndpoint/SuggestEndpoint.php index 9c9d8367..c996b2f7 100644 --- a/src/SearchEndpoint/SuggestEndpoint.php +++ b/src/SearchEndpoint/SuggestEndpoint.php @@ -11,7 +11,7 @@ namespace ONGR\ElasticsearchDSL\SearchEndpoint; -use ONGR\ElasticsearchDSL\Suggest\TermSuggest; +use ONGR\ElasticsearchDSL\Suggest\Suggest; use Symfony\Component\Serializer\Normalizer\NormalizerInterface; /** @@ -34,7 +34,7 @@ public function normalize( ): array|string|int|float|bool { $output = []; if (count($this->getAll()) > 0) { - /** @var TermSuggest $suggest */ + /** @var Suggest $suggest */ foreach ($this->getAll() as $suggest) { $output = array_merge($output, $suggest->toArray()); } diff --git a/src/Suggest/Suggest.php b/src/Suggest/Suggest.php index 1738d84d..6a639c48 100644 --- a/src/Suggest/Suggest.php +++ b/src/Suggest/Suggest.php @@ -13,7 +13,6 @@ use ONGR\ElasticsearchDSL\NamedBuilderInterface; use ONGR\ElasticsearchDSL\ParametersTrait; -use Symfony\Component\Serializer\Exception\InvalidArgumentException; class Suggest implements NamedBuilderInterface { diff --git a/tests/Functional/AbstractElasticsearchTestCase.php b/tests/Functional/AbstractElasticsearchTestCase.php index 6fa1a107..67b175e7 100644 --- a/tests/Functional/AbstractElasticsearchTestCase.php +++ b/tests/Functional/AbstractElasticsearchTestCase.php @@ -21,7 +21,7 @@ abstract class AbstractElasticsearchTestCase extends TestCase /** * Test index name in the elasticsearch. */ - const INDEX_NAME = 'elasticsaerch-dsl-test'; + const INDEX_NAME = 'elasticsearch-dsl-test'; /** * @var Client @@ -35,17 +35,18 @@ protected function setUp(): void { parent::setUp(); - $this->client = ClientBuilder::create()->build(); + $this->client = ClientBuilder::create() + ->build(); $this->deleteIndex(); - $this->client->indices()->create( - array_filter( - [ - 'index' => self::INDEX_NAME, - 'mapping' => $this->getMapping() - ] - ) - ); + $params = [ + 'index' => self::INDEX_NAME + ]; + if ($this->getMapping()) { + $params['body']['mappings'] = $this->getMapping(); + } + + $response = $this->client->indices()->create($params); $bulkBody = []; diff --git a/tests/Functional/Knn/KnnTest.php b/tests/Functional/Knn/KnnTest.php new file mode 100644 index 00000000..9aac8566 --- /dev/null +++ b/tests/Functional/Knn/KnnTest.php @@ -0,0 +1,102 @@ + + */ + +namespace ONGR\ElasticsearchDSL\Tests\Functional\Knn; + +use Composer\InstalledVersions; +use Elastic\Elasticsearch\Client; +use ONGR\ElasticsearchDSL\Knn\Knn; +use ONGR\ElasticsearchDSL\Query\TermLevel\TermQuery; +use ONGR\ElasticsearchDSL\Search; +use ONGR\ElasticsearchDSL\Tests\Functional\AbstractElasticsearchTestCase; + +class KnnTest extends AbstractElasticsearchTestCase +{ + /** + * {@inheritdoc} + */ + protected function getDataArray(): array + { + return [ + 'knn_data' => [ + 'doc_1' => [ + 'label' => 1, + 'vector_field' => [1, 2, 3], + ], + 'doc_2' => [ + 'label' => 1, + 'vector_field' => [1, 2, 4], + ], + 'doc_3' => [ + 'label' => 2, + 'vector_field' => [1, 2, 30], + ], + ] + ]; + } + + protected function getMapping(): array + { + return [ + 'properties' => [ + 'label' => [ + 'type' => 'long' + ], + 'vector_field' => [ + 'type' => 'dense_vector', + 'similarity' => 'cosine', + 'index' => true, + 'dims' => 3 + ] + ] + ]; + } + + /** + * Match all test + */ + public function testKnnSearch(): void + { + $knn = new Knn('vector_field', [1, 2, 3], 1, 1); + + $search = new Search(); + $search->addKnn($knn); + $results = $this->executeSearch($search, true); + $this->assertCount(1, $results['hits']['hits']); + $this->assertEquals('doc_1', $results['hits']['hits'][0]['_id']); + } + + /** + * Match all test + */ + public function testKnnSearchWithFilter(): void + { + $knn = new Knn('vector_field', [1, 2, 3], 1, 1); + $knn->setFilter(new TermQuery('label', 2)); + + $search = new Search(); + $search->addKnn($knn); + $results = $this->executeSearch($search, true); + $this->assertCount(1, $results['hits']['hits']); + $this->assertEquals('doc_3', $results['hits']['hits'][0]['_id']); + } + + /** + * Match all test + */ + public function testMultipleKnnSearchWithBoost(): void + { + $knn1 = new Knn('vector_field', [1, 2, 3], 1, 1); + $knn1->setFilter(new TermQuery('label', 2)); + $knn1->setBoost(0.5); + + $search = new Search(); + $search->addKnn($knn1); + $results = $this->executeSearch($search, true); + $this->assertCount(1, $results['hits']['hits']); + $this->assertEquals('doc_3', $results['hits']['hits'][0]['_id']); + } +} diff --git a/tests/Unit/Knn/KnnTest.php b/tests/Unit/Knn/KnnTest.php new file mode 100644 index 00000000..3e479bcc --- /dev/null +++ b/tests/Unit/Knn/KnnTest.php @@ -0,0 +1,66 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace ONGR\ElasticsearchDSL\Tests\Unit\Knn; + +use ONGR\ElasticsearchDSL\Knn\Knn; +use ONGR\ElasticsearchDSL\Query\MatchAllQuery; + +class KnnTest extends \PHPUnit\Framework\TestCase +{ + /** + * Tests toArray(). + */ + public function testToArray(): void + { + $query = new Knn('field', [1, 2, 3], 10, 100); + $this->assertEquals([ + 'field' => 'field', + 'query_vector' => [1, 2, 3], + 'k' => 10, + 'num_candidates' => 100 + ], $query->toArray()); + } + + /** + * Tests toArray(). + */ + public function testToArrayWithFilter(): void + { + $query = new Knn('field', [1, 2, 3], 10, 100); + $query->setFilter(new MatchAllQuery()); + $this->assertEquals([ + 'field' => 'field', + 'query_vector' => [1, 2, 3], + 'k' => 10, + 'num_candidates' => 100, + 'filter' => [ + 'match_all' => new \stdClass() + ] + ], $query->toArray()); + } + + /** + * Tests toArray(). + */ + public function testToArrayWithSimilarity(): void + { + $query = new Knn('field', [1, 2, 3], 10, 100); + $query->setSimilarity(1); + $this->assertEquals([ + 'field' => 'field', + 'query_vector' => [1, 2, 3], + 'k' => 10, + 'num_candidates' => 100, + 'similarity' => 1 + ], $query->toArray()); + } +} diff --git a/tests/Unit/SearchEndpoint/KnnEndpointTest.php b/tests/Unit/SearchEndpoint/KnnEndpointTest.php new file mode 100644 index 00000000..ad1e447e --- /dev/null +++ b/tests/Unit/SearchEndpoint/KnnEndpointTest.php @@ -0,0 +1,77 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace ONGR\ElasticsearchDSL\Tests\Unit\SearchEndpoint; + +use ONGR\ElasticsearchDSL\Knn\Knn; +use ONGR\ElasticsearchDSL\Query\MatchAllQuery; +use ONGR\ElasticsearchDSL\SearchEndpoint\KnnEndpoint; +use PHPUnit\Framework\TestCase; + +/** + * Class AggregationsEndpointTest. + */ +class KnnEndpointTest extends TestCase +{ + /** + * Tests constructor. + */ + public function testItCanBeInstantiated(): void + { + $this->assertInstanceOf( + KnnEndpoint::class, + new KnnEndpoint() + ); + } + + /** + * Tests if endpoint returns builders. + */ + public function testEndpointGetter(): void + { + $knn = new Knn('acme', [1, 2, 3], 10, 100); + $endpoint = new KnnEndpoint(); + $endpoint->add($knn, 'knn'); + $builders = $endpoint->getAll(); + + $this->assertCount(1, $builders); + $this->assertSame($knn, $builders['knn']); + } + + /** + * Tests if endpoint returns builders. + */ + public function testEndpointException(): void + { + $this->expectException(\LogicException::class); + $this->expectExceptionMessage('Add Knn builder instead!'); + $knn = new MatchAllQuery(); + $endpoint = new KnnEndpoint(); + $endpoint->add($knn); + } + + /** + * Tests if endpoint returns builders. + */ + public function testEndpointWithMultipleVector(): void + { + $knn1 = new Knn('acme', [1, 2, 3], 10, 100); + $knn2 = new Knn('acme', [1, 2, 4], 10, 100); + $endpoint = new KnnEndpoint(); + $endpoint->add($knn1, 'knn1'); + $endpoint->add($knn2, 'knn2'); + $builders = $endpoint->getAll(); + + $this->assertCount(2, $builders); + $this->assertSame($knn1, $builders['knn1']); + $this->assertSame($knn2, $builders['knn2']); + } +} diff --git a/tests/Unit/SearchEndpoint/SearchEndpointFactoryTest.php b/tests/Unit/SearchEndpoint/SearchEndpointFactoryTest.php index c928f19c..1a3725a9 100644 --- a/tests/Unit/SearchEndpoint/SearchEndpointFactoryTest.php +++ b/tests/Unit/SearchEndpoint/SearchEndpointFactoryTest.php @@ -34,8 +34,8 @@ public function testGet() */ public function testFactory() { - $endpoinnt = SearchEndpointFactory::get(AggregationsEndpoint::NAME); + $endpoint = SearchEndpointFactory::get(AggregationsEndpoint::NAME); - $this->assertInstanceOf(SearchEndpointInterface::class, $endpoinnt); + $this->assertInstanceOf(SearchEndpointInterface::class, $endpoint); } }