diff --git a/qastle/linq_util.py b/qastle/linq_util.py index bb0e2ff..f67f173 100644 --- a/qastle/linq_util.py +++ b/qastle/linq_util.py @@ -27,6 +27,10 @@ class ElementAt(ast.AST): _fields = ['source', 'index'] +class Contains(ast.AST): + _fields = ['source', 'value'] + + class Aggregate(ast.AST): _fields = ['source', 'seed', 'func'] @@ -81,6 +85,7 @@ class Choose(ast.AST): 'First', 'Last', 'ElementAt', + 'Contains', 'Aggregate', 'Count', 'Max', @@ -153,6 +158,10 @@ def visit_Call(self, node): if len(args) != 1: raise SyntaxError('ElementAt() call must have exactly one argument') return ElementAt(source=self.visit(source), index=self.visit(args[0])) + elif function_name == 'Contains': + if len(args) != 1: + raise SyntaxError('Contains() call must have exactly one argument') + return Contains(source=self.visit(source), value=self.visit(args[0])) elif function_name == 'Aggregate': if len(args) != 2: raise SyntaxError('Aggregate() call must have exactly two arguments; found' diff --git a/qastle/transform.py b/qastle/transform.py index f1c2629..8fa763c 100644 --- a/qastle/transform.py +++ b/qastle/transform.py @@ -1,5 +1,6 @@ -from .linq_util import (Where, Select, SelectMany, First, Last, ElementAt, Aggregate, Count, Max, - Min, Sum, All, Any, Concat, Zip, OrderBy, OrderByDescending, Choose) +from .linq_util import (Where, Select, SelectMany, First, Last, ElementAt, Contains, Aggregate, + Count, Max, Min, Sum, All, Any, Concat, Zip, OrderBy, OrderByDescending, + Choose) from .ast_util import wrap_ast, unwrap_ast import lark @@ -191,6 +192,11 @@ def visit_ElementAt(self, node): self.visit(node.source), self.visit(node.index)) + def visit_Contains(self, node): + return self.make_composite_node_string('Contains', + self.visit(node.source), + self.visit(node.value)) + def visit_Aggregate(self, node): return self.make_composite_node_string('Aggregate', self.visit(node.source), @@ -451,6 +457,11 @@ def composite(self, children): raise SyntaxError('ElementAt node must have two fields; found ' + str(len(fields))) return ElementAt(source=fields[0], index=fields[1]) + elif node_type == 'Contains': + if len(fields) != 2: + raise SyntaxError('Contains node must have two fields; found ' + str(len(fields))) + return Contains(source=fields[0], value=fields[1]) + elif node_type == 'Aggregate': if len(fields) != 3: raise SyntaxError('Aggregate node must have three fields; found ' diff --git a/tests/test_ast_language.py b/tests/test_ast_language.py index 1597b55..d97e4dd 100644 --- a/tests/test_ast_language.py +++ b/tests/test_ast_language.py @@ -193,6 +193,13 @@ def test_ElementAt(): '(ElementAt data_source 2)') +def test_Contains(): + first_node = Contains(source=unwrap_ast(ast.parse('data_source')), + value=unwrap_ast(ast.parse('element'))) + assert_equivalent_python_ast_and_text_ast(wrap_ast(first_node), + '(Contains data_source element)') + + def test_Aggregate(): aggregate_node = Aggregate(source=unwrap_ast(ast.parse('data_source')), seed=unwrap_ast(ast.parse('0')), diff --git a/tests/test_linq_util.py b/tests/test_linq_util.py index cef83c9..fe724ae 100644 --- a/tests/test_linq_util.py +++ b/tests/test_linq_util.py @@ -190,6 +190,27 @@ def test_elementat_bad(): insert_linq_nodes(ast.parse('the_source.ElementAt()')) +def test_contains(): + initial_ast = ast.parse("the_source.Contains(element)") + final_ast = insert_linq_nodes(initial_ast) + expected_ast = wrap_ast(Contains(source=unwrap_ast(ast.parse('the_source')), + value=unwrap_ast(ast.parse('element')))) + assert_ast_nodes_are_equal(final_ast, expected_ast) + + +def test_contains_composite(): + initial_ast = ast.parse("the_source.First().Contains(element)") + final_ast = insert_linq_nodes(initial_ast) + expected_ast = wrap_ast(Contains(source=First(source=unwrap_ast(ast.parse('the_source'))), + value=unwrap_ast(ast.parse('element')))) + assert_ast_nodes_are_equal(final_ast, expected_ast) + + +def test_contains_bad(): + with pytest.raises(SyntaxError): + insert_linq_nodes(ast.parse('the_source.Contains()')) + + def test_aggregate(): initial_ast = ast.parse("the_source.Aggregate(0, 'lambda total, next: total + next')") final_ast = insert_linq_nodes(initial_ast)