Skip to content

Commit

Permalink
[Mono][Arm64] Instrinsify methods for Vector4 on Arm64 (#72124)
Browse files Browse the repository at this point in the history
* Initial change to enable intrinsics for Vector4 on arm64

* Fix return

* Add test app

* Also the makefile change

* Fix return type handling from vector to struct

* Fix wasm build failure

* Handle volatile argument

* Promote SIMD return to value

* Address review feedback and add intrinsify more methods

* Assign a NULL pointer to retval

* Remove unused local var

* Fix build warning
  • Loading branch information
fanyang-mono committed Jul 27, 2022
1 parent a92bd72 commit c21ae04
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 54 deletions.
31 changes: 26 additions & 5 deletions src/mono/mono/mini/mini-llvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,10 @@ ovr_tag_from_mono_vector_class (MonoClass *klass) {
case 8: ret |= INTRIN_vector64; break;
case 16: ret |= INTRIN_vector128; break;
}

if (!strcmp ("Vector4", m_class_get_name (klass)) || !strcmp ("Vector2", m_class_get_name (klass)))
return ret | INTRIN_float32;

MonoType *etype = mono_class_get_context (klass)->class_inst->type_argv [0];
switch (etype->type) {
case MONO_TYPE_I1: case MONO_TYPE_U1: ret |= INTRIN_int8; break;
Expand Down Expand Up @@ -1419,9 +1423,9 @@ convert_full (EmitContext *ctx, LLVMValueRef v, LLVMTypeRef dtype, gboolean is_u

if (LLVMGetTypeKind (stype) == LLVMPointerTypeKind && LLVMGetTypeKind (dtype) == LLVMPointerTypeKind)
return LLVMBuildBitCast (ctx->builder, v, dtype, "");
if (LLVMGetTypeKind (dtype) == LLVMPointerTypeKind)
if (LLVMGetTypeKind (dtype) == LLVMPointerTypeKind && LLVMGetTypeKind (stype) == LLVMIntegerTypeKind)
return LLVMBuildIntToPtr (ctx->builder, v, dtype, "");
if (LLVMGetTypeKind (stype) == LLVMPointerTypeKind)
if (LLVMGetTypeKind (stype) == LLVMPointerTypeKind && LLVMGetTypeKind (dtype) == LLVMIntegerTypeKind)
return LLVMBuildPtrToInt (ctx->builder, v, dtype, "");

if (mono_arch_is_soft_float ()) {
Expand Down Expand Up @@ -4088,6 +4092,7 @@ emit_entry_bb (EmitContext *ctx, LLVMBuilderRef builder)
// FIXME: Enabling this fails on windows
case LLVMArgVtypeAddr:
case LLVMArgVtypeByRef:
case LLVMArgAsFpArgs:
{
if (MONO_CLASS_IS_SIMD (ctx->cfg, mono_class_from_mono_type_internal (ainfo->type)))
/* Treat these as normal values */
Expand Down Expand Up @@ -4789,6 +4794,9 @@ process_call (EmitContext *ctx, MonoBasicBlock *bb, LLVMBuilderRef *builder_ref,
if (!addresses [call->inst.dreg])
addresses [call->inst.dreg] = build_alloca_address (ctx, sig->ret);
LLVMBuildStore (builder, lcall, convert_full (ctx, addresses [call->inst.dreg]->value, pointer_type (LLVMTypeOf (lcall)), FALSE));

load_name = "process_call_fp_struct";
should_promote_to_value = is_simd;
break;
case LLVMArgVtypeByVal:
/*
Expand Down Expand Up @@ -5989,10 +5997,23 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
case LLVMArgAsIArgs:
case LLVMArgFpStruct: {
LLVMTypeRef ret_type = LLVMGetReturnType (LLVMGetElementType (LLVMTypeOf (method)));
LLVMValueRef retval;
LLVMValueRef retval, elem;
gboolean is_simd = MONO_CLASS_IS_SIMD (ctx->cfg, mono_class_from_mono_type_internal (sig->ret));

g_assert (addresses [ins->sreg1]);
retval = LLVMBuildLoad2 (builder, ret_type, convert (ctx, addresses [ins->sreg1]->value, pointer_type (ret_type)), "");
if (is_simd) {
g_assert (lhs);
retval = LLVMConstNull(ret_type);

int len = LLVMGetVectorSize (LLVMTypeOf (lhs));
for (int i = 0; i < len; i++)
{
elem = LLVMBuildExtractElement (builder, lhs, const_int32 (i), "extract_elem");
retval = LLVMBuildInsertValue (builder, retval, elem, i, "insert_val_struct");
}
} else{
g_assert (addresses [ins->sreg1]);
retval = LLVMBuildLoad2 (builder, ret_type, convert (ctx, addresses [ins->sreg1]->value, pointer_type (ret_type)), "");
}
LLVMBuildRet (builder, retval);
break;
}
Expand Down
2 changes: 1 addition & 1 deletion src/mono/mono/mini/mini-runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -4351,7 +4351,7 @@ init_class (MonoClass *klass)

const char *name = m_class_get_name (klass);

#ifdef TARGET_AMD64
#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
/*
* Some of the intrinsics used by the VectorX classes are only implemented on amd64.
* The JIT can't handle SIMD types with != 16 size yet.
Expand Down
103 changes: 55 additions & 48 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -306,18 +306,20 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna
break;
case SN_Multiply:
case SN_op_Multiply:
if (fsig->params [1]->type != MONO_TYPE_GENERICINST) {
MonoInst* ins = emit_simd_ins (cfg, klass, OP_CREATE_SCALAR_UNSAFE, args [1]->dreg, -1);
ins->inst_c1 = arg_type;
ins = emit_simd_ins (cfg, klass, OP_XBINOP_BYSCALAR, args [0]->dreg, ins->dreg);
ins->inst_c0 = OP_FMUL;
return ins;
} else if (fsig->params [0]->type != MONO_TYPE_GENERICINST) {
MonoInst* ins = emit_simd_ins (cfg, klass, OP_CREATE_SCALAR_UNSAFE, args [0]->dreg, -1);
ins->inst_c1 = arg_type;
ins = emit_simd_ins (cfg, klass, OP_XBINOP_BYSCALAR, ins->dreg, args [1]->dreg);
ins->inst_c0 = OP_FMUL;
return ins;
if (strcmp ("Vector4", m_class_get_name (klass)) && strcmp ("Vector2", m_class_get_name (klass))) {
if (fsig->params [1]->type != MONO_TYPE_GENERICINST) {
MonoInst* ins = emit_simd_ins (cfg, klass, OP_CREATE_SCALAR_UNSAFE, args [1]->dreg, -1);
ins->inst_c1 = arg_type;
ins = emit_simd_ins (cfg, klass, OP_XBINOP_BYSCALAR, args [0]->dreg, ins->dreg);
ins->inst_c0 = OP_FMUL;
return ins;
} else if (fsig->params [0]->type != MONO_TYPE_GENERICINST) {
MonoInst* ins = emit_simd_ins (cfg, klass, OP_CREATE_SCALAR_UNSAFE, args [0]->dreg, -1);
ins->inst_c1 = arg_type;
ins = emit_simd_ins (cfg, klass, OP_XBINOP_BYSCALAR, ins->dreg, args [1]->dreg);
ins->inst_c0 = OP_FMUL;
return ins;
}
}
instc0 = OP_FMUL;
break;
Expand Down Expand Up @@ -512,8 +514,15 @@ emit_sum_vector (MonoCompile *cfg, MonoType *vector_type, MonoTypeEnum element_t
{
MonoClass *vector_class = mono_class_from_mono_type_internal (vector_type);
int vector_size = mono_class_value_size (vector_class, NULL);
MonoClass *element_class = mono_class_from_mono_type_internal (get_vector_t_elem_type (vector_type));
int element_size = mono_class_value_size (element_class, NULL);
int element_size;
if (!strcmp ("Vector4", m_class_get_name (vector_class)))
element_size = vector_size / 4;
else if (!strcmp ("Vector2", m_class_get_name (vector_class)))
element_size = vector_size / 2;
else {
MonoClass *element_class = mono_class_from_mono_type_internal (get_vector_t_elem_type (vector_type));
element_size = mono_class_value_size (element_class, NULL);
}
gboolean has_single_element = vector_size == element_size;

// If there's just one element we need to extract it instead of summing the whole array
Expand Down Expand Up @@ -783,7 +792,7 @@ emit_vector_create_elementwise (
return ins;
}

#if defined(TARGET_AMD64) || defined(TARGET_ARM64) || defined(TARGET_WASM)
#if defined(TARGET_AMD64) || defined(TARGET_ARM64) || defined(TARGET_WASM)

static int
type_to_xinsert_op (MonoTypeEnum type)
Expand Down Expand Up @@ -1549,20 +1558,20 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
return NULL;
}

#endif // defined(TARGET_AMD64) || defined(TARGET_ARM64)

#ifdef TARGET_AMD64

// System.Numerics.Vector2/Vector3/Vector4
static guint16 vector2_methods[] = {
SN_ctor,
SN_Abs,
SN_Add,
SN_CopyTo,
SN_Divide,
SN_Dot,
SN_GetElement,
SN_Max,
SN_Min,
SN_Multiply,
SN_SquareRoot,
SN_Subtract,
SN_WithElement,
SN_get_Item,
SN_get_One,
Expand Down Expand Up @@ -1715,6 +1724,10 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
ins->inst_c1 = MONO_TYPE_R4;
return ins;
}
case SN_Add:
case SN_Divide:
case SN_Multiply:
case SN_Subtract:
case SN_op_Addition:
case SN_op_Division:
case SN_op_Multiply:
Expand All @@ -1723,34 +1736,13 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
case SN_Min:
if (!(!fsig->hasthis && fsig->param_count == 2 && mono_metadata_type_equal (fsig->ret, type) && mono_metadata_type_equal (fsig->params [0], type) && mono_metadata_type_equal (fsig->params [1], type)))
return NULL;
ins = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, args [1]->dreg);
ins->inst_c1 = etype->type;

switch (id) {
case SN_op_Addition:
ins->inst_c0 = OP_FADD;
break;
case SN_op_Division:
ins->inst_c0 = OP_FDIV;
break;
case SN_op_Multiply:
ins->inst_c0 = OP_FMUL;
break;
case SN_op_Subtraction:
ins->inst_c0 = OP_FSUB;
break;
case SN_Max:
ins->inst_c0 = OP_FMAX;
break;
case SN_Min:
ins->inst_c0 = OP_FMIN;
break;
default:
g_assert_not_reached ();
break;
}
return ins;
return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, MONO_TYPE_R4, id);
case SN_Dot: {
#ifdef TARGET_ARM64
int instc0 = OP_FMUL;
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc0, MONO_TYPE_R4, fsig, args);
return emit_sum_vector (cfg, fsig->params [0], MONO_TYPE_R4, pairwise_multiply);
#elif defined(TARGET_AMD64)
if (!(mini_get_cpu_features (cfg) & MONO_CPU_X86_SSE41))
return NULL;

Expand All @@ -1766,6 +1758,9 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
ins->inst_c1 = MONO_TYPE_R4;
MONO_ADD_INS (cfg->cbb, ins);
return ins;
#else
return NULL;
#endif
}
case SN_Abs: {
// MAX(x,0-x)
Expand Down Expand Up @@ -1793,9 +1788,15 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
return ins;
}
case SN_SquareRoot: {
#ifdef TARGET_ARM64
return emit_simd_ins_for_sig (cfg, klass, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FSQRT, MONO_TYPE_R4, fsig, args);
#elif defined(TARGET_AMD64)
ins = emit_simd_ins (cfg, klass, OP_XOP_X_X, args [0]->dreg, -1);
ins->inst_c0 = (IntrinsicId)INTRINS_SSE_SQRT_PS;
return ins;
#else
return NULL;
#endif
}
case SN_CopyTo:
// FIXME:
Expand All @@ -1807,9 +1808,9 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
return NULL;
}

#endif /* TARGET_AMD64 */
#endif // defined(TARGET_AMD64) || defined(TARGET_ARM64) || defined(TARGET_WASM)

#if defined(TARGET_AMD64)
#ifdef TARGET_AMD64

static guint16 vector_methods [] = {
SN_ConvertToDouble,
Expand Down Expand Up @@ -4029,6 +4030,12 @@ arch_emit_simd_intrinsics (const char *class_ns, const char *class_name, MonoCom
return emit_vector64_vector128_t (cfg, cmethod, fsig, args);
}

if (!strcmp (class_ns, "System.Numerics")) {
//if (!strcmp ("Vector2", class_name) || !strcmp ("Vector4", class_name) || !strcmp ("Vector3", class_name))
if (!strcmp ("Vector4", class_name))
return emit_vector_2_3_4 (cfg, cmethod, fsig, args);
}

return NULL;
}
#elif TARGET_AMD64
Expand Down

0 comments on commit c21ae04

Please sign in to comment.