Skip to content

Commit

Permalink
Pass on delegate target instance if required
Browse files Browse the repository at this point in the history
  • Loading branch information
dallmair authored and terrajobst committed Oct 2, 2024
1 parent 39061c9 commit 6909dd1
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 24 deletions.
14 changes: 7 additions & 7 deletions src/NQuery.Tests/Evaluation/EvaluationTest.cs
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
namespace NQuery.Tests.Evaluation
{
public class EvaluationTest
public abstract class EvaluationTest
{
protected static void AssertProduces<T>(string text, T[] expected)
protected static void AssertProduces<T>(string text, T[] expected, DataContext dataContext = null)
{
var expectedColumns = new[] { typeof(T) };
var expectedRows = new object[expected.Length][];

for (var i = 0; i < expected.Length; i++)
expectedRows[i] = new object[] { expected[i] };

AssertProduces(text, expectedColumns, expectedRows);
AssertProduces(text, expectedColumns, expectedRows, dataContext);
}

protected static void AssertProduces<T1, T2>(string text, (T1, T2)[] expected)
protected static void AssertProduces<T1, T2>(string text, (T1, T2)[] expected, DataContext dataContext = null)
{
var expectedColumns = new[] { typeof(T1), typeof(T2) };
var expectedRows = new object[expected.Length][];

for (var i = 0; i < expected.Length; i++)
expectedRows[i] = new object[] { expected[i].Item1, expected[i].Item2 };

AssertProduces(text, expectedColumns, expectedRows);
AssertProduces(text, expectedColumns, expectedRows, dataContext);
}

private static void AssertProduces(string text, Type[] expectedColumns, object[][] expectedRows)
private static void AssertProduces(string text, Type[] expectedColumns, object[][] expectedRows, DataContext dataContext)
{
var dataContext = NorthwindDataContext.Instance;
dataContext ??= NorthwindDataContext.Instance;
var query = Query.Create(dataContext, text);
using var data = query.ExecuteReader();

Expand Down
91 changes: 91 additions & 0 deletions src/NQuery.Tests/Evaluation/FunctionInvocationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using NQuery.Symbols;

namespace NQuery.Tests.Evaluation
{
public sealed class FunctionInvocationTests : EvaluationTest
{
private readonly DataContext _dataContext;

public FunctionInvocationTests()
{
_dataContext = DataContext.Default.AddFunctions(
new FunctionSymbol<int>(nameof(StaticFunction0), StaticFunction0),
new FunctionSymbol<int, int>(nameof(StaticFunction1), StaticFunction1),
new FunctionSymbol<int, int, int>(nameof(StaticFunction2), StaticFunction2),
new FunctionSymbol<int, int, int, int>(nameof(StaticFunction3), StaticFunction3),
new FunctionSymbol<int>(nameof(InstanceFunction0), InstanceFunction0),
new FunctionSymbol<int, int>(nameof(InstanceFunction1), InstanceFunction1),
new FunctionSymbol<int, int, int>(nameof(InstanceFunction2), InstanceFunction2),
new FunctionSymbol<int, int, int, int>(nameof(InstanceFunction3), InstanceFunction3)
);
}

[Theory]
[InlineData(nameof(StaticFunction0) + "()")]
[InlineData(nameof(StaticFunction1) + "(42)")]
[InlineData(nameof(StaticFunction2) + "(40, 2)")]
[InlineData(nameof(StaticFunction3) + "(20, 20, 2)")]
public void Evaluation_FunctionInvocationExpression_StaticFunction(string functionInvocation)
{
var text = "SELECT " + functionInvocation;

var expected = new[] { 42 };

AssertProduces(text, expected, _dataContext);
}

[Theory]
[InlineData(nameof(InstanceFunction0) + "()")]
[InlineData(nameof(InstanceFunction1) + "(42)")]
[InlineData(nameof(InstanceFunction2) + "(40, 2)")]
[InlineData(nameof(InstanceFunction3) + "(20, 20, 2)")]
public void Evaluation_FunctionInvocationExpression_InstanceFunction(string functionInvocation)
{
var text = "SELECT " + functionInvocation;

var expected = new[] { GetHashCode() + 42 };

AssertProduces(text, expected, _dataContext);
}

private static int StaticFunction0()
{
return 42;
}

private static int StaticFunction1(int arg)
{
return arg;
}

private static int StaticFunction2(int arg1, int arg2)
{
return arg1 + arg2;
}

private static int StaticFunction3(int arg1, int arg2, int arg3)
{
return arg1 + arg2 + arg3;
}

private int InstanceFunction0()
{
return GetHashCode() + 42;
}

private int InstanceFunction1(int arg)
{
return GetHashCode() + arg;
}

private int InstanceFunction2(int arg1, int arg2)
{
return GetHashCode() + arg1 + arg2;
}

private int InstanceFunction3(int arg1, int arg2, int arg3)
{
return GetHashCode() + arg1 + arg2 + arg3;
}
}
}
41 changes: 24 additions & 17 deletions src/NQuery/Symbols/FunctionSymbol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@ protected FunctionSymbol(string name, Type type, params ParameterSymbol[] parame
{
}

public abstract Expression CreateInvocation(IEnumerable<Expression> arguments);
public Expression CreateInvocation(IEnumerable<Expression> arguments)
{
var function = FunctionDelegate;
var instance = function.Target == null ? null : Expression.Constant(function.Target);
return Expression.Call(instance, function.Method, arguments);
}

public override SymbolKind Kind
{
get { return SymbolKind.Function; }
}

protected abstract Delegate FunctionDelegate { get; }
}

public sealed class FunctionSymbol<TResult> : FunctionSymbol
Expand All @@ -30,12 +37,12 @@ public FunctionSymbol(string name, Func<TResult> function)
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
public Func<TResult> Function { get; }

protected override Delegate FunctionDelegate
{
return Expression.Call(Function.Method, arguments);
get { return Function; }
}

public Func<TResult> Function { get; }
}

public sealed class FunctionSymbol<T, TResult> : FunctionSymbol
Expand All @@ -52,12 +59,12 @@ public FunctionSymbol(string name, string argumentName, Func<T, TResult> functio
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
public Func<T, TResult> Function { get; }

protected override Delegate FunctionDelegate
{
return Expression.Call(Function.Method, arguments);
get { return Function; }
}

public Func<T, TResult> Function { get; }
}

public sealed class FunctionSymbol<T1, T2, TResult> : FunctionSymbol
Expand All @@ -73,12 +80,12 @@ public FunctionSymbol(string name, string parameterName1, string parameterName2,
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
public Func<T1, T2, TResult> Function { get; }

protected override Delegate FunctionDelegate
{
return Expression.Call(Function.Method, arguments);
get { return Function; }
}

public Func<T1, T2, TResult> Function { get; }
}

public sealed class FunctionSymbol<T1, T2, T3, TResult> : FunctionSymbol
Expand All @@ -94,11 +101,11 @@ public FunctionSymbol(string name, string parameterName1, string parameterName2,
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
public Func<T1, T2, T3, TResult> Function { get; }

protected override Delegate FunctionDelegate
{
return Expression.Call(Function.Method, arguments);
get { return Function; }
}

public Func<T1, T2, T3, TResult> Function { get; }
}
}

0 comments on commit 6909dd1

Please sign in to comment.