Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For function calls, pass on the delegate's target instance if required #68

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/NQuery.Tests/Evaluation/EvaluationTest.cs
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
namespace NQuery.Tests.Evaluation
{
public class EvaluationTest
public abstract class EvaluationTest
{
protected static void AssertProduces<T>(string text, T[] expected)
protected static void AssertProduces<T>(string text, T[] expected, DataContext dataContext = null)
{
var expectedColumns = new[] { typeof(T) };
var expectedRows = new object[expected.Length][];

for (var i = 0; i < expected.Length; i++)
expectedRows[i] = new object[] { expected[i] };

AssertProduces(text, expectedColumns, expectedRows);
AssertProduces(text, expectedColumns, expectedRows, dataContext);
}

protected static void AssertProduces<T1, T2>(string text, (T1, T2)[] expected)
protected static void AssertProduces<T1, T2>(string text, (T1, T2)[] expected, DataContext dataContext = null)
{
var expectedColumns = new[] { typeof(T1), typeof(T2) };
var expectedRows = new object[expected.Length][];

for (var i = 0; i < expected.Length; i++)
expectedRows[i] = new object[] { expected[i].Item1, expected[i].Item2 };

AssertProduces(text, expectedColumns, expectedRows);
AssertProduces(text, expectedColumns, expectedRows, dataContext);
}

private static void AssertProduces(string text, Type[] expectedColumns, object[][] expectedRows)
private static void AssertProduces(string text, Type[] expectedColumns, object[][] expectedRows, DataContext dataContext)
{
var dataContext = NorthwindDataContext.Instance;
dataContext ??= NorthwindDataContext.Instance;
var query = Query.Create(dataContext, text);
using var data = query.ExecuteReader();

Expand Down
91 changes: 91 additions & 0 deletions src/NQuery.Tests/Evaluation/FunctionInvocationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using NQuery.Symbols;

namespace NQuery.Tests.Evaluation
{
public sealed class FunctionInvocationTests : EvaluationTest
{
private readonly DataContext _dataContext;

public FunctionInvocationTests()
{
_dataContext = DataContext.Default.AddFunctions(
new FunctionSymbol<int>(nameof(StaticFunction0), StaticFunction0),
new FunctionSymbol<int, int>(nameof(StaticFunction1), StaticFunction1),
new FunctionSymbol<int, int, int>(nameof(StaticFunction2), StaticFunction2),
new FunctionSymbol<int, int, int, int>(nameof(StaticFunction3), StaticFunction3),
new FunctionSymbol<int>(nameof(InstanceFunction0), InstanceFunction0),
new FunctionSymbol<int, int>(nameof(InstanceFunction1), InstanceFunction1),
new FunctionSymbol<int, int, int>(nameof(InstanceFunction2), InstanceFunction2),
new FunctionSymbol<int, int, int, int>(nameof(InstanceFunction3), InstanceFunction3)
);
}

[Theory]
[InlineData(nameof(StaticFunction0) + "()")]
[InlineData(nameof(StaticFunction1) + "(42)")]
[InlineData(nameof(StaticFunction2) + "(40, 2)")]
[InlineData(nameof(StaticFunction3) + "(20, 20, 2)")]
public void Evaluation_FunctionInvocationExpression_StaticFunction(string functionInvocation)
{
var text = "SELECT " + functionInvocation;

var expected = new[] { 42 };

AssertProduces(text, expected, _dataContext);
}

[Theory]
[InlineData(nameof(InstanceFunction0) + "()")]
[InlineData(nameof(InstanceFunction1) + "(42)")]
[InlineData(nameof(InstanceFunction2) + "(40, 2)")]
[InlineData(nameof(InstanceFunction3) + "(20, 20, 2)")]
public void Evaluation_FunctionInvocationExpression_InstanceFunction(string functionInvocation)
{
var text = "SELECT " + functionInvocation;

var expected = new[] { GetHashCode() + 42 };

AssertProduces(text, expected, _dataContext);
}

private static int StaticFunction0()
{
return 42;
}

private static int StaticFunction1(int arg)
{
return arg;
}

private static int StaticFunction2(int arg1, int arg2)
{
return arg1 + arg2;
}

private static int StaticFunction3(int arg1, int arg2, int arg3)
{
return arg1 + arg2 + arg3;
}

private int InstanceFunction0()
{
return GetHashCode() + 42;
}

private int InstanceFunction1(int arg)
{
return GetHashCode() + arg;
}

private int InstanceFunction2(int arg1, int arg2)
{
return GetHashCode() + arg1 + arg2;
}

private int InstanceFunction3(int arg1, int arg2, int arg3)
{
return GetHashCode() + arg1 + arg2 + arg3;
}
}
}
41 changes: 24 additions & 17 deletions src/NQuery/Symbols/FunctionSymbol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@ protected FunctionSymbol(string name, Type type, params ParameterSymbol[] parame
{
}

public abstract Expression CreateInvocation(IEnumerable<Expression> arguments);
public Expression CreateInvocation(IEnumerable<Expression> arguments)
{
var function = FunctionDelegate;
var instance = function.Target == null ? null : Expression.Constant(function.Target);
return Expression.Call(instance, function.Method, arguments);
}

public override SymbolKind Kind
{
get { return SymbolKind.Function; }
}

protected abstract Delegate FunctionDelegate { get; }
}

public sealed class FunctionSymbol<TResult> : FunctionSymbol
Expand All @@ -30,12 +37,12 @@ public FunctionSymbol(string name, Func<TResult> function)
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
public Func<TResult> Function { get; }

protected override Delegate FunctionDelegate
{
return Expression.Call(Function.Method, arguments);
get { return Function; }
}

public Func<TResult> Function { get; }
}

public sealed class FunctionSymbol<T, TResult> : FunctionSymbol
Expand All @@ -52,12 +59,12 @@ public FunctionSymbol(string name, string argumentName, Func<T, TResult> functio
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
public Func<T, TResult> Function { get; }

protected override Delegate FunctionDelegate
{
return Expression.Call(Function.Method, arguments);
get { return Function; }
}

public Func<T, TResult> Function { get; }
}

public sealed class FunctionSymbol<T1, T2, TResult> : FunctionSymbol
Expand All @@ -73,12 +80,12 @@ public FunctionSymbol(string name, string parameterName1, string parameterName2,
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
public Func<T1, T2, TResult> Function { get; }

protected override Delegate FunctionDelegate
{
return Expression.Call(Function.Method, arguments);
get { return Function; }
}

public Func<T1, T2, TResult> Function { get; }
}

public sealed class FunctionSymbol<T1, T2, T3, TResult> : FunctionSymbol
Expand All @@ -94,11 +101,11 @@ public FunctionSymbol(string name, string parameterName1, string parameterName2,
Function = function;
}

public override Expression CreateInvocation(IEnumerable<Expression> arguments)
public Func<T1, T2, T3, TResult> Function { get; }

protected override Delegate FunctionDelegate
{
return Expression.Call(Function.Method, arguments);
get { return Function; }
}

public Func<T1, T2, T3, TResult> Function { get; }
}
}
Loading