Skip to content

Commit

Permalink
SVE Implementation for SDDMMCOO with copyrhs op
Browse files Browse the repository at this point in the history
SVE intrinsic code is added to improve the performance of SDDMMCOO Op
when bacst is disabled and reduce_size=1
  • Loading branch information
akote123 committed Jan 17, 2025
1 parent ba73133 commit 3254377
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ if (${BUILD_TYPE} STREQUAL "dev")
if (MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Od")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Od")
elseif ( CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(AARCH64)")
# Check if the compiler supports ARMv8.2-A or later with SVE
include(CheckCCompilerFlag)
# Try to detect whether the system supports SVE
check_c_compiler_flag("-march=armv8.2-a+sve" SUPPORTS_SVE)
# Output the result
if(SUPPORTS_SVE)
message(STATUS "Hardware supports SVE")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g3 -ggdb -march=armv8.2-a+sve")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb -march=armv8.2-a+sve")
endif()
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g3 -ggdb")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb")
Expand Down
54 changes: 53 additions & 1 deletion src/array/cpu/sddmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <dgl/bcast.h>
#include <dgl/runtime/parallel_for.h>

#ifdef __ARM_FEATURE_SVE
#include <arm_sve.h> // to leverage sve intrinsics
#endif

#include "../selector.h"

namespace dgl {
Expand Down Expand Up @@ -84,7 +88,7 @@ template <
void SDDMMCoo(
const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,
NDArray out) {
const bool has_idx = !IsNullArray(coo.data);
const bool has_idx = !IsNullArray(coo.data);
const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>();
const IdType* edges = coo.data.Ptr<IdType>();
Expand Down Expand Up @@ -222,6 +226,54 @@ struct Dot {

} // namespace op

// SDDMMCoo Specialization
#ifdef __ARM_FEATURE_SVE
template <>
void SDDMMCoo <int32_t, float, dgl::aten::cpu::op::CopyRhs<float>, 0, 2> (
const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out) {
const bool has_idx = !IsNullArray(coo.data);
const int32_t* row = coo.row.Ptr<int32_t>();
const int32_t* col = coo.col.Ptr<int32_t>();
const int32_t* edges = coo.data.Ptr<int32_t>();
const float* X = lhs.Ptr<float>();
const float* Y = rhs.Ptr<float>();
const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
float* O = out.Ptr<float>();
#pragma omp parallel for
for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
const int32_t rid = row[i];
const int32_t cid = col[i];
const int32_t eid = has_idx ? edges[i] : i;
float* out_off = O + eid * dim;
if (!bcast.use_bcast && reduce_size == 1) {
for (int64_t k = 0; k < dim; k += svcntw()) {
svbool_t pgk = svwhilelt_b32(k, dim);
int64_t rhs_base1 = cid * rhs_dim;
svfloat32_t rhs_off_vector = svld1_f32(pgk, &Y[rhs_base1 + k]);
svst1_f32(pgk, &out_off[k], rhs_off_vector);
}
} else {
//with bcast.use_bcast == true, Op::use_lhs == false, and Op::Call
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const float* lhs_off =
dgl::aten::cpu::op::CopyRhs<float>::use_lhs ? X + rid * lhs_dim +
lhs_add * reduce_size
: nullptr;

const float* rhs_off =
dgl::aten::cpu::op::CopyRhs<float>::use_rhs ? Y + cid * rhs_dim +
rhs_add * reduce_size
: nullptr;
out_off[k] = dgl::aten::cpu::op::CopyRhs<float>::Call(lhs_off, rhs_off, bcast.reduce_size);
}
}
}
}
#endif

} // namespace cpu
} // namespace aten
} // namespace dgl
Expand Down

0 comments on commit 3254377

Please sign in to comment.