diff --git a/evadb/expression/abstract_expression.py b/evadb/expression/abstract_expression.py index 9b72f32e6..1ee2fa01f 100644 --- a/evadb/expression/abstract_expression.py +++ b/evadb/expression/abstract_expression.py @@ -45,6 +45,7 @@ class ExpressionType(IntEnum): ARITHMETIC_SUBTRACT = auto() ARITHMETIC_MULTIPLY = auto() ARITHMETIC_DIVIDE = auto() + ARITHMETIC_MODULUS = auto() FUNCTION_EXPRESSION = auto() diff --git a/evadb/expression/arithmetic_expression.py b/evadb/expression/arithmetic_expression.py index 37c2c5a53..2ca791183 100644 --- a/evadb/expression/arithmetic_expression.py +++ b/evadb/expression/arithmetic_expression.py @@ -36,13 +36,62 @@ def __init__( super().__init__(exp_type, rtype=ExpressionReturnType.FLOAT, children=children) def evaluate(self, *args, **kwargs): - vl = self.get_child(0).evaluate(*args, **kwargs) - vr = self.get_child(1).evaluate(*args, **kwargs) + lbatch = self.get_child(0).evaluate(*args, **kwargs) + rbatch = self.get_child(1).evaluate(*args, **kwargs) - return Batch.combine_batches(vl, vr, self.etype) + assert len(lbatch) == len( + rbatch + ), f"Left and Right batch does not have equal elements: left: {len(lbatch)} right: {len(rbatch)}" + + assert self.etype in [ + ExpressionType.ARITHMETIC_ADD, + ExpressionType.ARITHMETIC_SUBTRACT, + ExpressionType.ARITHMETIC_DIVIDE, + ExpressionType.ARITHMETIC_MULTIPLY, + ExpressionType.ARITHMETIC_MODULUS, + ], f"Expression type not supported {self.etype}" + + if self.etype == ExpressionType.ARITHMETIC_ADD: + return Batch.from_add(lbatch, rbatch) + elif self.etype == ExpressionType.ARITHMETIC_SUBTRACT: + return Batch.from_subtract(lbatch, rbatch) + elif self.etype == ExpressionType.ARITHMETIC_MULTIPLY: + return Batch.from_multiply(lbatch, rbatch) + elif self.etype == ExpressionType.ARITHMETIC_DIVIDE: + return Batch.from_divide(lbatch, rbatch) + elif self.etype == ExpressionType.ARITHMETIC_MODULUS: + return Batch.from_modulus(lbatch, rbatch) + + return Batch.combine_batches(lbatch, rbatch, self.etype) + + def get_symbol(self) -> str: + if self.etype == ExpressionType.ARITHMETIC_ADD: + return "+" + elif self.etype == ExpressionType.ARITHMETIC_SUBTRACT: + return "-" + elif self.etype == ExpressionType.ARITHMETIC_MULTIPLY: + return "*" + elif self.etype == ExpressionType.ARITHMETIC_DIVIDE: + return "/" + elif self.etype == ExpressionType.ARITHMETIC_MODULUS: + return "%" + + def __str__(self) -> str: + expr_str = "(" + if self.get_child(0): + expr_str += f"{self.get_child(0)}" + if self.etype: + expr_str += f" {self.get_symbol()} " + if self.get_child(1): + expr_str += f"{self.get_child(1)}" + expr_str += ")" + return expr_str def __eq__(self, other): is_subtree_equal = super().__eq__(other) if not isinstance(other, ArithmeticExpression): return False return is_subtree_equal and self.etype == other.etype + + def __hash__(self) -> int: + return super().__hash__() diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 43e69cc4f..c1139b7c4 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -76,6 +76,26 @@ def deserialize(cls, data): obj = PickleSerializer.deserialize(data) return cls(frames=obj["frames"]) + @classmethod + def from_add(cls, batch1: Batch, batch2: Batch) -> Batch: + return Batch(pd.DataFrame(batch1.to_numpy() + batch2.to_numpy())) + + @classmethod + def from_subtract(cls, batch1: Batch, batch2: Batch) -> Batch: + return Batch(pd.DataFrame(batch1.to_numpy() - batch2.to_numpy())) + + @classmethod + def from_multiply(cls, batch1: Batch, batch2: Batch) -> Batch: + return Batch(pd.DataFrame(batch1.to_numpy() * batch2.to_numpy())) + + @classmethod + def from_divide(cls, batch1: Batch, batch2: Batch) -> Batch: + return Batch(pd.DataFrame(batch1.to_numpy() / batch2.to_numpy())) + + @classmethod + def from_modulus(cls, batch1: Batch, batch2: Batch) -> Batch: + return Batch(pd.DataFrame(batch1.to_numpy() % batch2.to_numpy())) + @classmethod def from_eq(cls, batch1: Batch, batch2: Batch) -> Batch: return Batch(pd.DataFrame(batch1.to_numpy() == batch2.to_numpy())) diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index 4b96bf647..df00ebd5c 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -321,7 +321,13 @@ predicate: predicate NOT? IN "(" (select_statement | expressions) ")" ->in_pred | predicate comparison_operator predicate -> binary_comparison_predicate | predicate comparison_operator (ALL | ANY | SOME) "(" select_statement ")" ->subquery_comparison_predicate | assign_var ->expression_atom_predicate - | expression_atom + | arithmetic_expression + +arithmetic_expression: product + | arithmetic_expression add_sub_operator product -> arithmetic_expression_atom + +product: expression_atom + | product div_mul_mod_operator expression_atom -> arithmetic_expression_atom assign_var.1: LOCAL_ID VAR_ASSIGN expression_atom @@ -332,8 +338,7 @@ expression_atom.2: constant ->constant_expression_atom | unary_operator expression_atom ->unary_expression_atom | "(" expression ("," expression)* ")" ->nested_expression_atom | "(" select_statement ")" ->subquery_expession_atom - | expression_atom bit_operator expression_atom ->bit_expression_atom - | expression_atom math_operator expression_atom + | expression_atom bit_operator expression_atom ->bit_expression_atom unary_operator: EXCLAMATION_SYMBOL | BIT_NOT_OP | PLUS | MINUS | NOT @@ -343,7 +348,8 @@ logical_operator: AND | XOR | OR bit_operator: "<<" | ">>" | "&" | "^" | "|" -math_operator: STAR | DIVIDE | MODULUS | DIV | MOD | PLUS | MINUS | MINUSMINUS +div_mul_mod_operator: DIVIDE | STAR | MODULUS +add_sub_operator: PLUS | MINUS // KEYWORDS @@ -526,7 +532,7 @@ OR_ASSIGN: "|=" STAR: "*" DIVIDE: "/" -MODULUS: "%" +MODULUS: "%" PLUS: "+" MINUSMINUS: "--" MINUS: "-" diff --git a/evadb/parser/lark_visitor/_expressions.py b/evadb/parser/lark_visitor/_expressions.py index ff53ed4e1..91b8d3a7d 100644 --- a/evadb/parser/lark_visitor/_expressions.py +++ b/evadb/parser/lark_visitor/_expressions.py @@ -17,6 +17,7 @@ from evadb.catalog.catalog_type import ColumnType from evadb.expression.abstract_expression import ExpressionType +from evadb.expression.arithmetic_expression import ArithmeticExpression from evadb.expression.comparison_expression import ComparisonExpression from evadb.expression.constant_value_expression import ConstantValueExpression from evadb.expression.logical_expression import LogicalExpression @@ -60,6 +61,28 @@ def constant(self, tree): return self.visit_children(tree) + def arithmetic_expression_atom(self, tree): + left = self.visit(tree.children[0]) + op = self.visit(tree.children[1]) + right = self.visit(tree.children[2]) + return ArithmeticExpression(op, left, right) + + def div_mul_mod_operator(self, tree): + op = str(tree.children[0]) + if op == "*": + return ExpressionType.ARITHMETIC_MULTIPLY + elif op == "/": + return ExpressionType.ARITHMETIC_DIVIDE + elif op == "%": + return ExpressionType.ARITHMETIC_MODULUS + + def add_sub_operator(self, tree): + op = str(tree.children[0]) + if op == "+": + return ExpressionType.ARITHMETIC_ADD + elif op == "-": + return ExpressionType.ARITHMETIC_SUBTRACT + def logical_expression(self, tree): left = self.visit(tree.children[0]) op = self.visit(tree.children[1])