diff --git a/src/NQuery.Tests/Evaluation/EagerAndLazyTests.cs b/src/NQuery.Tests/Evaluation/EagerAndLazyTests.cs new file mode 100644 index 00000000..5c6a4a9e --- /dev/null +++ b/src/NQuery.Tests/Evaluation/EagerAndLazyTests.cs @@ -0,0 +1,170 @@ +using System.Linq.Expressions; + +using NQuery.Symbols; + +namespace NQuery.Tests.Evaluation +{ + public sealed class EagerAndLazyTests + { + private static InvocationResult EvaluateAndCountInvocations(string text) + { + var invocationResult = new InvocationResult(); + var invocationResultVariable = new VariableSymbol("ir", typeof(InvocationResult), invocationResult); + var nullInt32Function = new InvocationResultFunctionSymbol("NULL_INT32", NullInt32Function); + var nonNullInt32Function = new InvocationResultFunctionSymbol("NON_NULL_INT32", NonNullInt32Function); + var dataContext = DataContext.Default + .AddVariables(invocationResultVariable) + .AddFunctions(nullInt32Function, nonNullInt32Function); + var expression = Expression.Create(dataContext, text); + invocationResult.Result = expression.Evaluate(); + return invocationResult; + } + + [Fact] + public void Evaluation_Conversion_Once() + { + var result = EvaluateAndCountInvocations("CAST(NON_NULL_INT32(ir) AS int64)"); + Assert.Equal(42L, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Evaluation_Unary_Once() + { + var result = EvaluateAndCountInvocations("~NON_NULL_INT32(ir)"); + Assert.Equal(~42, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Evaluation_Binary_EagerOnce() + { + var result = EvaluateAndCountInvocations("NULL_INT32(ir) + NON_NULL_INT32(ir)"); + Assert.Null(result.Result); + Assert.Equal(1, result.NullInt32FunctionCount); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Evaluation_FunctionInvocation_EagerOnce() + { + var result = EvaluateAndCountInvocations("SUBSTRING('abc', NULL_INT32(ir), NON_NULL_INT32(ir))"); + Assert.Null(result.Result); + Assert.Equal(1, result.NullInt32FunctionCount); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Evaluation_MethodInvocation_Instance_Once() + { + var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir).Equals(42)"); + Assert.Equal(true, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Evaluation_MethodInvocation_Arguments_EagerOnce() + { + var result = EvaluateAndCountInvocations("''.Substring(NULL_INT32(ir), NON_NULL_INT32(ir))"); + Assert.Null(result.Result); + Assert.Equal(1, result.NullInt32FunctionCount); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Evaluation_PropertyAccess_Once() + { + var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir).Equals(42)"); + Assert.Equal(true, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Evaluation_IsNull_Once() + { + var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir) IS NOT NULL"); + Assert.Equal(true, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Evaluation_CaseWhen_NonNullFunction_LazyOnce() + { + const string text = @" + CASE + WHEN NON_NULL_INT32(ir) = 42 THEN 42 + ELSE NULL_INT32(ir) + END"; + + var result = EvaluateAndCountInvocations(text); + Assert.Equal(42, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + Assert.Equal(0, result.NullInt32FunctionCount); + } + + [Fact] + public void Evaluation_CaseWhen_NonNullNestedFunction_LazyOnce() + { + const string text = @" + CASE + WHEN TO_INT32(NON_NULL_INT32(ir)) = 42 THEN 42 + WHEN TO_INT32(NON_NULL_INT32(ir)) != 42 THEN 0 + ELSE TO_INT32(NULL_INT32(ir)) + END"; + + var result = EvaluateAndCountInvocations(text); + Assert.Equal(42, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + Assert.Equal(0, result.NullInt32FunctionCount); + } + + [Fact] + public void Evaluation_CaseWhen_NullFunction_LazyOnce() + { + const string text = @" + CASE + WHEN NULL_INT32(ir) = 0 THEN 42 + ELSE 0 + END"; + + var result = EvaluateAndCountInvocations(text); + Assert.Equal(0, result.Result); + Assert.Equal(1, result.NullInt32FunctionCount); + } + + private static int? NullInt32Function(InvocationResult ir) + { + ir.NullInt32FunctionCount++; + return null; + } + + private static int? NonNullInt32Function(InvocationResult ir) + { + ir.NonNullInt32FunctionCount++; + return 42; + } + + private sealed class InvocationResult + { + public object Result { get; set; } + public int NullInt32FunctionCount { get; set; } + public int NonNullInt32FunctionCount { get; set; } + } + + private sealed class InvocationResultFunctionSymbol : FunctionSymbol + { + public InvocationResultFunctionSymbol(string name, Func function) + : base(name, typeof(TResult).GetNonNullableType(), new ParameterSymbol("ir", typeof(InvocationResult))) + { + Function = function; + } + + public override Expression CreateInvocation(IEnumerable arguments) + { + return Expression.Call(Function.Method, arguments); + } + + private Func Function { get; } + } + } +} diff --git a/src/NQuery/Iterators/ExpressionBuilder.cs b/src/NQuery/Iterators/ExpressionBuilder.cs index d47fc49c..9fbf697a 100644 --- a/src/NQuery/Iterators/ExpressionBuilder.cs +++ b/src/NQuery/Iterators/ExpressionBuilder.cs @@ -39,13 +39,19 @@ public static IteratorPredicate BuildIteratorPredicate(BoundExpression predicate return BuildExpression(predicate, typeof(bool), allocation); } - private static TDelegate BuildExpression(BoundExpression expression, Type targetType, RowBufferAllocation allocation) + private static TDelegate BuildExpression(BoundExpression expression, Type targetType, RowBufferAllocation allocation) where TDelegate : Delegate + { + var lambda = BuildExpression(expression, typeof(TDelegate), targetType, allocation); + return (TDelegate)lambda.Compile(); + } + + private static LambdaExpression BuildExpression(BoundExpression expression, Type delegateType, Type targetType, RowBufferAllocation allocation) { var builder = new ExpressionBuilder(allocation); - return builder.BuildExpression(expression, targetType); + return builder.BuildExpression(expression, delegateType, targetType); } - private ParameterExpression BuildCachedExpression(BoundExpression expression) + private Expression BuildCachedExpression(BoundExpression expression) { var result = BuildExpression(expression); var liftedExpression = BuildLiftedExpression(result); @@ -63,11 +69,6 @@ private static Expression BuildLiftedExpression(Expression result) : Expression.Convert(result, result.Type.GetNullableType()); } - private Expression BuildLiftedExpression(BoundExpression expression) - { - return BuildLiftedExpression(BuildExpression(expression)); - } - private static Expression BuildLoweredExpression(Expression expression) { if (!expression.Type.IsNullableOfT()) @@ -113,7 +114,7 @@ private static Expression BuildInvocation(MethodSymbol methodSymbol, Expression return BuildLiftedExpression( methodSymbol.CreateInvocation( - BuildLoweredExpression(instance), + BuildLoweredExpression(instance), arguments.Select(BuildLoweredExpression) ) ); @@ -144,17 +145,17 @@ private static UnaryExpression BuildNullableTrue() return Expression.Convert(Expression.Constant(true), typeof(bool?)); } - private TDelegate BuildExpression(BoundExpression expression, Type targetType) + private LambdaExpression BuildExpression(BoundExpression expression, Type delegateType, Type targetType) { var actualExpression = BuildCachedExpression(expression); var coalescedExpression = targetType.CanBeNull() - ? (Expression)actualExpression + ? actualExpression : Expression.Coalesce(actualExpression, Expression.Default(targetType)); var resultExpression = Expression.Convert(coalescedExpression, targetType); var expressions = _assignments.Concat(new[] { resultExpression }); var body = Expression.Block(_locals, expressions); - var lambda = Expression.Lambda(body); - return lambda.Compile(); + var lambda = Expression.Lambda(delegateType, body); + return lambda; } private Expression BuildExpression(BoundExpression expression) @@ -455,7 +456,7 @@ private Expression BuildCaseLabel(BoundCaseExpression caseExpression, int caseLa if (caseLabelIndex == caseExpression.CaseLabels.Length) return caseExpression.ElseExpression is null ? BuildNullValue(caseExpression.Type) - : BuildLiftedExpression(caseExpression.ElseExpression); + : BuildNestedScopeInvocation(caseExpression.ElseExpression); var caseLabel = caseExpression.CaseLabels[caseLabelIndex]; var condition = caseLabel.Condition; @@ -464,12 +465,21 @@ private Expression BuildCaseLabel(BoundCaseExpression caseExpression, int caseLa return Expression.Condition( Expression.Equal( - BuildLiftedExpression(condition), + BuildNestedScopeInvocation(condition), BuildNullableTrue() ), - BuildLiftedExpression(result), + BuildNestedScopeInvocation(result), BuildCaseLabel(caseExpression, caseLabelIndex + 1) ); } + + private Expression BuildNestedScopeInvocation(BoundExpression expression) + { + var targetType = expression.Type; + var delegateType = typeof(Func<>).MakeGenericType(targetType); + var lambda = BuildExpression(expression, delegateType, targetType, _rowBufferAllocation); + var invocation = Expression.Invoke(lambda); + return BuildLiftedExpression(invocation); + } } } \ No newline at end of file