Skip to content
This repository has been archived by the owner on Nov 1, 2020. It is now read-only.

Commit

Permalink
Wasm: add support for overflow checks on signed and unsigned ints mul…
Browse files Browse the repository at this point in the history
…tiply (#8259)

* wasm-ovf-unsigned-int

* refactor for stack kind tests and add signed check

* use llvm intrinsics
  • Loading branch information
yowl authored Aug 26, 2020
1 parent 266ae09 commit 34af164
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 18 deletions.
93 changes: 75 additions & 18 deletions src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2782,15 +2782,15 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
// then
builder.PositionAtEnd(notFatBranch);
ExceptionRegion currentTryRegion = GetCurrentTryRegion();
LLVMValueRef notFatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fn, llvmArgs, ref nextInstrBlock);
LLVMValueRef notFatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fn, llvmArgs.ToArray(), ref nextInstrBlock);
builder.BuildBr(endifBlock);

// else
builder.PositionAtEnd(fatBranch);
var fnWithDict = builder.BuildCast(LLVMOpcode.LLVMBitCast, fn, LLVMTypeRef.CreatePointer(GetLLVMSignatureForMethod(runtimeDeterminedMethod.Signature, true), 0), "fnWithDict");
var dictDereffed = builder.BuildLoad(builder.BuildLoad( dict, "l1"), "l2");
llvmArgs.Insert(needsReturnSlot ? 2 : 1, dictDereffed);
LLVMValueRef fatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fnWithDict, llvmArgs, ref nextInstrBlock);
LLVMValueRef fatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fnWithDict, llvmArgs.ToArray(), ref nextInstrBlock);
builder.BuildBr(endifBlock);

// endif
Expand All @@ -2806,7 +2806,7 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
}
else
{
llvmReturn = CallOrInvoke(fromLandingPad, builder, GetCurrentTryRegion(), fn, llvmArgs, ref nextInstrBlock);
llvmReturn = CallOrInvoke(fromLandingPad, builder, GetCurrentTryRegion(), fn, llvmArgs.ToArray(), ref nextInstrBlock);
}

if (!returnType.IsVoid)
Expand All @@ -2829,23 +2829,21 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
}

LLVMValueRef CallOrInvoke(bool fromLandingPad, LLVMBuilderRef builder, ExceptionRegion currentTryRegion,
LLVMValueRef fn, List<LLVMValueRef> llvmArgs, ref LLVMBasicBlockRef nextInstrBlock)
LLVMValueRef fn, LLVMValueRef[] llvmArgs, ref LLVMBasicBlockRef nextInstrBlock)
{
LLVMValueRef retVal;
if (currentTryRegion == null || fromLandingPad) // not handling exceptions that occur in the LLVM landing pad determining the EH handler
{
retVal = builder.BuildCall(fn, llvmArgs.ToArray(), string.Empty);
retVal = builder.BuildCall(fn, llvmArgs, string.Empty);
}
else
{
nextInstrBlock = _currentFunclet.AppendBasicBlock(String.Format("Try{0:X}", _currentOffset));

retVal = builder.BuildInvoke(fn, llvmArgs.ToArray(),
retVal = builder.BuildInvoke(fn, llvmArgs,
nextInstrBlock, GetOrCreateLandingPad(currentTryRegion), string.Empty);

_curBasicBlock = nextInstrBlock;
_currentBasicBlock.LLVMBlocks.Add(_curBasicBlock);
_currentBasicBlock.LastInternalBlock = _curBasicBlock;
AddInternalBasicBlock(nextInstrBlock);
builder.PositionAtEnd(_curBasicBlock);
}
return retVal;
Expand Down Expand Up @@ -3902,10 +3900,26 @@ private void ImportBinaryOperation(ILOpcode opcode)

result = _builder.BuildSub(left, right, "sub");
break;
// TODO: Overflow checks
case ILOpcode.mul_ovf:
case ILOpcode.mul_ovf_un:
result = _builder.BuildMul(left, right, "mul");
Debug.Assert(CanPerformUnsignedOverflowOperations(op1.Kind));
if (Is32BitStackValue(op1.Kind))
{
result = BuildMulOverflowCheck(left, right, "umul", LLVMTypeRef.Int32);
}
else
{
result = BuildMulOverflowCheck(left, right, "umul", LLVMTypeRef.Int64);
}
break;
case ILOpcode.mul_ovf:
if (Is32BitStackValue(op1.Kind))
{
result = BuildMulOverflowCheck(left, right, "smul", LLVMTypeRef.Int32);
}
else
{
result = BuildMulOverflowCheck(left, right, "smul", LLVMTypeRef.Int64);
}
break;

default:
Expand All @@ -3922,6 +3936,32 @@ private void ImportBinaryOperation(ILOpcode opcode)
PushExpression(kind, "binop", result, type);
}

LLVMValueRef BuildMulOverflowCheck(LLVMValueRef left, LLVMValueRef right, string mulOp, LLVMTypeRef intType)
{
LLVMValueRef mulFunction = GetOrCreateLLVMFunction("llvm." + mulOp + ".with.overflow." + (intType == LLVMTypeRef.Int32 ? "i32" : "i64"), LLVMTypeRef.CreateFunction(
LLVMTypeRef.CreateStruct(new[] { intType, LLVMTypeRef.Int1}, false), new[] { intType, intType }));
LLVMValueRef mulRes = _builder.BuildCall(mulFunction, new[] {left, right});
var overflow = _builder.BuildExtractValue(mulRes, 1);
LLVMBasicBlockRef overflowBlock = _currentFunclet.AppendBasicBlock("ovf");
LLVMBasicBlockRef noOverflowBlock = _currentFunclet.AppendBasicBlock("no_ovf");
_builder.BuildCondBr(overflow, overflowBlock, noOverflowBlock);

_builder.PositionAtEnd(overflowBlock);
CallOrInvokeThrowException(_builder, "ThrowHelpers", "ThrowOverflowException");

_builder.PositionAtEnd(noOverflowBlock);
LLVMValueRef result = _builder.BuildExtractValue(mulRes, 0);
AddInternalBasicBlock(noOverflowBlock);
return result;
}

void AddInternalBasicBlock(LLVMBasicBlockRef basicBlock)
{
_curBasicBlock = basicBlock;
_currentBasicBlock.LLVMBlocks.Add(_curBasicBlock);
_currentBasicBlock.LastInternalBlock = _curBasicBlock;
}

bool CanPerformSignedOverflowOperations(StackValueKind kind)
{
return kind == StackValueKind.Int32 || kind == StackValueKind.Int64;
Expand Down Expand Up @@ -3995,7 +4035,7 @@ void BuildAddOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValue
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), left, right }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new LLVMValueRef[] { GetShadowStack(), left, right }, ref nextInstrBlock);
}

void BuildSubOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValueRef left, LLVMValueRef right, LLVMTypeRef sizeTypeRef, LLVMValueRef maxValue, LLVMValueRef minValue, bool signed)
Expand Down Expand Up @@ -4028,7 +4068,7 @@ void BuildSubOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValue
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), left, right }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new LLVMValueRef[] { GetShadowStack(), left, right }, ref nextInstrBlock);
}

private void BuildOverflowCheck(LLVMBuilderRef builder, LLVMValueRef compOperand, LLVMIntPredicate predicate,
Expand Down Expand Up @@ -4542,7 +4582,7 @@ private void ThrowIfNull(LLVMValueRef entry)
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), NullRefFunction, new List<LLVMValueRef> { GetShadowStack(), entry }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), NullRefFunction, new LLVMValueRef[] { GetShadowStack(), entry }, ref nextInstrBlock);
}

private void ThrowCkFinite(LLVMValueRef value, int size, ref LLVMValueRef llvmCheckFunction)
Expand Down Expand Up @@ -4585,16 +4625,33 @@ private void ThrowCkFinite(LLVMValueRef value, int size, ref LLVMValueRef llvmCh
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), value }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new LLVMValueRef[] { GetShadowStack(), value }, ref nextInstrBlock);
}

private void ThrowException(LLVMBuilderRef builder, string helperClass, string helperMethodName, LLVMValueRef throwingFunction)
{
LLVMValueRef fn = GetHelperLlvmMethod(helperClass, helperMethodName);
builder.BuildCall(fn, new LLVMValueRef[] {throwingFunction.GetParam(0) }, string.Empty);
builder.BuildUnreachable();
}

/// <summary>
/// Calls or invokes the call to throwing the exception so it can be caught in the caller
/// </summary>
private void CallOrInvokeThrowException(LLVMBuilderRef builder, string helperClass, string helperMethodName)
{
LLVMValueRef fn = GetHelperLlvmMethod(helperClass, helperMethodName);
LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, builder, GetCurrentTryRegion(), fn, new LLVMValueRef[] {GetShadowStack()}, ref nextInstrBlock);
builder.BuildUnreachable();
}

LLVMValueRef GetHelperLlvmMethod(string helperClass, string helperMethodName)
{
MetadataType helperType = _compilation.TypeSystemContext.SystemModule.GetKnownType("Internal.Runtime.CompilerHelpers", helperClass);
MethodDesc helperMethod = helperType.GetKnownMethod(helperMethodName, null);
LLVMValueRef fn = LLVMFunctionForMethod(helperMethod, helperMethod, null, false, null, null, out bool hasHiddenParam, out LLVMValueRef dictPtrPtrStore, out LLVMValueRef fatFunctionPtr);
builder.BuildCall(fn, new LLVMValueRef[] {throwingFunction.GetParam(0) }, string.Empty);
builder.BuildUnreachable();
return fn;
}

private LLVMValueRef GetInstanceFieldAddress(StackEntry objectEntry, FieldDesc field)
Expand Down
Loading

0 comments on commit 34af164

Please sign in to comment.