Skip to content

Commit

Permalink
feat: added arithmetic expression support, closes #1093
Browse files Browse the repository at this point in the history
  • Loading branch information
aayushacharya committed Feb 7, 2024
1 parent e5a9190 commit bdeecaf
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 8 deletions.
1 change: 1 addition & 0 deletions evadb/expression/abstract_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ExpressionType(IntEnum):
ARITHMETIC_SUBTRACT = auto()
ARITHMETIC_MULTIPLY = auto()
ARITHMETIC_DIVIDE = auto()
ARITHMETIC_MODULUS = auto()

FUNCTION_EXPRESSION = auto()

Expand Down
55 changes: 52 additions & 3 deletions evadb/expression/arithmetic_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,62 @@ def __init__(
super().__init__(exp_type, rtype=ExpressionReturnType.FLOAT, children=children)

def evaluate(self, *args, **kwargs):
vl = self.get_child(0).evaluate(*args, **kwargs)
vr = self.get_child(1).evaluate(*args, **kwargs)
lbatch = self.get_child(0).evaluate(*args, **kwargs)
rbatch = self.get_child(1).evaluate(*args, **kwargs)

return Batch.combine_batches(vl, vr, self.etype)
assert len(lbatch) == len(
rbatch
), f"Left and Right batch does not have equal elements: left: {len(lbatch)} right: {len(rbatch)}"

assert self.etype in [
ExpressionType.ARITHMETIC_ADD,
ExpressionType.ARITHMETIC_SUBTRACT,
ExpressionType.ARITHMETIC_DIVIDE,
ExpressionType.ARITHMETIC_MULTIPLY,
ExpressionType.ARITHMETIC_MODULUS,
], f"Expression type not supported {self.etype}"

if self.etype == ExpressionType.ARITHMETIC_ADD:
return Batch.from_add(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_SUBTRACT:
return Batch.from_subtract(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_MULTIPLY:
return Batch.from_multiply(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_DIVIDE:
return Batch.from_divide(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_MODULUS:
return Batch.from_modulus(lbatch, rbatch)

return Batch.combine_batches(lbatch, rbatch, self.etype)

def get_symbol(self) -> str:
if self.etype == ExpressionType.ARITHMETIC_ADD:
return "+"
elif self.etype == ExpressionType.ARITHMETIC_SUBTRACT:
return "-"
elif self.etype == ExpressionType.ARITHMETIC_MULTIPLY:
return "*"
elif self.etype == ExpressionType.ARITHMETIC_DIVIDE:
return "/"
elif self.etype == ExpressionType.ARITHMETIC_MODULUS:
return "%"

def __str__(self) -> str:
expr_str = "("
if self.get_child(0):
expr_str += f"{self.get_child(0)}"
if self.etype:
expr_str += f" {self.get_symbol()} "
if self.get_child(1):
expr_str += f"{self.get_child(1)}"
expr_str += ")"
return expr_str

def __eq__(self, other):
is_subtree_equal = super().__eq__(other)
if not isinstance(other, ArithmeticExpression):
return False
return is_subtree_equal and self.etype == other.etype

def __hash__(self) -> int:
return super().__hash__()
20 changes: 20 additions & 0 deletions evadb/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ def deserialize(cls, data):
obj = PickleSerializer.deserialize(data)
return cls(frames=obj["frames"])

@classmethod
def from_add(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() + batch2.to_numpy()))

@classmethod
def from_subtract(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() - batch2.to_numpy()))

@classmethod
def from_multiply(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() * batch2.to_numpy()))

@classmethod
def from_divide(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() / batch2.to_numpy()))

@classmethod
def from_modulus(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() % batch2.to_numpy()))

@classmethod
def from_eq(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() == batch2.to_numpy()))
Expand Down
16 changes: 11 additions & 5 deletions evadb/parser/evadb.lark
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,13 @@ predicate: predicate NOT? IN "(" (select_statement | expressions) ")" ->in_pred
| predicate comparison_operator predicate -> binary_comparison_predicate
| predicate comparison_operator (ALL | ANY | SOME) "(" select_statement ")" ->subquery_comparison_predicate
| assign_var ->expression_atom_predicate
| expression_atom
| arithmetic_expression

arithmetic_expression: product
| arithmetic_expression add_sub_operator product -> arithmetic_expression_atom

product: expression_atom
| product div_mul_mod_operator expression_atom -> arithmetic_expression_atom

assign_var.1: LOCAL_ID VAR_ASSIGN expression_atom

Expand All @@ -332,8 +338,7 @@ expression_atom.2: constant ->constant_expression_atom
| unary_operator expression_atom ->unary_expression_atom
| "(" expression ("," expression)* ")" ->nested_expression_atom
| "(" select_statement ")" ->subquery_expession_atom
| expression_atom bit_operator expression_atom ->bit_expression_atom
| expression_atom math_operator expression_atom
| expression_atom bit_operator expression_atom ->bit_expression_atom

unary_operator: EXCLAMATION_SYMBOL | BIT_NOT_OP | PLUS | MINUS | NOT

Expand All @@ -343,7 +348,8 @@ logical_operator: AND | XOR | OR

bit_operator: "<<" | ">>" | "&" | "^" | "|"

math_operator: STAR | DIVIDE | MODULUS | DIV | MOD | PLUS | MINUS | MINUSMINUS
div_mul_mod_operator: DIVIDE | STAR | MODULUS
add_sub_operator: PLUS | MINUS

// KEYWORDS

Expand Down Expand Up @@ -526,7 +532,7 @@ OR_ASSIGN: "|="

STAR: "*"
DIVIDE: "/"
MODULUS: "%"
MODULUS: "%"
PLUS: "+"
MINUSMINUS: "--"
MINUS: "-"
Expand Down
23 changes: 23 additions & 0 deletions evadb/parser/lark_visitor/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from evadb.catalog.catalog_type import ColumnType
from evadb.expression.abstract_expression import ExpressionType
from evadb.expression.arithmetic_expression import ArithmeticExpression
from evadb.expression.comparison_expression import ComparisonExpression
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.logical_expression import LogicalExpression
Expand Down Expand Up @@ -60,6 +61,28 @@ def constant(self, tree):

return self.visit_children(tree)

def arithmetic_expression_atom(self, tree):
left = self.visit(tree.children[0])
op = self.visit(tree.children[1])
right = self.visit(tree.children[2])
return ArithmeticExpression(op, left, right)

def div_mul_mod_operator(self, tree):
op = str(tree.children[0])
if op == "*":
return ExpressionType.ARITHMETIC_MULTIPLY
elif op == "/":
return ExpressionType.ARITHMETIC_DIVIDE
elif op == "%":
return ExpressionType.ARITHMETIC_MODULUS

def add_sub_operator(self, tree):
op = str(tree.children[0])
if op == "+":
return ExpressionType.ARITHMETIC_ADD
elif op == "-":
return ExpressionType.ARITHMETIC_SUBTRACT

def logical_expression(self, tree):
left = self.visit(tree.children[0])
op = self.visit(tree.children[1])
Expand Down

0 comments on commit bdeecaf

Please sign in to comment.