Skip to content

Commit

Permalink
Merge pull request oap-project#11 from xwu99/new-oneccl
Browse files Browse the repository at this point in the history
[ML-10] Porting Kmeans and PCA to new oneCCL API
  • Loading branch information
xwu99 authored Feb 5, 2021
2 parents db2414c + 318cae1 commit b60158b
Show file tree
Hide file tree
Showing 16 changed files with 166 additions and 122 deletions.
2 changes: 1 addition & 1 deletion dev/install-build-deps-centos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ cd /tmp
rm -rf oneCCL
git clone https://github.com/oneapi-src/oneCCL
cd oneCCL
git checkout beta08
git checkout 2021.1
mkdir -p build && cd build
cmake ..
make -j 2 install
Expand Down
2 changes: 1 addition & 1 deletion dev/install-build-deps-ubuntu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ echo "Building oneCCL ..."
cd /tmp
git clone https://github.com/oneapi-src/oneCCL
cd oneCCL
git checkout beta08
git checkout 2021.1
mkdir build && cd build
cmake ..
make -j 2 install
Expand Down
32 changes: 32 additions & 0 deletions mllib-dal/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env bash

# Check envs for building
if [[ -z $JAVA_HOME ]]; then
echo $JAVA_HOME not defined!
exit 1
fi

if [[ -z $DAALROOT ]]; then
echo DAALROOT not defined!
exit 1
fi

if [[ -z $TBBROOT ]]; then
echo TBBROOT not defined!
exit 1
fi

if [[ -z $CCL_ROOT ]]; then
echo CCL_ROOT not defined!
exit 1
fi

echo === Building Environments ===
echo JAVA_HOME=$JAVA_HOME
echo DAALROOT=$DAALROOT
echo TBBROOT=$TBBROOT
echo CCL_ROOT=$CCL_ROOT
echo GCC Version: $(gcc -dumpversion)
echo =============================

