diff --git a/src/NQuery.Tests/Evaluation/EvaluationTest.cs b/src/NQuery.Tests/Evaluation/EvaluationTest.cs index 32b1f3f3..cc882ac7 100644 --- a/src/NQuery.Tests/Evaluation/EvaluationTest.cs +++ b/src/NQuery.Tests/Evaluation/EvaluationTest.cs @@ -1,8 +1,8 @@ namespace NQuery.Tests.Evaluation { - public class EvaluationTest + public abstract class EvaluationTest { - protected static void AssertProduces(string text, T[] expected) + protected static void AssertProduces(string text, T[] expected, DataContext dataContext = null) { var expectedColumns = new[] { typeof(T) }; var expectedRows = new object[expected.Length][]; @@ -10,10 +10,10 @@ protected static void AssertProduces(string text, T[] expected) for (var i = 0; i < expected.Length; i++) expectedRows[i] = new object[] { expected[i] }; - AssertProduces(text, expectedColumns, expectedRows); + AssertProduces(text, expectedColumns, expectedRows, dataContext); } - protected static void AssertProduces(string text, (T1, T2)[] expected) + protected static void AssertProduces(string text, (T1, T2)[] expected, DataContext dataContext = null) { var expectedColumns = new[] { typeof(T1), typeof(T2) }; var expectedRows = new object[expected.Length][]; @@ -21,12 +21,12 @@ protected static void AssertProduces(string text, (T1, T2)[] expected) for (var i = 0; i < expected.Length; i++) expectedRows[i] = new object[] { expected[i].Item1, expected[i].Item2 }; - AssertProduces(text, expectedColumns, expectedRows); + AssertProduces(text, expectedColumns, expectedRows, dataContext); } - private static void AssertProduces(string text, Type[] expectedColumns, object[][] expectedRows) + private static void AssertProduces(string text, Type[] expectedColumns, object[][] expectedRows, DataContext dataContext) { - var dataContext = NorthwindDataContext.Instance; + dataContext ??= NorthwindDataContext.Instance; var query = Query.Create(dataContext, text); using var data = query.ExecuteReader(); diff --git a/src/NQuery.Tests/Evaluation/FunctionInvocationTests.cs b/src/NQuery.Tests/Evaluation/FunctionInvocationTests.cs new file mode 100644 index 00000000..487066b3 --- /dev/null +++ b/src/NQuery.Tests/Evaluation/FunctionInvocationTests.cs @@ -0,0 +1,91 @@ +using NQuery.Symbols; + +namespace NQuery.Tests.Evaluation +{ + public sealed class FunctionInvocationTests : EvaluationTest + { + private readonly DataContext _dataContext; + + public FunctionInvocationTests() + { + _dataContext = DataContext.Default.AddFunctions( + new FunctionSymbol(nameof(StaticFunction0), StaticFunction0), + new FunctionSymbol(nameof(StaticFunction1), StaticFunction1), + new FunctionSymbol(nameof(StaticFunction2), StaticFunction2), + new FunctionSymbol(nameof(StaticFunction3), StaticFunction3), + new FunctionSymbol(nameof(InstanceFunction0), InstanceFunction0), + new FunctionSymbol(nameof(InstanceFunction1), InstanceFunction1), + new FunctionSymbol(nameof(InstanceFunction2), InstanceFunction2), + new FunctionSymbol(nameof(InstanceFunction3), InstanceFunction3) + ); + } + + [Theory] + [InlineData(nameof(StaticFunction0) + "()")] + [InlineData(nameof(StaticFunction1) + "(42)")] + [InlineData(nameof(StaticFunction2) + "(40, 2)")] + [InlineData(nameof(StaticFunction3) + "(20, 20, 2)")] + public void Evaluation_FunctionInvocationExpression_StaticFunction(string functionInvocation) + { + var text = "SELECT " + functionInvocation; + + var expected = new[] { 42 }; + + AssertProduces(text, expected, _dataContext); + } + + [Theory] + [InlineData(nameof(InstanceFunction0) + "()")] + [InlineData(nameof(InstanceFunction1) + "(42)")] + [InlineData(nameof(InstanceFunction2) + "(40, 2)")] + [InlineData(nameof(InstanceFunction3) + "(20, 20, 2)")] + public void Evaluation_FunctionInvocationExpression_InstanceFunction(string functionInvocation) + { + var text = "SELECT " + functionInvocation; + + var expected = new[] { GetHashCode() + 42 }; + + AssertProduces(text, expected, _dataContext); + } + + private static int StaticFunction0() + { + return 42; + } + + private static int StaticFunction1(int arg) + { + return arg; + } + + private static int StaticFunction2(int arg1, int arg2) + { + return arg1 + arg2; + } + + private static int StaticFunction3(int arg1, int arg2, int arg3) + { + return arg1 + arg2 + arg3; + } + + private int InstanceFunction0() + { + return GetHashCode() + 42; + } + + private int InstanceFunction1(int arg) + { + return GetHashCode() + arg; + } + + private int InstanceFunction2(int arg1, int arg2) + { + return GetHashCode() + arg1 + arg2; + } + + private int InstanceFunction3(int arg1, int arg2, int arg3) + { + return GetHashCode() + arg1 + arg2 + arg3; + } + } +} diff --git a/src/NQuery/Symbols/FunctionSymbol.cs b/src/NQuery/Symbols/FunctionSymbol.cs index 28a4c5a6..260db174 100644 --- a/src/NQuery/Symbols/FunctionSymbol.cs +++ b/src/NQuery/Symbols/FunctionSymbol.cs @@ -14,12 +14,19 @@ protected FunctionSymbol(string name, Type type, params ParameterSymbol[] parame { } - public abstract Expression CreateInvocation(IEnumerable arguments); + public Expression CreateInvocation(IEnumerable arguments) + { + var function = FunctionDelegate; + var instance = function.Target == null ? null : Expression.Constant(function.Target); + return Expression.Call(instance, function.Method, arguments); + } public override SymbolKind Kind { get { return SymbolKind.Function; } } + + protected abstract Delegate FunctionDelegate { get; } } public sealed class FunctionSymbol : FunctionSymbol @@ -30,12 +37,12 @@ public FunctionSymbol(string name, Func function) Function = function; } - public override Expression CreateInvocation(IEnumerable arguments) + public Func Function { get; } + + protected override Delegate FunctionDelegate { - return Expression.Call(Function.Method, arguments); + get { return Function; } } - - public Func Function { get; } } public sealed class FunctionSymbol : FunctionSymbol @@ -52,12 +59,12 @@ public FunctionSymbol(string name, string argumentName, Func functio Function = function; } - public override Expression CreateInvocation(IEnumerable arguments) + public Func Function { get; } + + protected override Delegate FunctionDelegate { - return Expression.Call(Function.Method, arguments); + get { return Function; } } - - public Func Function { get; } } public sealed class FunctionSymbol : FunctionSymbol @@ -73,12 +80,12 @@ public FunctionSymbol(string name, string parameterName1, string parameterName2, Function = function; } - public override Expression CreateInvocation(IEnumerable arguments) + public Func Function { get; } + + protected override Delegate FunctionDelegate { - return Expression.Call(Function.Method, arguments); + get { return Function; } } - - public Func Function { get; } } public sealed class FunctionSymbol : FunctionSymbol @@ -94,11 +101,11 @@ public FunctionSymbol(string name, string parameterName1, string parameterName2, Function = function; } - public override Expression CreateInvocation(IEnumerable arguments) + public Func Function { get; } + + protected override Delegate FunctionDelegate { - return Expression.Call(Function.Method, arguments); + get { return Function; } } - - public Func Function { get; } } } \ No newline at end of file