From 8f4d4940795dd085a1e54c10eae15707d978e1ca Mon Sep 17 00:00:00 2001 From: Thomas Dallmair Date: Fri, 16 Aug 2024 23:00:58 +0200 Subject: [PATCH 1/3] Compile `CASE WHEN` expressions to dedicated functions Closes #65 --- .../ExpressionTests.Evaluation.cs | 150 ++++++++++++++++++ src/NQuery/Iterators/ExpressionBuilder.cs | 36 +++-- 2 files changed, 173 insertions(+), 13 deletions(-) create mode 100644 src/NQuery.Tests/ExpressionTests.Evaluation.cs diff --git a/src/NQuery.Tests/ExpressionTests.Evaluation.cs b/src/NQuery.Tests/ExpressionTests.Evaluation.cs new file mode 100644 index 00000000..7a912052 --- /dev/null +++ b/src/NQuery.Tests/ExpressionTests.Evaluation.cs @@ -0,0 +1,150 @@ +using System.Linq.Expressions; + +using NQuery.Symbols; + +namespace NQuery.Tests +{ + public partial class ExpressionTests + { + private static InvocationResult EvaluateAndCountInvocations(string text) + { + var invocationResult = new InvocationResult(); + var invocationResultVariable = new VariableSymbol("ir", typeof(InvocationResult), invocationResult); + var nullInt32Function = new InvocationResultFunctionSymbol("NULL_INT32", NullInt32Function); + var nonNullInt32Function = new InvocationResultFunctionSymbol("NON_NULL_INT32", NonNullInt32Function); + var dataContext = DataContext.Default + .AddVariables(invocationResultVariable) + .AddFunctions(nullInt32Function, nonNullInt32Function); + var expression = Expression.Create(dataContext, text); + invocationResult.Result = expression.Evaluate(); + return invocationResult; + } + + [Fact] + public void Expression_Evaluation_Conversion_Once() + { + var result = EvaluateAndCountInvocations("CAST(NON_NULL_INT32(ir) AS int64)"); + Assert.Equal(42L, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_Unary_Once() + { + var result = EvaluateAndCountInvocations("~NON_NULL_INT32(ir)"); + Assert.Equal(~42, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_Binary_EagerOnce() + { + var result = EvaluateAndCountInvocations("NULL_INT32(ir) + NON_NULL_INT32(ir)"); + Assert.Null(result.Result); + Assert.Equal(1, result.NullInt32FunctionCount); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_FunctionInvocation_EagerOnce() + { + var result = EvaluateAndCountInvocations("SUBSTRING('abc', NULL_INT32(ir), NON_NULL_INT32(ir))"); + Assert.Null(result.Result); + Assert.Equal(1, result.NullInt32FunctionCount); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_MethodInvocation_Instance_Once() + { + var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir).Equals(42)"); + Assert.Equal(true, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_MethodInvocation_Arguments_EagerOnce() + { + var result = EvaluateAndCountInvocations("''.Substring(NULL_INT32(ir), NON_NULL_INT32(ir))"); + Assert.Null(result.Result); + Assert.Equal(1, result.NullInt32FunctionCount); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_PropertyAccess_Once() + { + var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir).Equals(42)"); + Assert.Equal(true, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_IsNull_Once() + { + var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir) IS NOT NULL"); + Assert.Equal(true, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_CaseWhen_LazyOnce_Simple() + { + var result = EvaluateAndCountInvocations("CASE WHEN NON_NULL_INT32(ir) = 42 THEN 42 ELSE NULL_INT32(ir) END"); + Assert.Equal(42, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + Assert.Equal(0, result.NullInt32FunctionCount); + } + + [Fact] + public void Expression_Evaluation_CaseWhen_LazyOnce_Complex() + { + const string text = @" + CASE + WHEN TO_INT32(NON_NULL_INT32(ir)) = 42 THEN 42 + WHEN TO_INT32(NON_NULL_INT32(ir)) != 42 THEN 0 + ELSE TO_INT32(NULL_INT32(ir)) + END"; + + var result = EvaluateAndCountInvocations(text); + Assert.Equal(42, result.Result); + Assert.Equal(1, result.NonNullInt32FunctionCount); + Assert.Equal(0, result.NullInt32FunctionCount); + } + + private static int? NullInt32Function(InvocationResult ir) + { + ir.NullInt32FunctionCount++; + return null; + } + + private static int? NonNullInt32Function(InvocationResult ir) + { + ir.NonNullInt32FunctionCount++; + return 42; + } + + private sealed class InvocationResult + { + public object Result { get; set; } + public int NullInt32FunctionCount { get; set; } + public int NonNullInt32FunctionCount { get; set; } + } + + private sealed class InvocationResultFunctionSymbol : FunctionSymbol + { + public InvocationResultFunctionSymbol(string name, Func function) + : base(name, typeof(TResult).GetNonNullableType(), new ParameterSymbol("ir", typeof(InvocationResult))) + { + Function = function; + } + + public override Expression CreateInvocation(IEnumerable arguments) + { + return Expression.Call(Function.Method, arguments); + } + + private Func Function { get; } + } + } +} diff --git a/src/NQuery/Iterators/ExpressionBuilder.cs b/src/NQuery/Iterators/ExpressionBuilder.cs index d47fc49c..e26bd952 100644 --- a/src/NQuery/Iterators/ExpressionBuilder.cs +++ b/src/NQuery/Iterators/ExpressionBuilder.cs @@ -39,10 +39,16 @@ public static IteratorPredicate BuildIteratorPredicate(BoundExpression predicate return BuildExpression(predicate, typeof(bool), allocation); } - private static TDelegate BuildExpression(BoundExpression expression, Type targetType, RowBufferAllocation allocation) + private static TDelegate BuildExpression(BoundExpression expression, Type targetType, RowBufferAllocation allocation) where TDelegate : Delegate + { + var lambda = BuildExpression(expression, typeof(TDelegate), targetType, allocation); + return (TDelegate)lambda.Compile(); + } + + private static LambdaExpression BuildExpression(BoundExpression expression, Type delegateType, Type targetType, RowBufferAllocation allocation) { var builder = new ExpressionBuilder(allocation); - return builder.BuildExpression(expression, targetType); + return builder.BuildExpression(expression, delegateType, targetType); } private ParameterExpression BuildCachedExpression(BoundExpression expression) @@ -63,11 +69,6 @@ private static Expression BuildLiftedExpression(Expression result) : Expression.Convert(result, result.Type.GetNullableType()); } - private Expression BuildLiftedExpression(BoundExpression expression) - { - return BuildLiftedExpression(BuildExpression(expression)); - } - private static Expression BuildLoweredExpression(Expression expression) { if (!expression.Type.IsNullableOfT()) @@ -144,7 +145,7 @@ private static UnaryExpression BuildNullableTrue() return Expression.Convert(Expression.Constant(true), typeof(bool?)); } - private TDelegate BuildExpression(BoundExpression expression, Type targetType) + private LambdaExpression BuildExpression(BoundExpression expression, Type delegateType, Type targetType) { var actualExpression = BuildCachedExpression(expression); var coalescedExpression = targetType.CanBeNull() @@ -153,8 +154,8 @@ private TDelegate BuildExpression(BoundExpression expression, Type ta var resultExpression = Expression.Convert(coalescedExpression, targetType); var expressions = _assignments.Concat(new[] { resultExpression }); var body = Expression.Block(_locals, expressions); - var lambda = Expression.Lambda(body); - return lambda.Compile(); + var lambda = Expression.Lambda(delegateType, body); + return lambda; } private Expression BuildExpression(BoundExpression expression) @@ -455,7 +456,7 @@ private Expression BuildCaseLabel(BoundCaseExpression caseExpression, int caseLa if (caseLabelIndex == caseExpression.CaseLabels.Length) return caseExpression.ElseExpression is null ? BuildNullValue(caseExpression.Type) - : BuildLiftedExpression(caseExpression.ElseExpression); + : BuildNestedScopeInvocation(caseExpression.ElseExpression); var caseLabel = caseExpression.CaseLabels[caseLabelIndex]; var condition = caseLabel.Condition; @@ -464,12 +465,21 @@ private Expression BuildCaseLabel(BoundCaseExpression caseExpression, int caseLa return Expression.Condition( Expression.Equal( - BuildLiftedExpression(condition), + BuildNestedScopeInvocation(condition), BuildNullableTrue() ), - BuildLiftedExpression(result), + BuildNestedScopeInvocation(result), BuildCaseLabel(caseExpression, caseLabelIndex + 1) ); } + + private Expression BuildNestedScopeInvocation(BoundExpression expression) + { + var targetType = expression.Type; + var delegateType = typeof(Func<>).MakeGenericType(targetType); + var lambda = BuildExpression(expression, delegateType, targetType, _rowBufferAllocation); + var invocation = Expression.Invoke(lambda); + return BuildLiftedExpression(invocation); + } } } \ No newline at end of file From c367bc5a0dbda6474f02cd96b7c7715113f14e2b Mon Sep 17 00:00:00 2001 From: Thomas Dallmair Date: Fri, 16 Aug 2024 23:03:43 +0200 Subject: [PATCH 2/3] Tiny code simplification --- src/NQuery/Iterators/ExpressionBuilder.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/NQuery/Iterators/ExpressionBuilder.cs b/src/NQuery/Iterators/ExpressionBuilder.cs index e26bd952..9fbf697a 100644 --- a/src/NQuery/Iterators/ExpressionBuilder.cs +++ b/src/NQuery/Iterators/ExpressionBuilder.cs @@ -51,7 +51,7 @@ private static LambdaExpression BuildExpression(BoundExpression expression, Type return builder.BuildExpression(expression, delegateType, targetType); } - private ParameterExpression BuildCachedExpression(BoundExpression expression) + private Expression BuildCachedExpression(BoundExpression expression) { var result = BuildExpression(expression); var liftedExpression = BuildLiftedExpression(result); @@ -114,7 +114,7 @@ private static Expression BuildInvocation(MethodSymbol methodSymbol, Expression return BuildLiftedExpression( methodSymbol.CreateInvocation( - BuildLoweredExpression(instance), + BuildLoweredExpression(instance), arguments.Select(BuildLoweredExpression) ) ); @@ -149,7 +149,7 @@ private LambdaExpression BuildExpression(BoundExpression expression, Type delega { var actualExpression = BuildCachedExpression(expression); var coalescedExpression = targetType.CanBeNull() - ? (Expression)actualExpression + ? actualExpression : Expression.Coalesce(actualExpression, Expression.Default(targetType)); var resultExpression = Expression.Convert(coalescedExpression, targetType); var expressions = _assignments.Concat(new[] { resultExpression }); From 6f79f5c18e5430e94785efce50cb329b552e0c59 Mon Sep 17 00:00:00 2001 From: Thomas Dallmair Date: Mon, 19 Aug 2024 12:20:16 +0200 Subject: [PATCH 3/3] Move tests and add another test case --- .../EagerAndLazyTests.cs} | 46 +++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) rename src/NQuery.Tests/{ExpressionTests.Evaluation.cs => Evaluation/EagerAndLazyTests.cs} (79%) diff --git a/src/NQuery.Tests/ExpressionTests.Evaluation.cs b/src/NQuery.Tests/Evaluation/EagerAndLazyTests.cs similarity index 79% rename from src/NQuery.Tests/ExpressionTests.Evaluation.cs rename to src/NQuery.Tests/Evaluation/EagerAndLazyTests.cs index 7a912052..5c6a4a9e 100644 --- a/src/NQuery.Tests/ExpressionTests.Evaluation.cs +++ b/src/NQuery.Tests/Evaluation/EagerAndLazyTests.cs @@ -2,9 +2,9 @@ using NQuery.Symbols; -namespace NQuery.Tests +namespace NQuery.Tests.Evaluation { - public partial class ExpressionTests + public sealed class EagerAndLazyTests { private static InvocationResult EvaluateAndCountInvocations(string text) { @@ -21,7 +21,7 @@ private static InvocationResult EvaluateAndCountInvocations(string text) } [Fact] - public void Expression_Evaluation_Conversion_Once() + public void Evaluation_Conversion_Once() { var result = EvaluateAndCountInvocations("CAST(NON_NULL_INT32(ir) AS int64)"); Assert.Equal(42L, result.Result); @@ -29,7 +29,7 @@ public void Expression_Evaluation_Conversion_Once() } [Fact] - public void Expression_Evaluation_Unary_Once() + public void Evaluation_Unary_Once() { var result = EvaluateAndCountInvocations("~NON_NULL_INT32(ir)"); Assert.Equal(~42, result.Result); @@ -37,7 +37,7 @@ public void Expression_Evaluation_Unary_Once() } [Fact] - public void Expression_Evaluation_Binary_EagerOnce() + public void Evaluation_Binary_EagerOnce() { var result = EvaluateAndCountInvocations("NULL_INT32(ir) + NON_NULL_INT32(ir)"); Assert.Null(result.Result); @@ -46,7 +46,7 @@ public void Expression_Evaluation_Binary_EagerOnce() } [Fact] - public void Expression_Evaluation_FunctionInvocation_EagerOnce() + public void Evaluation_FunctionInvocation_EagerOnce() { var result = EvaluateAndCountInvocations("SUBSTRING('abc', NULL_INT32(ir), NON_NULL_INT32(ir))"); Assert.Null(result.Result); @@ -55,7 +55,7 @@ public void Expression_Evaluation_FunctionInvocation_EagerOnce() } [Fact] - public void Expression_Evaluation_MethodInvocation_Instance_Once() + public void Evaluation_MethodInvocation_Instance_Once() { var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir).Equals(42)"); Assert.Equal(true, result.Result); @@ -63,7 +63,7 @@ public void Expression_Evaluation_MethodInvocation_Instance_Once() } [Fact] - public void Expression_Evaluation_MethodInvocation_Arguments_EagerOnce() + public void Evaluation_MethodInvocation_Arguments_EagerOnce() { var result = EvaluateAndCountInvocations("''.Substring(NULL_INT32(ir), NON_NULL_INT32(ir))"); Assert.Null(result.Result); @@ -72,7 +72,7 @@ public void Expression_Evaluation_MethodInvocation_Arguments_EagerOnce() } [Fact] - public void Expression_Evaluation_PropertyAccess_Once() + public void Evaluation_PropertyAccess_Once() { var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir).Equals(42)"); Assert.Equal(true, result.Result); @@ -80,7 +80,7 @@ public void Expression_Evaluation_PropertyAccess_Once() } [Fact] - public void Expression_Evaluation_IsNull_Once() + public void Evaluation_IsNull_Once() { var result = EvaluateAndCountInvocations("NON_NULL_INT32(ir) IS NOT NULL"); Assert.Equal(true, result.Result); @@ -88,16 +88,22 @@ public void Expression_Evaluation_IsNull_Once() } [Fact] - public void Expression_Evaluation_CaseWhen_LazyOnce_Simple() + public void Evaluation_CaseWhen_NonNullFunction_LazyOnce() { - var result = EvaluateAndCountInvocations("CASE WHEN NON_NULL_INT32(ir) = 42 THEN 42 ELSE NULL_INT32(ir) END"); + const string text = @" + CASE + WHEN NON_NULL_INT32(ir) = 42 THEN 42 + ELSE NULL_INT32(ir) + END"; + + var result = EvaluateAndCountInvocations(text); Assert.Equal(42, result.Result); Assert.Equal(1, result.NonNullInt32FunctionCount); Assert.Equal(0, result.NullInt32FunctionCount); } [Fact] - public void Expression_Evaluation_CaseWhen_LazyOnce_Complex() + public void Evaluation_CaseWhen_NonNullNestedFunction_LazyOnce() { const string text = @" CASE @@ -112,6 +118,20 @@ ELSE TO_INT32(NULL_INT32(ir)) Assert.Equal(0, result.NullInt32FunctionCount); } + [Fact] + public void Evaluation_CaseWhen_NullFunction_LazyOnce() + { + const string text = @" + CASE + WHEN NULL_INT32(ir) = 0 THEN 42 + ELSE 0 + END"; + + var result = EvaluateAndCountInvocations(text); + Assert.Equal(0, result.Result); + Assert.Equal(1, result.NullInt32FunctionCount); + } + private static int? NullInt32Function(InvocationResult ir) { ir.NullInt32FunctionCount++;