Skip to content

Commit

Permalink
fixed #419
Browse files Browse the repository at this point in the history
  • Loading branch information
dadhi committed Oct 13, 2024
1 parent b70292a commit 706de20
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 72 deletions.
137 changes: 69 additions & 68 deletions src/FastExpressionCompiler/FastExpressionCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,7 @@ private static bool TryEmitCoalesceOperator(BinaryExpression expr, IReadOnlyList
if (!TryEmit(left, paramExprs, il, ref closure, setup, flags))
return false;

var exprType = expr.Type;
var leftType = left.Type;
if (leftType.IsValueType)
{
Expand All @@ -2489,7 +2490,7 @@ private static bool TryEmitCoalesceOperator(BinaryExpression expr, IReadOnlyList

il.Demit(OpCodes.Brfalse, labelFalse);

if (expr.Type == leftType.GetUnderlyingNullableTypeUnsafe())
if (exprType == Nullable.GetUnderlyingType(leftType))
{
// if the target expression type is of underlying nullable, and the left operand is not null,
// then extract its underlying value
Expand All @@ -2515,16 +2516,16 @@ private static bool TryEmitCoalesceOperator(BinaryExpression expr, IReadOnlyList
if (!TryEmit(right, paramExprs, il, ref closure, setup, flags))
return false;

if (right.Type != expr.Type)
if (right.Type != exprType)
il.TryEmitBoxOf(right.Type);

if (left.Type == expr.Type)
if (left.Type == exprType)
il.DmarkLabel(labelFalse);
else
{
il.Demit(OpCodes.Br, labelDone);
il.DmarkLabel(labelFalse); // todo: @bug? should we insert the boxing for the Nullable value type before the Castclass
il.Demit(OpCodes.Castclass, expr.Type);
il.Demit(OpCodes.Castclass, exprType);
il.DmarkLabel(labelDone);
}
}
Expand Down Expand Up @@ -4178,7 +4179,7 @@ private static bool TryEmitArithmeticAndOrAssign(

if (leftIsNullable)
{
// todo: @perf @simplify avoid the Dup and the Pop for this case
// todo: @perf @simplify avoid the Dup and the Pop for this case via storing and loading local var same as in `TryEmitArithmetic`
if (leftIsByAddress | objExpr != null)
{
var skipPopLeftDuppedInstance = il.DefineLabel();
Expand Down Expand Up @@ -5118,20 +5119,20 @@ private static bool TryEmitComparison(
#endif
ILGenerator il, ref ClosureInfo closure, CompilerFlags setup, ParentFlags parent)
{
var leftOpType = left.Type;
var leftIsNullable = leftOpType.IsNullable();
var rightOpType = right.Type;
var leftType = left.Type;
var leftIsNullable = leftType.IsNullable();
var rightType = right.Type;

// if one operand is `null` then coalesce the types
var rightIsNull = IsNullContainingExpression(right);
var comparingObjectWithNull = rightIsNull & rightOpType == typeof(object);
var comparingObjectWithNull = rightIsNull & rightType == typeof(object);
if (comparingObjectWithNull)
rightOpType = leftOpType;
rightType = leftType;

var leftIsNull = IsNullContainingExpression(left);
comparingObjectWithNull = leftIsNull & leftOpType == typeof(object);
comparingObjectWithNull = leftIsNull & leftType == typeof(object);
if (comparingObjectWithNull)
leftOpType = rightOpType;
leftType = rightType;

var operandParent = parent & ~ParentFlags.IgnoreResult & ~ParentFlags.InstanceAccess;

Expand All @@ -5143,19 +5144,19 @@ private static bool TryEmitComparison(
{
if (!TryEmit(left, paramExprs, il, ref closure, setup, operandParent))
return false;
EmitStoreAndLoadLocalVariableAddress(il, leftOpType);
EmitMethodCall(il, leftOpType.GetNullableHasValueGetterMethod());
EmitStoreAndLoadLocalVariableAddress(il, leftType);
EmitMethodCall(il, leftType.GetNullableHasValueGetterMethod());
if (nodeType == ExpressionType.Equal)
EmitEqualToZeroOrNull(il);
return il.EmitPopIfIgnoreResult(parent);
}

if (leftIsNull && rightOpType.IsNullable())
if (leftIsNull && rightType.IsNullable())
{
if (!TryEmit(right, paramExprs, il, ref closure, setup, operandParent))
return false;
EmitStoreAndLoadLocalVariableAddress(il, rightOpType);
EmitMethodCall(il, rightOpType.GetNullableHasValueGetterMethod());
EmitStoreAndLoadLocalVariableAddress(il, rightType);
EmitMethodCall(il, rightType.GetNullableHasValueGetterMethod());
if (nodeType == ExpressionType.Equal)
EmitEqualToZeroOrNull(il);
return il.EmitPopIfIgnoreResult(parent);
Expand All @@ -5173,12 +5174,12 @@ private static bool TryEmitComparison(

// save the left result to restore it later after the complex expression, see #422
if (rightIsComplexExpression = right.IsComplexExpression())
lVarIndex = EmitStoreLocalVariable(il, leftOpType);
lVarIndex = EmitStoreLocalVariable(il, leftType);
else if (leftIsNullable)
{
lVarIndex = EmitStoreAndLoadLocalVariableAddress(il, leftOpType);
il.Demit(OpCodes.Ldfld, leftOpType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
leftOpType = Nullable.GetUnderlyingType(leftOpType);
lVarIndex = EmitStoreAndLoadLocalVariableAddress(il, leftType);
il.Demit(OpCodes.Ldfld, leftType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
leftType = Nullable.GetUnderlyingType(leftType);
}
}

Expand All @@ -5188,8 +5189,8 @@ private static bool TryEmitComparison(
return false;

if (comparingObjectWithNull ||
(leftOpType != rightOpType && leftOpType.IsClass && rightOpType.IsClass &&
(leftOpType == typeof(object) | rightOpType == typeof(object))))
(leftType != rightType && leftType.IsClass && rightType.IsClass &&
(leftType == typeof(object) | rightType == typeof(object))))
{
if (!isEqualityOp)
return false;
Expand All @@ -5208,40 +5209,40 @@ private static bool TryEmitComparison(
var rVarIndex = -1;
if (rightIsComplexExpression)
{
rVarIndex = EmitStoreLocalVariable(il, rightOpType);
rVarIndex = EmitStoreLocalVariable(il, rightType);
if (!leftIsNullable)
EmitLoadLocalVariable(il, lVarIndex);
else
{
EmitLoadLocalVariableAddress(il, lVarIndex);
il.Demit(OpCodes.Ldfld, leftOpType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
leftOpType = Nullable.GetUnderlyingType(leftOpType);
il.Demit(OpCodes.Ldfld, leftType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
leftType = Nullable.GetUnderlyingType(leftType);
}

if (!rightOpType.IsNullable())
if (!rightType.IsNullable())
EmitLoadLocalVariable(il, rVarIndex);
else
{
EmitLoadLocalVariableAddress(il, rVarIndex);
il.Demit(OpCodes.Ldfld, rightOpType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
rightOpType = Nullable.GetUnderlyingType(rightOpType);
il.Demit(OpCodes.Ldfld, rightType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
rightType = Nullable.GetUnderlyingType(rightType);
}
}
else if (leftIsNull)
{
// here we're handling only non-nullable right, the nullable right with null left is handled above
rVarIndex = EmitStoreLocalVariable(il, rightOpType);
// here we're handling only non-nullable right, the nullable right with null left is handled above
rVarIndex = EmitStoreLocalVariable(il, rightType);
il.Demit(OpCodes.Ldnull);
EmitLoadLocalVariable(il, rVarIndex);
}
else if (rightOpType.IsNullable())
else if (rightType.IsNullable())
{
rVarIndex = EmitStoreAndLoadLocalVariableAddress(il, rightOpType);
il.Demit(OpCodes.Ldfld, rightOpType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
rightOpType = Nullable.GetUnderlyingType(rightOpType);
rVarIndex = EmitStoreAndLoadLocalVariableAddress(il, rightType);
il.Demit(OpCodes.Ldfld, rightType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
rightType = Nullable.GetUnderlyingType(rightType);
}

if (!leftOpType.IsPrimitive && !leftOpType.IsEnum)
if (!leftType.IsPrimitive && !leftType.IsEnum)
{
var methodName
= nodeType == ExpressionType.Equal ? "op_Equality"
Expand All @@ -5254,9 +5255,9 @@ var methodName
if (methodName == null)
return false;

var method = FindBinaryOperandMethod(methodName, leftOpType, leftOpType, rightOpType, typeof(bool));
if (method == null & leftOpType != rightOpType)
method = FindBinaryOperandMethod(methodName, rightOpType, leftOpType, rightOpType, typeof(bool));
var method = FindBinaryOperandMethod(methodName, leftType, leftType, rightType, typeof(bool));
if (method == null & leftType != rightType)
method = FindBinaryOperandMethod(methodName, rightType, leftType, rightType, typeof(bool));
if (method != null)
{
var ok = EmitMethodCall(il, method);
Expand Down Expand Up @@ -5296,17 +5297,17 @@ var methodName
break;
case ExpressionType.GreaterThanOrEqual:
// simplifying by using the LessThen (Clt) and comparing with negative outcome (Ceq 0)
if (leftOpType.IsUnsigned() && rightOpType.IsUnsigned() ||
(leftOpType.IsFloatingPoint() || rightOpType.IsFloatingPoint()))
if (leftType.IsUnsigned() && rightType.IsUnsigned() ||
(leftType.IsFloatingPoint() || rightType.IsFloatingPoint()))
il.Demit(OpCodes.Clt_Un);
else
il.Demit(OpCodes.Clt);
EmitEqualToZeroOrNull(il);
break;
case ExpressionType.LessThanOrEqual:
// simplifying by using the GreaterThen (Cgt) and comparing with negative outcome (Ceq 0)
if (leftOpType.IsUnsigned() && rightOpType.IsUnsigned() ||
(leftOpType.IsFloatingPoint() || rightOpType.IsFloatingPoint()))
if (leftType.IsUnsigned() && rightType.IsUnsigned() ||
(leftType.IsFloatingPoint() || rightType.IsFloatingPoint()))
il.Demit(OpCodes.Cgt_Un);
else
il.Demit(OpCodes.Cgt);
Expand All @@ -5320,7 +5321,7 @@ var methodName
nullableCheck:
if (leftIsNullable)
{
var leftNullableHasValueGetterMethod = left.Type.GetNullableHasValueGetterMethod();
var leftNullableHasValueGetterMethod = left.Type.GetNullableHasValueGetterMethod(); // asking from the left.Type because leftType now is set to the underlying type

EmitLoadLocalVariableAddress(il, lVarIndex);
EmitMethodCall(il, leftNullableHasValueGetterMethod);
Expand Down Expand Up @@ -5399,19 +5400,21 @@ private static bool TryEmitArithmetic(Expression left, Expression right, Express
var leftNoValueLabel = default(Label);
var leftType = left.Type;
var leftIsNullable = leftType.IsNullable();
var leftVar = -1;
var leftValueVar = -1;
if (leftIsNullable)
{
leftNoValueLabel = il.DefineLabel();
if (!TryEmit(left, paramExprs, il, ref closure, setup, flags | ParentFlags.InstanceCall))
return false;

if (!closure.LastEmitIsAddress)
EmitStoreAndLoadLocalVariableAddress(il, leftType);

il.Demit(OpCodes.Dup);
leftVar = EmitStoreAndLoadLocalVariableAddress(il, leftType);
EmitMethodCall(il, leftType.GetNullableHasValueGetterMethod());
il.Demit(OpCodes.Brfalse, leftNoValueLabel);

EmitLoadLocalVariableAddress(il, leftVar);
il.Demit(OpCodes.Ldfld, leftType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
leftValueVar = EmitStoreLocalVariable(il, Nullable.GetUnderlyingType(leftType));
}
else if (!TryEmit(left, paramExprs, il, ref closure, setup, flags))
return false;
Expand All @@ -5424,57 +5427,58 @@ private static bool TryEmitArithmetic(Expression left, Expression right, Express
}
else
{
var rightType = right.Type;

// stores the left value for later to restore it after the complex right emit,
// Stores the left value for later to restore it after the complex right emit,
// it prevents the problems in cases of right being a block, try-catch, etc.
// see `Using_try_finally_as_arithmetic_operand_use_void_block_in_finally`
var leftVar = -1;
if (right.NodeType.IsBlockLikeOrConditional() || right.NodeType == ExpressionType.Invoke)
leftVar = EmitStoreLocalVariable(il, leftType);
var rightType = right.Type;
if (leftValueVar == -1 && right.IsComplexExpression())
leftValueVar = EmitStoreLocalVariable(il, leftType);

var rightVar = -1;
var rightValueVar = -1;
rightIsNullable = rightType.IsNullable();
if (rightIsNullable)
{
rightNoValueLabel = il.DefineLabel();
if (!TryEmit(right, paramExprs, il, ref closure, setup, flags | ParentFlags.InstanceCall))
return false;

if (!closure.LastEmitIsAddress)
EmitStoreAndLoadLocalVariableAddress(il, rightType);
rightVar = EmitStoreAndLoadLocalVariableAddress(il, rightType);

il.Demit(OpCodes.Dup);
EmitMethodCall(il, rightType.GetNullableHasValueGetterMethod());
il.Demit(OpCodes.Brfalse, rightNoValueLabel);

EmitLoadLocalVariableAddress(il, rightVar);
il.Demit(OpCodes.Ldfld, rightType.GetNullableValueUnsafeAkaGetValueOrDefaultMethod());
rightValueVar = EmitStoreLocalVariable(il, Nullable.GetUnderlyingType(rightType));
}
else if (!TryEmit(right, paramExprs, il, ref closure, setup, flags))
return false;

if (leftVar != -1)
// Means that it was complex right and the result of the left operation was stored
// and should be restored now, so the left and right go in order before the arithmetic operation
if (leftValueVar != -1)
{
// restore the left and right in proper order for operation
var rightVar = EmitStoreLocalVariable(il, rightType);
EmitLoadLocalVariable(il, leftVar);
EmitLoadLocalVariable(il, rightVar);
if (rightValueVar == -1)
rightValueVar = EmitStoreLocalVariable(il, rightType);
EmitLoadLocalVariable(il, leftValueVar);
EmitLoadLocalVariable(il, rightValueVar);
}

if (!TryEmitArithmeticOperation(leftType, rightType, nodeType, exprType, il))
return false;
}

if (leftIsNullable | rightIsNullable) // todo: @clarify that the emitted code is correct
if (leftIsNullable | rightIsNullable)
{
var valueLabel = il.DefineLabel();
il.Demit(OpCodes.Br, valueLabel);

if (rightIsNullable)
il.DmarkLabel(rightNoValueLabel);
il.Demit(OpCodes.Pop);

if (leftIsNullable)
il.DmarkLabel(leftNoValueLabel);
il.Demit(OpCodes.Pop);

if (exprType.IsNullable())
{
Expand Down Expand Up @@ -6103,9 +6107,6 @@ internal static bool IsNullable(this Type type) =>
internal static Type GetUnderlyingNullableTypeOrNull(this Type type) =>
(type.IsValueType & type.IsGenericType) && type.GetGenericTypeDefinition() == typeof(Nullable<>) ? type.GetGenericArguments()[0] : null;

[MethodImpl((MethodImplOptions)256)]
internal static Type GetUnderlyingNullableTypeUnsafe(this Type type) => type.GetGenericArguments()[0];

public static string GetArithmeticBinaryOperatorMethodName(this ExpressionType nodeType) =>
nodeType switch
{
Expand Down
Loading

0 comments on commit 706de20

Please sign in to comment.