Skip to content

Commit

Permalink
Merge pull request #64 from iris-hep/feature_55_contains
Browse files Browse the repository at this point in the history
add `Contains` node
  • Loading branch information
masonproffitt authored Oct 18, 2021
2 parents 878a6a9 + e6e5825 commit 71d80a2
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 2 deletions.
9 changes: 9 additions & 0 deletions qastle/linq_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class ElementAt(ast.AST):
_fields = ['source', 'index']


class Contains(ast.AST):
_fields = ['source', 'value']


class Aggregate(ast.AST):
_fields = ['source', 'seed', 'func']

Expand Down Expand Up @@ -81,6 +85,7 @@ class Choose(ast.AST):
'First',
'Last',
'ElementAt',
'Contains',
'Aggregate',
'Count',
'Max',
Expand Down Expand Up @@ -153,6 +158,10 @@ def visit_Call(self, node):
if len(args) != 1:
raise SyntaxError('ElementAt() call must have exactly one argument')
return ElementAt(source=self.visit(source), index=self.visit(args[0]))
elif function_name == 'Contains':
if len(args) != 1:
raise SyntaxError('Contains() call must have exactly one argument')
return Contains(source=self.visit(source), value=self.visit(args[0]))
elif function_name == 'Aggregate':
if len(args) != 2:
raise SyntaxError('Aggregate() call must have exactly two arguments; found'
Expand Down
15 changes: 13 additions & 2 deletions qastle/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .linq_util import (Where, Select, SelectMany, First, Last, ElementAt, Aggregate, Count, Max,
Min, Sum, All, Any, Concat, Zip, OrderBy, OrderByDescending, Choose)
from .linq_util import (Where, Select, SelectMany, First, Last, ElementAt, Contains, Aggregate,
Count, Max, Min, Sum, All, Any, Concat, Zip, OrderBy, OrderByDescending,
Choose)
from .ast_util import wrap_ast, unwrap_ast

import lark
Expand Down Expand Up @@ -191,6 +192,11 @@ def visit_ElementAt(self, node):
self.visit(node.source),
self.visit(node.index))

def visit_Contains(self, node):
return self.make_composite_node_string('Contains',
self.visit(node.source),
self.visit(node.value))

def visit_Aggregate(self, node):
return self.make_composite_node_string('Aggregate',
self.visit(node.source),
Expand Down Expand Up @@ -451,6 +457,11 @@ def composite(self, children):
raise SyntaxError('ElementAt node must have two fields; found ' + str(len(fields)))
return ElementAt(source=fields[0], index=fields[1])

elif node_type == 'Contains':
if len(fields) != 2:
raise SyntaxError('Contains node must have two fields; found ' + str(len(fields)))
return Contains(source=fields[0], value=fields[1])

elif node_type == 'Aggregate':
if len(fields) != 3:
raise SyntaxError('Aggregate node must have three fields; found '
Expand Down
7 changes: 7 additions & 0 deletions tests/test_ast_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ def test_ElementAt():
'(ElementAt data_source 2)')


def test_Contains():
first_node = Contains(source=unwrap_ast(ast.parse('data_source')),
value=unwrap_ast(ast.parse('element')))
assert_equivalent_python_ast_and_text_ast(wrap_ast(first_node),
'(Contains data_source element)')


def test_Aggregate():
aggregate_node = Aggregate(source=unwrap_ast(ast.parse('data_source')),
seed=unwrap_ast(ast.parse('0')),
Expand Down
21 changes: 21 additions & 0 deletions tests/test_linq_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,27 @@ def test_elementat_bad():
insert_linq_nodes(ast.parse('the_source.ElementAt()'))


def test_contains():
initial_ast = ast.parse("the_source.Contains(element)")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Contains(source=unwrap_ast(ast.parse('the_source')),
value=unwrap_ast(ast.parse('element'))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_contains_composite():
initial_ast = ast.parse("the_source.First().Contains(element)")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Contains(source=First(source=unwrap_ast(ast.parse('the_source'))),
value=unwrap_ast(ast.parse('element'))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_contains_bad():
with pytest.raises(SyntaxError):
insert_linq_nodes(ast.parse('the_source.Contains()'))


def test_aggregate():
initial_ast = ast.parse("the_source.Aggregate(0, 'lambda total, next: total + next')")
final_ast = insert_linq_nodes(initial_ast)
Expand Down

0 comments on commit 71d80a2

Please sign in to comment.