Skip to content

Commit

Permalink
Ensure Vector256.Dot produces a V256 result (#88712)
Browse files Browse the repository at this point in the history
* Ensure Vector256.Dot produces a V256 result

* Apply formatting patch
  • Loading branch information
tannergooding committed Jul 13, 2023
1 parent e0acb9d commit 620bd3e
Showing 1 changed file with 41 additions and 38 deletions.
79 changes: 41 additions & 38 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4462,29 +4462,30 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
assert(comp->compIsaSupportedDebugOnly(InstructionSet_AVX));

// We will be constructing the following parts:
// idx = CNS_INT int 0xF1
// idx = CNS_INT int 0xFF
// /--* op1 simd16
// +--* op2 simd16
// +--* idx int
// tmp1 = * HWINTRINSIC simd32 T DotProduct
// /--* tmp1 simd32
// * STORE_LCL_VAR simd32
// tmp1 = LCL_VAR simd32
// /--* tmp1 simd32
// tmp1 = * HWINTRINSIC simd16 T GetLower
// tmp2 = LCL_VAR simd32
// /--* tmp2 simd16
// tmp2 = * HWINTRINSIC simd16 T GetUpper
// /--* tmp1 simd16
// +--* tmp2 simd16
// node = * HWINTRINSIC simd16 T Add
// tmp3 = LCL_VAR simd32
// /--* tmp2 simd32
// +--* tmp3 simd32
// +--* CNS_INT int 0x01
// tmp2 = * HWINTRINSIC simd32 T Permute
// /--* tmp1 simd32
// +--* tmp2 simd32
// node = * HWINTRINSIC simd32 T Add

// This is roughly the following managed code:
// var tmp1 = Avx.DotProduct(op1, op2, 0xFF);
// var tmp2 = tmp1.GetUpper();
// return Sse.Add(tmp1, tmp2);
// var tmp2 = Avx.Permute2x128(tmp1, tmp1, 0x4E);
// return Avx.Add(tmp1, tmp2);

idx = comp->gtNewIconNode(0xF1, TYP_INT);
idx = comp->gtNewIconNode(0xFF, TYP_INT);
BlockRange().InsertBefore(node, idx);

tmp1 = comp->gtNewSimdHWIntrinsicNode(simdType, op1, op2, idx, NI_AVX_DotProduct, simdBaseJitType,
Expand All @@ -4500,27 +4501,30 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
tmp2 = comp->gtClone(tmp1);
BlockRange().InsertAfter(tmp1, tmp2);

tmp3 = comp->gtNewSimdGetUpperNode(TYP_SIMD16, tmp2, simdBaseJitType, simdSize);
tmp3 = comp->gtClone(tmp2);
BlockRange().InsertAfter(tmp2, tmp3);
LowerNode(tmp3);

tmp1 = comp->gtNewSimdGetLowerNode(TYP_SIMD16, tmp1, simdBaseJitType, simdSize);
BlockRange().InsertAfter(tmp3, tmp1);
LowerNode(tmp1);
idx = comp->gtNewIconNode(0x01, TYP_INT);
BlockRange().InsertAfter(tmp3, idx);

tmp2 = comp->gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, tmp3, tmp1, simdBaseJitType, 16);
BlockRange().InsertAfter(tmp1, tmp2);
tmp2 = comp->gtNewSimdHWIntrinsicNode(simdType, tmp2, tmp3, idx, NI_AVX_Permute2x128, simdBaseJitType,
simdSize);
BlockRange().InsertAfter(idx, tmp2);
LowerNode(tmp2);

tmp1 = comp->gtNewSimdBinOpNode(GT_ADD, simdType, tmp1, tmp2, simdBaseJitType, simdSize);
BlockRange().InsertAfter(tmp2, tmp1);

// We're producing a vector result, so just return the result directly
LIR::Use use;

if (BlockRange().TryGetUse(node, &use))
{
use.ReplaceWith(tmp2);
use.ReplaceWith(tmp1);
}

BlockRange().Remove(node);
return LowerNode(tmp2);
return LowerNode(tmp1);
}

case TYP_DOUBLE:
Expand Down Expand Up @@ -4999,21 +5003,19 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
// /--* tmp1 simd32
// * STORE_LCL_VAR simd32
// tmp1 = LCL_VAR simd32
// /--* tmp1 simd32
// tmp1 = * HWINTRINSIC simd16 T GetLower
// tmp2 = LCL_VAR simd32
// /--* tmp2 simd32
// tmp3 = * HWINTRINSIC simd16 T GetUpper
// /--* tmp1 simd16
// +--* tmp3 simd16
// tmp1 = * HWINTRINSIC simd16 T Add
// +--* CNS_INT int 0x01
// tmp2 = * HWINTRINSIC simd32 float Permute
// /--* tmp1 simd32
// +--* tmp2 simd32
// tmp1 = * HWINTRINSIC simd32 T Add
// ...

// This is roughly the following managed code:
// ...
// var tmp2 = tmp1;
// tmp3 = tmp2.GetUpper();
// var tmp1 = Isa.Add(tmp1.GetLower(), tmp2);
// var tmp2 = Isa.Permute2x128(tmp1, tmp2, 0x01);
// tmp1 = Isa.Add(tmp1, tmp2);
// ...

assert(simdBaseType != TYP_FLOAT);
Expand All @@ -5026,20 +5028,21 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
tmp2 = comp->gtClone(tmp1);
BlockRange().InsertAfter(tmp1, tmp2);

tmp3 = comp->gtNewSimdGetUpperNode(TYP_SIMD16, tmp2, simdBaseJitType, simdSize);
tmp3 = comp->gtClone(tmp2);
BlockRange().InsertAfter(tmp2, tmp3);
LowerNode(tmp3);

tmp1 = comp->gtNewSimdGetLowerNode(TYP_SIMD16, tmp1, simdBaseJitType, simdSize);
BlockRange().InsertAfter(tmp3, tmp1);
LowerNode(tmp1);
idx = comp->gtNewIconNode(0x01, TYP_INT);
BlockRange().InsertAfter(tmp3, idx);

tmp2 = comp->gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, tmp3, tmp1, simdBaseJitType, 16);
BlockRange().InsertAfter(tmp1, tmp2);
NamedIntrinsic permute2x128 = (simdBaseType == TYP_DOUBLE) ? NI_AVX_Permute2x128 : NI_AVX2_Permute2x128;

tmp2 = comp->gtNewSimdHWIntrinsicNode(simdType, tmp2, tmp3, idx, permute2x128, simdBaseJitType, simdSize);
BlockRange().InsertAfter(idx, tmp2);
LowerNode(tmp2);

node->SetSimdSize(16);
tmp1 = tmp2;
tmp1 = comp->gtNewSimdBinOpNode(GT_ADD, simdType, tmp1, tmp2, simdBaseJitType, simdSize);
BlockRange().InsertAfter(tmp2, tmp1);
LowerNode(tmp1);
}

// We're producing a vector result, so just return the result directly
Expand Down

0 comments on commit 620bd3e

Please sign in to comment.