Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knn support #4

Merged
merged 5 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:

services:
elasticsearch:
image: elasticsearch:8.0.0
image: elasticsearch:8.4.0
ports:
- 9200:9200
env:
Expand Down
222 changes: 222 additions & 0 deletions src/Knn/Knn.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
<?php

/*
* This file is part of the ONGR package.
*
* (c) NFQ Technologies UAB <info@nfq.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace ONGR\ElasticsearchDSL\Knn;

use ONGR\ElasticsearchDSL\BuilderInterface;
use ONGR\ElasticsearchDSL\FieldAwareTrait;

class Knn implements BuilderInterface
{
use FieldAwareTrait;

/**
* @var string
*/
private $field;

/**
* @var array
*/
private $queryVector;

/**
* @var int
*/
private $k;

/**
* @var int
*/
private $numCandidates;

/**
* @var float|null
*/
private $boost;

/**
* @var float
*/
private $similarity = null;

/**
* @var BuilderInterface
*/
private $filter = null;


/**
* TermSuggest constructor.
* @param string $field
* @param array $queryVector
* @param int $k
* @param int $numCandidates
*/
public function __construct(
string $field,
array $queryVector,
int $k,
int $numCandidates
) {
$this->setField($field);
$this->setQueryVector($queryVector);
$this->setK($k);
$this->setNumCandidates($numCandidates);
}

/**
* @return string
*/
public function getField(): string
{
return $this->field;
}

/**
* @param string $field
*/
public function setField(string $field): void
{
$this->field = $field;
}

/**
* @return array
*/
public function getQueryVector(): array
{
return $this->queryVector;
}

/**
* @param array $queryVector
*/
public function setQueryVector(array $queryVector): void
{
$this->queryVector = $queryVector;
}

/**
* @return int
*/
public function getK(): int
{
return $this->k;
}

/**
* @param int $k
*/
public function setK(int $k): void
{
$this->k = $k;
}

/**
* @return int
*/
public function getNumCandidates(): int
{
return $this->numCandidates;
}

/**
* @param int $numCandidates
*/
public function setNumCandidates(int $numCandidates): void
{
$this->numCandidates = $numCandidates;
}

/**
* @return float|null
*/
public function getSimilarity(): ?float
{
return $this->similarity;
}

/**
* @param float $similarity
*/
public function setSimilarity(float $similarity): void
{
$this->similarity = $similarity;
}

/**
* @return float|null
*/
public function getBoost(): ?float
{
return $this->boost;
}

/**
* @param float $boost
*/
public function setBoost(float $boost): void
{
$this->boost = $boost;
}

/**
* @return BuilderInterface|null
*/
public function getFilter(): ?BuilderInterface
{
return $this->filter;
}

/**
* @param BuilderInterface $filter
*/
public function setFilter(BuilderInterface $filter): void
{
$this->filter = $filter;
}

/**
* {@inheritdoc}
*/
public function getType()
{
return 'knn';
}

/**
* {@inheritdoc}
*/
public function toArray()
{
$output = [
'field' => $this->getField(),
'query_vector' => $this->getQueryVector(),
'k' => $this->getK(),
'num_candidates' => $this->getNumCandidates(),
];

if ($this->getSimilarity()) {
$output['similarity'] = $this->getSimilarity();
}

if ($this->getBoost()) {
$output['boost'] = $this->getBoost();
}

if ($this->getFilter()) {
$output['filter'] = $this->getFilter()->toArray();
}

return $output;
}
}
17 changes: 17 additions & 0 deletions src/Search.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use ONGR\ElasticsearchDSL\SearchEndpoint\AggregationsEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\HighlightEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\InnerHitsEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\KnnEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\PostFilterEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\QueryEndpoint;
use ONGR\ElasticsearchDSL\SearchEndpoint\SearchEndpointFactory;
Expand Down Expand Up @@ -229,6 +230,22 @@ public function addQuery(BuilderInterface $query, $boolType = BoolQuery::MUST, $
return $this;
}


