Skip to content

Commit

Permalink
Port oneccl to new C++ apis
Browse files Browse the repository at this point in the history
  • Loading branch information
xwu99 committed Jan 14, 2021
1 parent e19e9c9 commit ac216f4
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 53 deletions.
35 changes: 14 additions & 21 deletions mllib-dal/src/main/native/KMeansDALImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
* limitations under the License.
*******************************************************************************/

#include <ccl.h>
#include <oneapi/ccl.hpp>
#include <daal.h>
#include <iostream>
#include <chrono>

#include "service.h"
#include "org_apache_spark_ml_clustering_KMeansDALImpl.h"
#include <iostream>
#include <chrono>
#include "OneCCL.h"

using namespace std;
using namespace daal;
Expand All @@ -30,7 +31,8 @@ const int ccl_root = 0;

typedef double algorithmFPType; /* Algorithm floating-point type */

static NumericTablePtr kmeans_compute(int rankId, const NumericTablePtr & pData, const NumericTablePtr & initialCentroids,
static NumericTablePtr kmeans_compute(int rankId, ccl::communicator &comm,
const NumericTablePtr & pData, const NumericTablePtr & initialCentroids,
size_t nClusters, size_t nBlocks, algorithmFPType &ret_cost)
{
const bool isRoot = (rankId == ccl_root);
Expand All @@ -43,17 +45,13 @@ static NumericTablePtr kmeans_compute(int rankId, const NumericTablePtr & pData,
CentroidsArchLength = inputArch.getSizeOfArchive();
}

ccl_request_t request;

/* Get partial results from the root node */
ccl_bcast(&CentroidsArchLength, sizeof(size_t), ccl_dtype_char, ccl_root, NULL, NULL, NULL, &request);
ccl_wait(request);
ccl::broadcast(&CentroidsArchLength, sizeof(size_t), ccl::datatype::uint8, ccl_root, comm).wait();

ByteBuffer nodeCentroids(CentroidsArchLength);
if (isRoot) inputArch.copyArchiveToArray(&nodeCentroids[0], CentroidsArchLength);

ccl_bcast(&nodeCentroids[0], CentroidsArchLength, ccl_dtype_char, ccl_root, NULL, NULL, NULL, &request);
ccl_wait(request);
ccl::broadcast(&nodeCentroids[0], CentroidsArchLength, ccl::datatype::uint8, ccl_root, comm).wait();

/* Deserialize centroids data */
OutputDataArchive outArch(nodeCentroids.size() ? &nodeCentroids[0] : NULL, CentroidsArchLength);
Expand All @@ -79,7 +77,7 @@ static NumericTablePtr kmeans_compute(int rankId, const NumericTablePtr & pData,
ByteBuffer serializedData;

/* Serialized data is of equal size on each node if each node called compute() equal number of times */
size_t* recvCounts = new size_t[nBlocks];
vector<size_t> recvCounts(nBlocks);
for (size_t i = 0; i < nBlocks; i++)
{
recvCounts[i] = perNodeArchLength;
Expand All @@ -90,10 +88,7 @@ static NumericTablePtr kmeans_compute(int rankId, const NumericTablePtr & pData,
dataArch.copyArchiveToArray(&nodeResults[0], perNodeArchLength);

/* Transfer partial results to step 2 on the root node */
ccl_allgatherv(&nodeResults[0], perNodeArchLength, &serializedData[0], recvCounts, ccl_dtype_char, NULL, NULL, NULL, &request);
ccl_wait(request);

delete [] recvCounts;
ccl::allgatherv(&nodeResults[0], perNodeArchLength, &serializedData[0], recvCounts, ccl::datatype::uint8, comm).wait();

if (isRoot)
{
Expand Down Expand Up @@ -168,8 +163,8 @@ JNIEXPORT jlong JNICALL Java_org_apache_spark_ml_clustering_KMeansDALImpl_cKMean
jint executor_num, jint executor_cores,
jobject resultObj) {

size_t rankId;
ccl_get_comm_rank(NULL, &rankId);
ccl::communicator *comm = getComm();
size_t rankId = comm->rank();

NumericTablePtr pData = *((NumericTablePtr *)pNumTabData);
NumericTablePtr centroids = *((NumericTablePtr *)pNumTabCenters);
Expand All @@ -189,16 +184,14 @@ JNIEXPORT jlong JNICALL Java_org_apache_spark_ml_clustering_KMeansDALImpl_cKMean
for (it = 0; it < iteration_num && !converged; it++) {
auto t1 = std::chrono::high_resolution_clock::now();

newCentroids = kmeans_compute(rankId, pData, centroids, cluster_num, executor_num, totalCost);
newCentroids = kmeans_compute(rankId, *comm, pData, centroids, cluster_num, executor_num, totalCost);

if (rankId == ccl_root) {
converged = areAllCentersConverged(centroids, newCentroids, tolerance);
}

// Sync converged status
ccl_request_t request;
ccl_bcast(&converged, 1, ccl_dtype_char, ccl_root, NULL, NULL, NULL, &request);
ccl_wait(request);
ccl::broadcast(&converged, 1, ccl::datatype::uint8, ccl_root, *comm).wait();

centroids = newCentroids;

Expand Down
46 changes: 25 additions & 21 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
#include <iostream>
#include <ccl.h>
#include <oneapi/ccl.hpp>
#include "org_apache_spark_ml_util_OneCCL__.h"

// todo: fill initial comm_size and rank_id
size_t comm_size;
size_t rank_id;

ccl::communicator *getComm() {
ccl::shared_ptr_class<ccl::kvs> kvs;
static ccl::communicator b = ccl::create_communicator(comm_size, rank_id, kvs);
return &b;
}

JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init
(JNIEnv *env, jobject obj, jobject param) {

std::cout << "oneCCL (native): init" << std::endl;

ccl_init();
ccl::init();

jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");
ccl::shared_ptr_class<ccl::kvs> kvs;
ccl::kvs::address_type main_addr;
kvs = ccl::create_kvs(main_addr);

auto comm = getComm();

size_t comm_size;
size_t rank_id;
rank_id = comm->rank();
comm_size = comm->size();

ccl_get_comm_size(NULL, &comm_size);
ccl_get_comm_rank(NULL, &rank_id);
jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");

env->SetLongField(param, fid_comm_size, comm_size);
env->SetLongField(param, fid_rank_id, rank_id);
Expand All @@ -35,7 +48,6 @@ JNIEXPORT void JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1cleanup

std::cout << "oneCCL (native): cleanup" << std::endl;

ccl_finalize();
}

/*
Expand All @@ -44,12 +56,9 @@ JNIEXPORT void JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1cleanup
* Signature: ()Z
*/
JNIEXPORT jboolean JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_isRoot
(JNIEnv *env, jobject obj) {

size_t rank_id;
ccl_get_comm_rank(NULL, &rank_id);
(JNIEnv *env, jobject obj) {

return (rank_id == 0);
return getComm()->rank() == 0;
}

/*
Expand All @@ -59,12 +68,7 @@ JNIEXPORT jboolean JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_isRoot
*/
JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_rankID
(JNIEnv *env, jobject obj) {

size_t rank_id;
ccl_get_comm_rank(NULL, &rank_id);

return rank_id;

return getComm()->rank();
}

/*
Expand Down
5 changes: 5 additions & 0 deletions mllib-dal/src/main/native/OneCCL.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include <oneapi/ccl.hpp>

ccl::communicator *getComm();
18 changes: 7 additions & 11 deletions mllib-dal/src/main/native/PCADALImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <ccl.h>
#include <daal.h>

#include "service.h"
Expand All @@ -7,6 +6,7 @@
#include <iostream>

#include "org_apache_spark_ml_feature_PCADALImpl.h"
#include "OneCCL.h"

using namespace std;
using namespace daal;
Expand All @@ -24,8 +24,9 @@ typedef double algorithmFPType; /* Algorithm floating-point type */
JNIEXPORT jlong JNICALL Java_org_apache_spark_ml_feature_PCADALImpl_cPCATrainDAL(
JNIEnv *env, jobject obj, jlong pNumTabData, jint k, jint executor_num, jint executor_cores,
jobject resultObj) {
size_t rankId;
ccl_get_comm_rank(NULL, &rankId);

ccl::communicator *comm = getComm();
size_t rankId = comm->rank();

const size_t nBlocks = executor_num;
const int comm_size = executor_num;
Expand Down Expand Up @@ -59,9 +60,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_spark_ml_feature_PCADALImpl_cPCATrainDAL
byte* nodeResults = new byte[perNodeArchLength];
dataArch.copyArchiveToArray(nodeResults, perNodeArchLength);

ccl_request_t request;

size_t* recv_counts = new size_t[comm_size * perNodeArchLength];
vector<size_t> recv_counts(comm_size * perNodeArchLength);
for (int i = 0; i < comm_size; i++) recv_counts[i] = perNodeArchLength;

cout << "PCA (native): ccl_allgatherv receiving " << perNodeArchLength * nBlocks << " bytes" << endl;
Expand All @@ -71,17 +70,14 @@ JNIEXPORT jlong JNICALL Java_org_apache_spark_ml_feature_PCADALImpl_cPCATrainDAL
/* Transfer partial results to step 2 on the root node */
// MPI_Gather(nodeResults, perNodeArchLength, MPI_CHAR, serializedData.get(),
// perNodeArchLength, MPI_CHAR, ccl_root, MPI_COMM_WORLD);
ccl_allgatherv(nodeResults, perNodeArchLength, serializedData.get(), recv_counts,
ccl_dtype_char, NULL, NULL, NULL, &request);
ccl_wait(request);
ccl::allgatherv(nodeResults, perNodeArchLength, serializedData.get(), recv_counts,
ccl::datatype::uint8, *comm).wait();

auto t2 = std::chrono::high_resolution_clock::now();

auto duration = std::chrono::duration_cast<std::chrono::seconds>( t2 - t1 ).count();
std::cout << "PCA (native): ccl_allgatherv took " << duration << " secs" << std::endl;

delete[] nodeResults;

if (rankId == ccl_root) {
auto t1 = std::chrono::high_resolution_clock::now();

Expand Down

0 comments on commit ac216f4

Please sign in to comment.