mvn -DskipTests clean package
16 changes: 11 additions & 5 deletions mllib-dal/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,12 @@
<resource>
<directory>${env.CCL_ROOT}/lib</directory>
<includes>
<include>libpmi.so.1</include>
<include>libresizable_pmi.so.1</include>
<!--<include>libpmi.so.1</include>-->
<!--<include>libresizable_pmi.so.1</include>-->
<include>libmpi.so.12.0.0</include>
<include>libfabric.so.1</include>
<include>libccl_atl_ofi.so.1</include>
<include>libccl.so</include>
<!--<include>libccl_atl_ofi.so.1</include>-->
</includes>
</resource>
<resource>
Expand Down Expand Up @@ -271,9 +273,13 @@
<destinationFile>${project.build.testOutputDirectory}/lib/libtbbmalloc.so.2</destinationFile>
</fileSet>
<fileSet>
<sourceFile>${project.build.testOutputDirectory}/lib/libccl_atl_ofi.so.1</sourceFile>
<destinationFile>${project.build.testOutputDirectory}/lib/libccl_atl_ofi.so</destinationFile>
<sourceFile>${project.build.testOutputDirectory}/lib/libmpi.so.12.0.0</sourceFile>
<destinationFile>${project.build.testOutputDirectory}/lib/libmpi.so.12</destinationFile>
</fileSet>
<!--<fileSet>-->
<!--<sourceFile>${project.build.testOutputDirectory}/lib/libccl_atl_ofi.so.1</sourceFile>-->
<!--<destinationFile>${project.build.testOutputDirectory}/lib/libccl_atl_ofi.so</destinationFile>-->
<!--</fileSet>-->
</fileSets>
</configuration>
</execution>
Expand Down
13 changes: 4 additions & 9 deletions mllib-dal/src/assembly/assembly.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,21 @@
</file>
<!-- Include oneCCL libraries into JAR -->
<file>
<source>${env.CCL_ROOT}/lib/libpmi.so.1</source>
<source>${env.CCL_ROOT}/lib/libfabric.so.1</source>
<outputDirectory>lib</outputDirectory>
</file>
<file>
<source>${env.CCL_ROOT}/lib/libresizable_pmi.so.1</source>
<source>${env.CCL_ROOT}/lib/libmpi.so.12.0.0</source>
<outputDirectory>lib</outputDirectory>
<destName>libmpi.so.12</destName>
</file>
<file>
<source>${env.CCL_ROOT}/lib//libfabric.so.1</source>
<source>${env.CCL_ROOT}/lib/libccl.so</source>
<outputDirectory>lib</outputDirectory>
</file>
<file>
<source>${env.CCL_ROOT}/lib/prov/libsockets-fi.so</source>
<outputDirectory>lib</outputDirectory>
</file>
<file>
<!-- Should rename to XXX.so for ATL to load -->
<source>${env.CCL_ROOT}/lib/libccl_atl_ofi.so.1</source>
<outputDirectory>lib</outputDirectory>
<destName>libccl_atl_ofi.so</destName>
</file>
</files>
</assembly>
11 changes: 5 additions & 6 deletions mllib-dal/src/main/java/org/apache/spark/ml/util/LibLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,23 @@ public static synchronized void loadLibraries() throws IOException {
/**
* Load oneCCL libs in dependency order
*/
public static synchronized void loadLibCCL() throws IOException {
loadFromJar(subDir, "libpmi.so.1");
loadFromJar(subDir, "libresizable_pmi.so.1");
private static synchronized void loadLibCCL() throws IOException {
loadFromJar(subDir, "libfabric.so.1");
loadFromJar(subDir, "libmpi.so.12");
loadFromJar(subDir, "libccl.so");
loadFromJar(subDir, "libsockets-fi.so");
loadFromJar(subDir, "libccl_atl_ofi.so");
}

/**
* Load MLlibDAL lib, it depends TBB libs that are loaded by oneDAL,
* so this function should be called after oneDAL loadLibrary
*/
public static synchronized void loadLibMLlibDAL() throws IOException {
private static synchronized void loadLibMLlibDAL() throws IOException {
// oneDAL Java API doesn't load correct libtbb version for oneAPI Beta 10
// Rename in pom.xml and assembly.xml to workaround.
// See https://github.com/oneapi-src/oneDAL/issues/1254 -->
LibUtils.loadLibrary();

loadFromJar(subDir, "libMLlibDAL.so");
}

Expand Down
35 changes: 14 additions & 21 deletions mllib-dal/src/main/native/KMeansDALImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
* limitations under the License.
*******************************************************************************/

#include <ccl.h>
#include <oneapi/ccl.hpp>
#include <daal.h>
#include <iostream>
#include <chrono>

#include "service.h"
#include "org_apache_spark_ml_clustering_KMeansDALImpl.h"
#include <iostream>
#include <chrono>
#include "OneCCL.h"

using namespace std;
using namespace daal;
Expand All @@ -30,7 +31,8 @@ const int ccl_root = 0;

typedef double algorithmFPType; /* Algorithm floating-point type */

static NumericTablePtr kmeans_compute(int rankId, const NumericTablePtr & pData, const NumericTablePtr & initialCentroids,
static NumericTablePtr kmeans_compute(int rankId, ccl::communicator &comm,
const NumericTablePtr & pData, const NumericTablePtr & initialCentroids,
size_t nClusters, size_t nBlocks, algorithmFPType &ret_cost)
{
const bool isRoot = (rankId == ccl_root);
Expand All @@ -43,17 +45,13 @@ static NumericTablePtr kmeans_compute(int rankId, const NumericTablePtr & pData,
CentroidsArchLength = inputArch.getSizeOfArchive();
}

ccl_request_t request;

/* Get partial results from the root node */
ccl_bcast(&CentroidsArchLength, sizeof(size_t), ccl_dtype_char, ccl_root, NULL, NULL, NULL, &request);
ccl_wait(request);
ccl::broadcast(&CentroidsArchLength, sizeof(size_t), ccl::datatype::uint8, ccl_root, comm).wait();

ByteBuffer nodeCentroids(CentroidsArchLength);
if (isRoot) inputArch.copyArchiveToArray(&nodeCentroids[0], CentroidsArchLength);

ccl_bcast(&nodeCentroids[0], CentroidsArchLength, ccl_dtype_char, ccl_root, NULL, NULL, NULL, &request);
ccl_wait(request);
ccl::broadcast(&nodeCentroids[0], CentroidsArchLength, ccl::datatype::uint8, ccl_root, comm).wait();

/* Deserialize centroids data */
OutputDataArchive outArch(nodeCentroids.size() ? &nodeCentroids[0] : NULL, CentroidsArchLength);
Expand All @@ -79,7 +77,7 @@ static NumericTablePtr kmeans_compute(int rankId, const NumericTablePtr & pData,
ByteBuffer serializedData;

/* Serialized data is of equal size on each node if each node called compute() equal number of times */
size_t* recvCounts = new size_t[nBlocks];
vector<size_t> recvCounts(nBlocks);
for (size_t i = 0; i < nBlocks; i++)
{
recvCounts[i] = perNodeArchLength;
Expand All @@ -90,10 +88,7 @@ static NumericTablePtr kmeans_compute(int rankId, const NumericTablePtr & pData,
dataArch.copyArchiveToArray(&nodeResults[0], perNodeArchLength);

/* Transfer partial results to step 2 on the root node */
ccl_allgatherv(&nodeResults[0], perNodeArchLength, &serializedData[0], recvCounts, ccl_dtype_char, NULL, NULL, NULL, &request);
ccl_wait(request);

delete [] recvCounts;
ccl::allgatherv(&nodeResults[0], perNodeArchLength, &serializedData[0], recvCounts, ccl::datatype::uint8, comm).wait();

if (isRoot)
{
Expand Down Expand Up @@ -168,8 +163,8 @@ JNIEXPORT jlong JNICALL Java_org_apache_spark_ml_clustering_KMeansDALImpl_cKMean
jint executor_num, jint executor_cores,
jobject resultObj) {

size_t rankId;
ccl_get_comm_rank(NULL, &rankId);
ccl::communicator &comm = getComm();
size_t rankId = comm.rank();

NumericTablePtr pData = *((NumericTablePtr *)pNumTabData);
NumericTablePtr centroids = *((NumericTablePtr *)pNumTabCenters);
Expand All @@ -189,16 +184,14 @@ JNIEXPORT jlong JNICALL Java_org_apache_spark_ml_clustering_KMeansDALImpl_cKMean
for (it = 0; it < iteration_num && !converged; it++) {
auto t1 = std::chrono::high_resolution_clock::now();

newCentroids = kmeans_compute(rankId, pData, centroids, cluster_num, executor_num, totalCost);
newCentroids = kmeans_compute(rankId, comm, pData, centroids, cluster_num, executor_num, totalCost);

if (rankId == ccl_root) {
converged = areAllCentersConverged(centroids, newCentroids, tolerance);
}

// Sync converged status
ccl_request_t request;
ccl_bcast(&converged, 1, ccl_dtype_char, ccl_root, NULL, NULL, NULL, &request);
ccl_wait(request);
ccl::broadcast(&converged, 1, ccl::datatype::uint8, ccl_root, comm).wait();

centroids = newCentroids;

Expand Down
2 changes: 1 addition & 1 deletion mllib-dal/src/main/native/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ INCS := -I $(JAVA_HOME)/include \

# Use static link if possible, TBB is only available as dynamic libs

LIBS := -L${CCL_ROOT}/lib -l:libccl.a \
LIBS := -L${CCL_ROOT}/lib -lccl \
-L$(DAALROOT)/lib/intel64 -l:libdaal_core.a -l:libdaal_thread.a \
-L$(TBBROOT)/lib -ltbb -ltbbmalloc
# TODO: Add signal chaining support, should fix linking, package so and loading
Expand Down
56 changes: 34 additions & 22 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,45 @@
#include <iostream>
#include <ccl.h>
#include <oneapi/ccl.hpp>
#include "org_apache_spark_ml_util_OneCCL__.h"

// todo: fill initial comm_size and rank_id
size_t comm_size;
size_t rank_id;

std::vector<ccl::communicator> g_comms;

ccl::communicator &getComm() {
return g_comms[0];
}

JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init
(JNIEnv *env, jobject obj, jobject param) {
(JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port, jobject param) {

std::cout << "oneCCL (native): init" << std::endl;

ccl_init();
ccl::init();

jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");
const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);

auto kvs_attr = ccl::create_kvs_attr();
kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);

size_t comm_size;
size_t rank_id;
ccl::shared_ptr_class<ccl::kvs> kvs;
kvs = ccl::create_main_kvs(kvs_attr);

ccl_get_comm_size(NULL, &comm_size);
ccl_get_comm_rank(NULL, &rank_id);
g_comms.push_back(ccl::create_communicator(size, rank, kvs));

rank_id = getComm().rank();
comm_size = getComm().size();

jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");

env->SetLongField(param, fid_comm_size, comm_size);
env->SetLongField(param, fid_rank_id, rank_id);
env->ReleaseStringUTFChars(ip_port, str);

return 1;
}
Expand All @@ -33,9 +52,10 @@ JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1init
JNIEXPORT void JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1cleanup
(JNIEnv *env, jobject obj) {

g_comms.pop_back();

std::cout << "oneCCL (native): cleanup" << std::endl;

ccl_finalize();
}

/*
Expand All @@ -44,12 +64,9 @@ JNIEXPORT void JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_c_1cleanup
* Signature: ()Z
*/
JNIEXPORT jboolean JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_isRoot
(JNIEnv *env, jobject obj) {
(JNIEnv *env, jobject obj) {

size_t rank_id;
ccl_get_comm_rank(NULL, &rank_id);

return (rank_id == 0);
return getComm().rank() == 0;
}

/*
Expand All @@ -59,12 +76,7 @@ JNIEXPORT jboolean JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_isRoot
*/
JNIEXPORT jint JNICALL Java_org_apache_spark_ml_util_OneCCL_00024_rankID
(JNIEnv *env, jobject obj) {

size_t rank_id;
ccl_get_comm_rank(NULL, &rank_id);

return rank_id;

return getComm().rank();
}

/*
Expand Down
5 changes: 5 additions & 0 deletions mllib-dal/src/main/native/OneCCL.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include <oneapi/ccl.hpp>

ccl::communicator &getComm();
Loading

0 comments on commit b60158b

Please sign in to comment.