Skip to content

Commit

Permalink
[mono] Add Vector128 Dot intrinsics for Amd64 (#76239)
Browse files Browse the repository at this point in the history
* Amd64 Dot intrinsics, change Amd64 Sum intrinsics

* code style fix
  • Loading branch information
matouskozak authored Oct 7, 2022
1 parent b38fb07 commit 02590ff
Showing 1 changed file with 55 additions and 21 deletions.
76 changes: 55 additions & 21 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,14 @@ extract_first_element (MonoCompile *cfg, MonoClass *klass, MonoTypeEnum element_
}

static MonoInst*
emit_sum_vector (MonoCompile *cfg, MonoClass *klass, MonoMethodSignature *fsig, MonoTypeEnum element_type, MonoInst **args)
emit_sum_vector (MonoCompile *cfg, MonoType *vector_type, MonoTypeEnum element_type, MonoInst *arg)
{
MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
MonoClass *arg_class = mono_class_from_mono_type_internal (vector_type);
int size = mono_class_value_size (arg_class, NULL);
if (size != 16) // Works only with Vector128
return NULL;

MonoClass *vector_class = mono_class_from_mono_type_internal (vector_type);
int instc0 = -1;
switch (element_type) {
case MONO_TYPE_R4:
Expand Down Expand Up @@ -605,16 +606,15 @@ emit_sum_vector (MonoCompile *cfg, MonoClass *klass, MonoMethodSignature *fsig,
case MONO_TYPE_I8:
case MONO_TYPE_U8: {
// Ssse3 doesn't have support for HorizontalAdd on i64
MonoInst *lower = emit_simd_ins_for_sig (cfg, klass, OP_XLOWER, 0, element_type, fsig, args);
MonoInst *upper = emit_simd_ins_for_sig (cfg, klass, OP_XUPPER, 0, element_type, fsig, args);
MonoInst *lower = emit_simd_ins (cfg, vector_class, OP_XLOWER, arg->dreg, -1);
MonoInst *upper = emit_simd_ins (cfg, vector_class, OP_XUPPER, arg->dreg, -1);

// Sum lower and upper i64
args[0] = lower;
args[1] = upper;
fsig->param_count = 2;
MonoInst* ins = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_IADD, element_type, fsig, args);
MonoInst *ins = emit_simd_ins (cfg, vector_class, OP_XBINOP, lower->dreg, upper->dreg);
ins->inst_c0 = OP_IADD;
ins->inst_c1 = element_type;

return extract_first_element (cfg, klass, element_type, ins->dreg);
return extract_first_element (cfg, vector_class, element_type, ins->dreg);
}
default: {
return NULL;
Expand All @@ -632,16 +632,15 @@ emit_sum_vector (MonoCompile *cfg, MonoClass *klass, MonoMethodSignature *fsig,
int num_rounds = fast_log2[num_elems];

MonoInst *tmp = emit_xzero (cfg, arg_class);
MonoInst *ins = NULL;
args[1] = tmp;
fsig->param_count = 2;
MonoInst *ins = arg;
// HorizontalAdds over vector log2(num_elems) times
for (int i = 0; i < num_rounds; ++i) {
ins = emit_simd_ins_for_sig (cfg, klass, OP_XOP_X_X_X, instc0, element_type, fsig, args);
args[0] = ins;
ins = emit_simd_ins (cfg, vector_class, OP_XOP_X_X_X, ins->dreg, tmp->dreg);
ins->inst_c0 = instc0;
ins->inst_c1 = element_type;
}

return extract_first_element (cfg, klass, element_type, ins->dreg);
return extract_first_element (cfg, vector_class, element_type, ins->dreg);
}
#endif

Expand Down Expand Up @@ -1295,13 +1294,50 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR_UNSAFE, -1, arg0_type, fsig, args);
}
case SN_Dot: {
#ifdef TARGET_ARM64
if (!is_element_type_primitive (fsig->params [0]))
return NULL;

#ifdef TARGET_ARM64
int instc0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL;
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc0, arg0_type, fsig, args);

return emit_sum_vector (cfg, fsig->params [0], arg0_type, pairwise_multiply);
#elif defined(TARGET_AMD64)
MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
int size = mono_class_value_size (arg_class, NULL);
if (size != 16) // Works only with Vector128
return NULL;

int instc =-1;
if (type_enum_is_float (arg0_type)) {
if (is_SIMD_feature_supported (cfg, MONO_CPU_X86_SSE41)) {
int mask_reg = alloc_ireg (cfg);
switch (arg0_type) {
case MONO_TYPE_R4:
instc = OP_SSE41_DPPS;
MONO_EMIT_NEW_ICONST (cfg, mask_reg, 0xf1); // 0xf1 ... 0b11110001
break;
case MONO_TYPE_R8:
instc = OP_SSE41_DPPD;
MONO_EMIT_NEW_ICONST (cfg, mask_reg, 0x31); // 0x31 ... 0b00110001
break;
default:
return NULL;
}
MonoInst *dot = emit_simd_ins (cfg, klass, instc, args [0]->dreg, args [1]->dreg);
dot->sreg3 = mask_reg;

return extract_first_element (cfg, klass, arg0_type, dot->dreg);
} else {
instc = OP_FMUL;
}
} else {
if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1)
return NULL; // We don't support sum vector for byte, sbyte types yet

instc = OP_IMUL;
}
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc, arg0_type, fsig, args);

return emit_sum_vector (cfg, fsig->params [0], arg0_type, pairwise_multiply);
#else
return NULL;
Expand Down Expand Up @@ -1526,10 +1562,8 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
case SN_Sum: {
if (!is_element_type_primitive (fsig->params [0]))
return NULL;
#ifdef TARGET_ARM64
return emit_sum_vector (cfg, fsig->params [0], arg0_type, args [0]);
#elif defined(TARGET_AMD64)
return emit_sum_vector(cfg, klass, fsig, arg0_type, args);
#if defined(TARGET_ARM64) || defined(TARGET_AMD64)
return emit_sum_vector (cfg, fsig->params [0], arg0_type, args [0]);
#else
return NULL;
#endif
Expand Down

0 comments on commit 02590ff

Please sign in to comment.