diff --git a/CMakeLists.txt b/CMakeLists.txt index ce6b348c3ea3..3a810274f670 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,6 +122,17 @@ if (${BUILD_TYPE} STREQUAL "dev") if (MSVC) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Od") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Od") + elseif ( CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(AARCH64)") + # Check if the compiler supports ARMv8.2-A or later with SVE + include(CheckCCompilerFlag) + # Try to detect whether the system supports SVE + check_c_compiler_flag("-march=armv8.2-a+sve" SUPPORTS_SVE) + # Output the result + if(SUPPORTS_SVE) + message(STATUS "Hardware supports SVE") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g3 -ggdb -march=armv8.2-a+sve") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb -march=armv8.2-a+sve") + endif() else() set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g3 -ggdb") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb") diff --git a/src/array/cpu/sddmm.h b/src/array/cpu/sddmm.h index 9e372bfc3ac2..8a9d795c121f 100644 --- a/src/array/cpu/sddmm.h +++ b/src/array/cpu/sddmm.h @@ -10,6 +10,10 @@ #include #include +#ifdef __ARM_FEATURE_SVE +#include // to leverage sve intrinsics +#endif + #include "../selector.h" namespace dgl { @@ -84,7 +88,7 @@ template < void SDDMMCoo( const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out) { - const bool has_idx = !IsNullArray(coo.data); + const bool has_idx = !IsNullArray(coo.data); const IdType* row = coo.row.Ptr(); const IdType* col = coo.col.Ptr(); const IdType* edges = coo.data.Ptr(); @@ -222,6 +226,54 @@ struct Dot { } // namespace op +// SDDMMCoo Specialization +#ifdef __ARM_FEATURE_SVE +template <> +void SDDMMCoo , 0, 2> ( + const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out) { + const bool has_idx = !IsNullArray(coo.data); + const int32_t* row = coo.row.Ptr(); + const int32_t* col = coo.col.Ptr(); + const int32_t* edges = coo.data.Ptr(); + const float* X = lhs.Ptr(); + const float* Y = rhs.Ptr(); + const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, + rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size; + float* O = out.Ptr(); +#pragma omp parallel for + for (int64_t i = 0; i < coo.row->shape[0]; ++i) { + const int32_t rid = row[i]; + const int32_t cid = col[i]; + const int32_t eid = has_idx ? edges[i] : i; + float* out_off = O + eid * dim; + if (!bcast.use_bcast && reduce_size == 1) { + for (int64_t k = 0; k < dim; k += svcntw()) { + svbool_t pgk = svwhilelt_b32(k, dim); + int64_t rhs_base1 = cid * rhs_dim; + svfloat32_t rhs_off_vector = svld1_f32(pgk, &Y[rhs_base1 + k]); + svst1_f32(pgk, &out_off[k], rhs_off_vector); + } + } else { + //with bcast.use_bcast == true, Op::use_lhs == false, and Op::Call + for (int64_t k = 0; k < dim; ++k) { + const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k; + const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; + const float* lhs_off = + dgl::aten::cpu::op::CopyRhs::use_lhs ? X + rid * lhs_dim + + lhs_add * reduce_size + : nullptr; + + const float* rhs_off = + dgl::aten::cpu::op::CopyRhs::use_rhs ? Y + cid * rhs_dim + + rhs_add * reduce_size + : nullptr; + out_off[k] = dgl::aten::cpu::op::CopyRhs::Call(lhs_off, rhs_off, bcast.reduce_size); + } + } + } +} +#endif + } // namespace cpu } // namespace aten } // namespace dgl