/**
* Adds Knn to the search.
*
* @param BuilderInterface $query
*
* @return $this
*/
public function addKnn(BuilderInterface $query)
{
$endpoint = $this->getEndpoint(KnnEndpoint::NAME);
$endpoint->add($query);

return $this;
}

/**
* Returns endpoint instance.
*
Expand Down
63 changes: 63 additions & 0 deletions src/SearchEndpoint/KnnEndpoint.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
<?php

/*
* This file is part of the ONGR package.
*
* (c) NFQ Technologies UAB <info@nfq.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace ONGR\ElasticsearchDSL\SearchEndpoint;

use ONGR\ElasticsearchDSL\BuilderInterface;
use ONGR\ElasticsearchDSL\Knn\Knn;
use Symfony\Component\Serializer\Normalizer\NormalizerInterface;

/**
* Search suggest dsl endpoint.
*/
class KnnEndpoint extends AbstractSearchEndpoint
{
/**
* Endpoint name
*/
const NAME = 'knn';

public function add(BuilderInterface $builder, $key = null)
{
if ($builder instanceof Knn) {
return parent::add($builder, $key);
}

throw new \LogicException('Add Knn builder instead!');
}

/**
* {@inheritdoc}
*/
public function normalize(
NormalizerInterface $normalizer,
$format = null,
array $context = []
): array|string|int|float|bool {
$knns = $this->getAll();
if (count($knns) === 1) {
/** @var Knn $knn */
$knn = array_values($knns)[0];
return $knn->toArray();
}

if (count($knns) > 1) {
$output = [];
/** @var Knn $knn */
foreach ($knns as $knn) {
$output[] = $knn->toArray();
}
return $output;
}

return [];
}
}
15 changes: 8 additions & 7 deletions src/SearchEndpoint/SearchEndpointFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ class SearchEndpointFactory
* @var array Holds namespaces for endpoints.
*/
private static $endpoints = [
'query' => 'ONGR\ElasticsearchDSL\SearchEndpoint\QueryEndpoint',
'post_filter' => 'ONGR\ElasticsearchDSL\SearchEndpoint\PostFilterEndpoint',
'sort' => 'ONGR\ElasticsearchDSL\SearchEndpoint\SortEndpoint',
'highlight' => 'ONGR\ElasticsearchDSL\SearchEndpoint\HighlightEndpoint',
'aggregations' => 'ONGR\ElasticsearchDSL\SearchEndpoint\AggregationsEndpoint',
'suggest' => 'ONGR\ElasticsearchDSL\SearchEndpoint\SuggestEndpoint',
'inner_hits' => 'ONGR\ElasticsearchDSL\SearchEndpoint\InnerHitsEndpoint',
'query' => QueryEndpoint::class,
'knn' => KnnEndpoint::class,
'post_filter' => PostFilterEndpoint::class,
'sort' => SortEndpoint::class,
'highlight' => HighlightEndpoint::class,
'aggregations' => AggregationsEndpoint::class,
'suggest' => SuggestEndpoint::class,
'inner_hits' => InnerHitsEndpoint::class,
];

/**
Expand Down
4 changes: 2 additions & 2 deletions src/SearchEndpoint/SuggestEndpoint.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace ONGR\ElasticsearchDSL\SearchEndpoint;

use ONGR\ElasticsearchDSL\Suggest\TermSuggest;
use ONGR\ElasticsearchDSL\Suggest\Suggest;
use Symfony\Component\Serializer\Normalizer\NormalizerInterface;

/**
Expand All @@ -34,7 +34,7 @@ public function normalize(
): array|string|int|float|bool {
$output = [];
if (count($this->getAll()) > 0) {
/** @var TermSuggest $suggest */
/** @var Suggest $suggest */
foreach ($this->getAll() as $suggest) {
$output = array_merge($output, $suggest->toArray());
}
Expand Down
1 change: 0 additions & 1 deletion src/Suggest/Suggest.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

use ONGR\ElasticsearchDSL\NamedBuilderInterface;
use ONGR\ElasticsearchDSL\ParametersTrait;
use Symfony\Component\Serializer\Exception\InvalidArgumentException;

class Suggest implements NamedBuilderInterface
{
Expand Down
Loading
Loading