Skip to content

Commit

Permalink
feat: bigframes.bigquery.vector_search supports use_brute_force a…
Browse files Browse the repository at this point in the history
…nd `fraction_lists_to_search` parameters (#1158)

* feat: `bigframes.bigquery.vector_search` supports `use_brute_force` and `fraction_lists_to_search` parameters

* fix f-string on lower python versions

---------

Co-authored-by: Chelsea Lin <chelsealin@google.com>
  • Loading branch information
tswast and chelsea-lin authored Nov 21, 2024
1 parent 5b355ef commit 131edc3
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 103 deletions.
53 changes: 26 additions & 27 deletions bigframes/bigquery/_operations/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import typing
from typing import Collection, Literal, Mapping, Optional, Union

import bigframes_vendored.constants as constants
import google.cloud.bigquery as bigquery

import bigframes.core.sql
Expand Down Expand Up @@ -96,10 +95,10 @@ def vector_search(
query: Union[dataframe.DataFrame, series.Series],
*,
query_column_to_search: Optional[str] = None,
top_k: Optional[int] = 10,
distance_type: Literal["euclidean", "cosine"] = "euclidean",
top_k: Optional[int] = None,
distance_type: Optional[Literal["euclidean", "cosine", "dot_product"]] = None,
fraction_lists_to_search: Optional[float] = None,
use_brute_force: bool = False,
use_brute_force: Optional[bool] = None,
) -> dataframe.DataFrame:
"""
Conduct vector search which searches embeddings to find semantically similar entities.
Expand Down Expand Up @@ -141,7 +140,8 @@ def vector_search(
... base_table="bigframes-dev.bigframes_tests_sys.base_table",
... column_to_search="my_embedding",
... query=search_query,
... top_k=2)
... top_k=2,
... use_brute_force=True)
embedding id my_embedding distance
dog [1. 2.] 1 [1. 2.] 0.0
cat [3. 5.2] 5 [5. 5.4] 2.009975
Expand Down Expand Up @@ -185,17 +185,18 @@ def vector_search(
find nearest neighbors. The column must have a type of ``ARRAY<FLOAT64>``. All elements in
the array must be non-NULL and all values in the column must have the same array dimensions
as the values in the ``column_to_search`` column. Can only be set when query is a DataFrame.
top_k (int, default 10):
top_k (int):
Sepecifies the number of nearest neighbors to return. Default to 10.
distance_type (str, defalt "euclidean"):
Specifies the type of metric to use to compute the distance between two vectors.
Possible values are "euclidean" and "cosine". Default to "euclidean".
Possible values are "euclidean", "cosine" and "dot_product".
Default to "euclidean".
fraction_lists_to_search (float, range in [0.0, 1.0]):
Specifies the percentage of lists to search. Specifying a higher percentage leads to
higher recall and slower performance, and the converse is true when specifying a lower
percentage. It is only used when a vector index is also used. You can only specify
``fraction_lists_to_search`` when ``use_brute_force`` is set to False.
use_brute_force (bool, default False):
use_brute_force (bool):
Determines whether to use brute force search by skipping the vector index if one is available.
Default to False.
Expand All @@ -204,37 +205,35 @@ def vector_search(
"""
import bigframes.series

if not fraction_lists_to_search and use_brute_force is True:
raise ValueError(
"You can't specify fraction_lists_to_search when use_brute_force is set to True."
)
if (
isinstance(query, bigframes.series.Series)
and query_column_to_search is not None
):
raise ValueError(
"You can't specify query_column_to_search when query is a Series."
)
# TODO(ashleyxu): Support options in vector search. b/344019989
if fraction_lists_to_search is not None or use_brute_force is True:
raise NotImplementedError(
f"fraction_lists_to_search and use_brute_force is not supported. {constants.FEEDBACK_LINK}"
)
options = {
"base_table": base_table,
"column_to_search": column_to_search,
"query_column_to_search": query_column_to_search,
"distance_type": distance_type,
"top_k": top_k,
"fraction_lists_to_search": fraction_lists_to_search,
"use_brute_force": use_brute_force,
}

# Only populate options if not set to the default value.
# This avoids accidentally setting options that are mutually exclusive.
options = None
if fraction_lists_to_search is not None:
options = {} if options is None else options
options["fraction_lists_to_search"] = fraction_lists_to_search
if use_brute_force is not None:
options = {} if options is None else options
options["use_brute_force"] = use_brute_force

(query,) = utils.convert_to_dataframe(query)
sql_string, index_col_ids, index_labels = query._to_sql_query(include_index=True)

sql = bigframes.core.sql.create_vector_search_sql(
sql_string=sql_string, options=options # type: ignore
sql_string=sql_string,
base_table=base_table,
column_to_search=column_to_search,
query_column_to_search=query_column_to_search,
top_k=top_k,
distance_type=distance_type,
options=options,
)
if index_col_ids is not None:
df = query._session.read_gbq(sql, index_col=index_col_ids)
Expand Down
63 changes: 32 additions & 31 deletions bigframes/core/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
"""

import datetime
import json
import math
from typing import cast, Collection, Iterable, Mapping, TYPE_CHECKING, Union
from typing import cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union

import bigframes.core.compile.googlesql as googlesql

Expand Down Expand Up @@ -157,43 +158,43 @@ def create_vector_index_ddl(

def create_vector_search_sql(
sql_string: str,
options: Mapping[str, Union[str | int | bool | float]] = {},
*,
base_table: str,
column_to_search: str,
query_column_to_search: Optional[str] = None,
top_k: Optional[int] = None,
distance_type: Optional[str] = None,
options: Optional[Mapping[str, Union[str | int | bool | float]]] = None,
) -> str:
"""Encode the VECTOR SEARCH statement for BigQuery Vector Search."""

base_table = options["base_table"]
column_to_search = options["column_to_search"]
distance_type = options["distance_type"]
top_k = options["top_k"]
query_column_to_search = options.get("query_column_to_search", None)
vector_search_args = [
f"TABLE {googlesql.identifier(cast(str, base_table))}",
f"{simple_literal(column_to_search)}",
f"({sql_string})",
]

if query_column_to_search is not None:
query_str = f"""
SELECT
query.*,
base.*,
distance,
FROM VECTOR_SEARCH(
TABLE {googlesql.identifier(cast(str, base_table))},
{simple_literal(column_to_search)},
({sql_string}),
{simple_literal(query_column_to_search)},
distance_type => {simple_literal(distance_type)},
top_k => {simple_literal(top_k)}
)
"""
else:
query_str = f"""
vector_search_args.append(
f"query_column_to_search => {simple_literal(query_column_to_search)}"
)

if top_k is not None:
vector_search_args.append(f"top_k=> {simple_literal(top_k)}")

if distance_type is not None:
vector_search_args.append(f"distance_type => {simple_literal(distance_type)}")

if options is not None:
vector_search_args.append(
f"options => {simple_literal(json.dumps(options, indent=None))}"
)

args_str = ",\n".join(vector_search_args)
return f"""
SELECT
query.*,
base.*,
distance,
FROM VECTOR_SEARCH(
TABLE {googlesql.identifier(cast(str, base_table))},
{simple_literal(column_to_search)},
({sql_string}),
distance_type => {simple_literal(distance_type)},
top_k => {simple_literal(top_k)}
)
FROM VECTOR_SEARCH({args_str})
"""
return query_str
79 changes: 34 additions & 45 deletions tests/unit/core/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,51 @@


def test_create_vector_search_sql_simple():
sql_string = "SELECT embedding FROM my_embeddings_table WHERE id = 1"
options = {
"base_table": "my_base_table",
"column_to_search": "my_embedding_column",
"distance_type": "COSINE",
"top_k": 10,
"use_brute_force": False,
}

expected_query = f"""
result_query = sql.create_vector_search_sql(
sql_string="SELECT embedding FROM my_embeddings_table WHERE id = 1",
base_table="my_base_table",
column_to_search="my_embedding_column",
)
assert (
result_query
== """
SELECT
query.*,
base.*,
distance,
FROM VECTOR_SEARCH(
TABLE `my_base_table`,
'my_embedding_column',
({sql_string}),
distance_type => 'COSINE',
top_k => 10
)
FROM VECTOR_SEARCH(TABLE `my_base_table`,
'my_embedding_column',
(SELECT embedding FROM my_embeddings_table WHERE id = 1))
"""

result_query = sql.create_vector_search_sql(
sql_string, options # type:ignore
)
assert result_query == expected_query


def test_create_vector_search_sql_query_column_to_search():
sql_string = "SELECT embedding FROM my_embeddings_table WHERE id = 1"
options = {
"base_table": "my_base_table",
"column_to_search": "my_embedding_column",
"distance_type": "COSINE",
"top_k": 10,
"query_column_to_search": "new_embedding_column",
"use_brute_force": False,
}

expected_query = f"""
def test_create_vector_search_sql_all_named_parameters():
result_query = sql.create_vector_search_sql(
sql_string="SELECT embedding FROM my_embeddings_table WHERE id = 1",
base_table="my_base_table",
column_to_search="my_embedding_column",
query_column_to_search="another_embedding_column",
top_k=10,
distance_type="cosine",
options={
"fraction_lists_to_search": 0.1,
"use_brute_force": False,
},
)
assert (
result_query
== """
SELECT
query.*,
base.*,
distance,
FROM VECTOR_SEARCH(
TABLE `my_base_table`,
'my_embedding_column',
({sql_string}),
'new_embedding_column',
distance_type => 'COSINE',
top_k => 10
)
FROM VECTOR_SEARCH(TABLE `my_base_table`,
'my_embedding_column',
(SELECT embedding FROM my_embeddings_table WHERE id = 1),
query_column_to_search => 'another_embedding_column',
top_k=> 10,
distance_type => 'cosine',
options => '{\\"fraction_lists_to_search\\": 0.1, \\"use_brute_force\\": false}')
"""

result_query = sql.create_vector_search_sql(
sql_string, options # type:ignore
)
assert result_query == expected_query

0 comments on commit 131edc3

Please sign in to comment.