Skip to content

Commit

Permalink
Ports updated fix for #9128 to dev.
Browse files Browse the repository at this point in the history
- Update TaskLiftingExpressionVisitor to detect blocking introduced by TaskBlockingExpressionVisitor. :trollface:
- Switched Include Compiler to use TaskBlockingExpressionVisitor instead of .Result.
- Reverted previous changes to TaskBlockingExpressionVisitor.
  • Loading branch information
anpete committed Sep 14, 2017
1 parent 7ecc4e5 commit b5dcbb1
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 25 deletions.
26 changes: 18 additions & 8 deletions src/EFCore/Query/EntityQueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
using Remotion.Linq.Clauses.ResultOperators;
using Remotion.Linq.Clauses.StreamedData;
using Remotion.Linq.Parsing;
using Microsoft.EntityFrameworkCore.Extensions.Internal;

namespace Microsoft.EntityFrameworkCore.Query
{
Expand Down Expand Up @@ -540,16 +541,25 @@ var entityTrackingInfos

private class IncludeRemovingExpressionVisitor : RelinqExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression node)
=> IncludeCompiler.IsIncludeMethod(node)
? node.Arguments[1]
: base.VisitMethodCall(node);

protected override Expression VisitMember(MemberExpression node)
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
var newExpression = Visit(node.Expression);
if (IncludeCompiler.IsIncludeMethod(methodCallExpression))
{
return methodCallExpression.Arguments[1];
}

if (methodCallExpression.Method
.MethodIsClosedFormOf(TaskBlockingExpressionVisitor.ResultMethodInfo))
{
var newArguments = VisitAndConvert(methodCallExpression.Arguments, "VisitMethodCall");

if (newArguments != methodCallExpression.Arguments)
{
return newArguments[0];
}
}

return newExpression != node.Expression ? newExpression : node;
return base.VisitMethodCall(methodCallExpression);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;
using JetBrains.Annotations;

namespace Microsoft.EntityFrameworkCore.Query.ExpressionVisitors.Internal
{
Expand All @@ -24,13 +25,24 @@ public override Expression Visit(Expression expression)
var typeInfo = expression.Type.GetTypeInfo();

if (typeInfo.IsGenericType
&& typeInfo.GetGenericTypeDefinition() == typeof(Task<>))
&& (typeInfo.GetGenericTypeDefinition() == typeof(Task<>)))
{
return Expression.Property(expression, nameof(Task<object>.Result));
return Expression.Call(
_resultMethodInfo.MakeGenericMethod(typeInfo.GenericTypeArguments[0]),
expression);
}
}

return expression;
}

private static readonly MethodInfo _resultMethodInfo
= typeof(TaskBlockingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(Result));

internal static MethodInfo ResultMethodInfo => _resultMethodInfo;

[UsedImplicitly]
private static T Result<T>(Task<T> task) => task.GetAwaiter().GetResult();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using JetBrains.Annotations;
using Remotion.Linq.Parsing;
using Microsoft.EntityFrameworkCore.Extensions.Internal;

namespace Microsoft.EntityFrameworkCore.Query.ExpressionVisitors.Internal
{
Expand Down Expand Up @@ -150,6 +151,38 @@ protected override Expression VisitMember(MemberExpression memberExpression)
return base.VisitMember(memberExpression);
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method
.MethodIsClosedFormOf(TaskBlockingExpressionVisitor.ResultMethodInfo))
{
_taskExpressions.Add(
Expression.Lambda<Func<Task<object>>>(
Expression.Call(
_toObjectTask.MakeGenericMethod(
methodCallExpression.Method.ReturnType),
methodCallExpression.Arguments[0])));

if (CancellationTokenParameter == null)
{
Visit(methodCallExpression.Arguments[0]);
}

return
Expression.Convert(
Expression.ArrayAccess(
_resultsParameter,
Expression.Constant(_taskExpressions.Count - 1)),
methodCallExpression.Method.ReturnType);
}

return base.VisitMethodCall(methodCallExpression);
}

// Prune nodes

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ResultOperators;
using Remotion.Linq.Parsing;
using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors.Internal;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
Expand Down Expand Up @@ -94,21 +95,21 @@ var entityParameter

var includeExpression
= blockExpressions.Last().Type == typeof(Task)
? (Expression)Expression.Property(
Expression.Call(
_includeAsyncMethodInfo
.MakeGenericMethod(targetQuerySourceReferenceExpression.Type),
EntityQueryModelVisitor.QueryContextParameter,
targetQuerySourceReferenceExpression,
Expression.NewArrayInit(typeof(object), propertyExpressions),
Expression.Lambda(
Expression.Block(blockExpressions),
? new TaskBlockingExpressionVisitor()
.Visit(
Expression.Call(
_includeAsyncMethodInfo
.MakeGenericMethod(targetQuerySourceReferenceExpression.Type),
EntityQueryModelVisitor.QueryContextParameter,
entityParameter,
_includedParameter,
_cancellationTokenParameter),
_cancellationTokenParameter),
nameof(Task<object>.Result))
targetQuerySourceReferenceExpression,
Expression.NewArrayInit(typeof(object), propertyExpressions),
Expression.Lambda(
Expression.Block(blockExpressions),
EntityQueryModelVisitor.QueryContextParameter,
entityParameter,
_includedParameter,
_cancellationTokenParameter),
_cancellationTokenParameter))
: Expression.Call(
_includeMethodInfo.MakeGenericMethod(targetQuerySourceReferenceExpression.Type),
EntityQueryModelVisitor.QueryContextParameter,
Expand Down

0 comments on commit b5dcbb1

Please sign in to comment.