Skip to content

Commit

Permalink
feat: LIKE operator (#604)
Browse files Browse the repository at this point in the history
* initial changes

* lint fix

* add test

* add binder check

* regex

* nit

* checkpoint

* checkpoint

---------

Co-authored-by: Kaushik Ravichandran <kravicha3@ada-01.cc.gatech.edu>
Co-authored-by: jarulraj <arulraj@gatech.edu>
  • Loading branch information
3 people authored Mar 29, 2023
1 parent 6b73d49 commit 4a84efe
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 4 deletions.
8 changes: 7 additions & 1 deletion eva/binder/binder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, List

from eva.catalog.catalog_type import TableType
from eva.catalog.catalog_utils import is_video_table
from eva.catalog.catalog_utils import is_string_col, is_video_table
from eva.catalog.sql_config import IDENTIFIER_COLUMN

if TYPE_CHECKING:
Expand Down Expand Up @@ -107,3 +107,9 @@ def check_table_object_is_video(table_ref: TableRef) -> None:
if not is_video_table(table_ref.table.table_obj):
err_msg = "GROUP BY only supported for video tables"
raise BinderError(err_msg)


def check_column_name_is_string(col_ref) -> None:
if not is_string_col(col_ref.col_object):
err_msg = "LIKE only supported for string columns"
raise BinderError(err_msg)
6 changes: 5 additions & 1 deletion eva/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
from eva.binder.binder_utils import (
BinderError,
bind_table_info,
check_column_name_is_string,
check_groupby_pattern,
check_table_object_is_video,
extend_star,
)
from eva.binder.statement_binder_context import StatementBinderContext
from eva.catalog.catalog_manager import CatalogManager
from eva.catalog.catalog_type import IndexType, NdArrayType, TableType
from eva.expression.abstract_expression import AbstractExpression
from eva.expression.abstract_expression import AbstractExpression, ExpressionType
from eva.expression.function_expression import FunctionExpression
from eva.expression.tuple_value_expression import TupleValueExpression
from eva.parser.alias import Alias
Expand Down Expand Up @@ -120,6 +121,9 @@ def _bind_select_statement(self, node: SelectStatement):
self.bind(node.from_table)
if node.where_clause:
self.bind(node.where_clause)
if node.where_clause.etype == ExpressionType.COMPARE_LIKE:
check_column_name_is_string(node.where_clause.children[0])

if node.target_list:
# SELECT * support
if (
Expand Down
4 changes: 4 additions & 0 deletions eva/catalog/catalog_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def is_video_table(table: TableCatalogEntry):
return table.table_type == TableType.VIDEO_DATA


def is_string_col(col: ColumnCatalogEntry):
return col.type == ColumnType.TEXT or col.array_type == NdArrayType.STR


def get_video_table_column_definitions() -> List[ColumnDefinition]:
"""
name: video path
Expand Down
3 changes: 3 additions & 0 deletions eva/expression/abstract_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class ExpressionType(IntEnum):
COMPARE_NEQ = auto()
COMPARE_CONTAINS = auto()
COMPARE_IS_CONTAINED = auto()
COMPARE_LIKE = auto()

# Logical operators
LOGICAL_AND = auto()
LOGICAL_OR = auto()
LOGICAL_NOT = auto()

# Arithmetic operators
ARITHMETIC_ADD = auto()
ARITHMETIC_SUBTRACT = auto()
Expand Down
3 changes: 3 additions & 0 deletions eva/expression/comparison_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def evaluate(self, *args, **kwargs):
ExpressionType.COMPARE_NEQ,
ExpressionType.COMPARE_CONTAINS,
ExpressionType.COMPARE_IS_CONTAINED,
ExpressionType.COMPARE_LIKE,
], f"Expression type not supported {self.etype}"

if self.etype == ExpressionType.COMPARE_EQUAL:
Expand All @@ -73,6 +74,8 @@ def evaluate(self, *args, **kwargs):
return Batch.compare_contains(lbatch, rbatch)
elif self.etype == ExpressionType.COMPARE_IS_CONTAINED:
return Batch.compare_is_contained(lbatch, rbatch)
elif self.etype == ExpressionType.COMPARE_LIKE:
return Batch.compare_like(lbatch, rbatch)

def get_symbol(self) -> str:
if self.etype == ExpressionType.COMPARE_EQUAL:
Expand Down
6 changes: 6 additions & 0 deletions eva/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ def compare_is_contained(cls, batch1: Batch, batch2: Batch) -> None:
)
)

@classmethod
def compare_like(cls, batch1: Batch, batch2: Batch) -> None:
col = batch1._frames.iloc[:, 0]
regex = batch2._frames.iloc[:, 0][0]
return cls(pd.DataFrame(col.astype("str").str.match(pat=regex)))

def __str__(self) -> str:
with pd.option_context(
"display.pprint_nest_depth", 1, "display.max_colwidth", 100
Expand Down
4 changes: 2 additions & 2 deletions eva/parser/eva.lark
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ predicate: predicate NOT? IN "(" (select_statement | expressions) ")" ->in_pred
| predicate IS null_notnull ->is_null_predicate
| predicate comparison_operator predicate -> binary_comparison_predicate
| predicate comparison_operator (ALL | ANY | SOME) "(" select_statement ")" ->subquery_comparison_predicate
| predicate NOT? LIKE predicate (STRING_LITERAL)?
| assign_var ->expression_atom_predicate
| expression_atom

Expand All @@ -287,7 +286,7 @@ expression_atom.2: constant ->constant_expression_atom

unary_operator: EXCLAMATION_SYMBOL | BIT_NOT_OP | PLUS | MINUS | NOT

comparison_operator: EQUAL_SYMBOL | GREATER_SYMBOL | LESS_SYMBOL | GREATER_OR_EQUAL_SYMBOL | LESS_OR_EQUAL_SYMBOL | NOT_EQUAL_SYMBOL | CONTAINS_SYMBOL | CONTAINED_IN_SYMBOL
comparison_operator: EQUAL_SYMBOL | GREATER_SYMBOL | LESS_SYMBOL | GREATER_OR_EQUAL_SYMBOL | LESS_OR_EQUAL_SYMBOL | NOT_EQUAL_SYMBOL | CONTAINS_SYMBOL | CONTAINED_IN_SYMBOL | LIKE_SYMBOL

logical_operator: AND | XOR | OR

Expand Down Expand Up @@ -479,6 +478,7 @@ LESS_OR_EQUAL_SYMBOL: ">="
NOT_EQUAL_SYMBOL: "!="
CONTAINS_SYMBOL: "@>"
CONTAINED_IN_SYMBOL: "<@"
LIKE_SYMBOL: "LIKE"

// Operators. Bit

Expand Down
2 changes: 2 additions & 0 deletions eva/parser/lark_visitor/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def comparison_operator(self, tree):
return ExpressionType.COMPARE_CONTAINS
elif op == "<@":
return ExpressionType.COMPARE_IS_CONTAINED
elif op == "LIKE":
return ExpressionType.COMPARE_LIKE

def logical_operator(self, tree):
op = str(tree.children[0])
Expand Down
4 changes: 4 additions & 0 deletions test/integration_tests/test_insert_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,7 @@ def test_should_insert_tuples_in_table(self):
),
)
)

query = """SELECT name FROM CSVTable WHERE name LIKE '.*(sad|happy)';"""
batch = execute_query_fetch_all(query)
self.assertEqual(len(batch._frames), 2)
72 changes: 72 additions & 0 deletions test/integration_tests/test_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# coding=utf-8
# Copyright 2018-2022 EVA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

from eva.catalog.catalog_manager import CatalogManager
from eva.configuration.constants import EVA_ROOT_DIR
from eva.server.command_handler import execute_query_fetch_all


class LikeTest(unittest.TestCase):
def setUp(self):
# reset the catalog manager before running each test
CatalogManager().reset()
meme1 = f"{EVA_ROOT_DIR}/data/detoxify/meme1.jpg"
meme2 = f"{EVA_ROOT_DIR}/data/detoxify/meme2.jpg"

execute_query_fetch_all(f"LOAD IMAGE '{meme1}' INTO MemeImages;")
execute_query_fetch_all(f"LOAD IMAGE '{meme2}' INTO MemeImages;")

def tearDown(self):
# clean up
execute_query_fetch_all("DROP TABLE IF EXISTS MemeImages;")

def test_like_with_ocr(self):
create_udf_query = """CREATE UDF IF NOT EXISTS OCRExtractor
INPUT (frame NDARRAY UINT8(3, ANYDIM, ANYDIM))
OUTPUT (labels NDARRAY STR(10),
bboxes NDARRAY FLOAT32(ANYDIM, 4),
scores NDARRAY FLOAT32(ANYDIM))
TYPE OCRExtraction
IMPL 'eva/udfs/ocr_extractor.py';
"""
execute_query_fetch_all(create_udf_query)

select_query = (
"""SELECT * FROM MemeImages JOIN LATERAL OCRExtractor(data) AS X(label, x, y) WHERE label LIKE """
+ r'"[A-Za-z\', \[]*CANT[\,\',A-Za-z \]]*"'
)
actual_batch = execute_query_fetch_all(select_query)

self.assertEqual(len(actual_batch._frames), 2)

def test_like_fails_on_non_string_col(self):
create_udf_query = """CREATE UDF IF NOT EXISTS OCRExtractor
INPUT (frame NDARRAY UINT8(3, ANYDIM, ANYDIM))
OUTPUT (labels NDARRAY STR(10),
bboxes NDARRAY FLOAT32(ANYDIM, 4),
scores NDARRAY FLOAT32(ANYDIM))
TYPE OCRExtraction
IMPL 'eva/udfs/ocr_extractor.py';
"""
execute_query_fetch_all(create_udf_query)

select_query = """SELECT * FROM MemeImages JOIN LATERAL OCRExtractor(data) AS X(label, x, y) WHERE x LIKE "[A-Za-z]*CANT";"""
with self.assertRaises(Exception):
execute_query_fetch_all(select_query)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4a84efe

Please sign in to comment.