From a1607de63f12f6aa6be2d71c5ca2816347f2928a Mon Sep 17 00:00:00 2001 From: Vladimir Ovsyannikov Date: Thu, 7 Mar 2019 12:32:04 +0100 Subject: [PATCH] A few more fixes: * freeing memory * copy X, theta back to host * separate method to get factorization score * pytest --- src/gpu/factorization/als.h | 6 +- src/gpu/factorization/factorization.cu | 29 ++- src/include/solver/factorization.h | 12 + .../h2o4gpu/solvers/factorization.py | 217 +++++++++++++----- src/interface_py/setup.py | 39 ++-- src/swig/solver/factorization.i | 4 +- .../open_data/factorization/test_10M.py | 56 ++++- xgboost | 2 +- 8 files changed, 275 insertions(+), 90 deletions(-) diff --git a/src/gpu/factorization/als.h b/src/gpu/factorization/als.h index c4b0e992e..059ba3816 100644 --- a/src/gpu/factorization/als.h +++ b/src/gpu/factorization/als.h @@ -2222,9 +2222,7 @@ class ALSFactorization #ifdef DEBUG printf("update X run %f seconds, gridSize: %d, blockSize %d.\n", seconds() - t0, m, f); #endif - // cudacall(cudaFree(csrRowIndex)); - // cudacall(cudaFree(csrColIndex)); - // cudacall(cudaFree(ythetaT)); + cudacall(cudaFree(ythetaT)); #ifdef DEBUG t0 = seconds(); @@ -2363,8 +2361,6 @@ class ALSFactorization cublasHandle_t handle; cusparseHandle_t cushandle; cusparseMatDescr_t descr; - T *ytheta = 0; - T *ythetaT = 0; T *thetaT; T *XT; }; diff --git a/src/gpu/factorization/factorization.cu b/src/gpu/factorization/factorization.cu index c011bcedd..0917770a2 100644 --- a/src/gpu/factorization/factorization.cu +++ b/src/gpu/factorization/factorization.cu @@ -2,6 +2,32 @@ #include "cuda_utils.h" #include "solver/factorization.h" +template void free_data(T **ptr) { + if (ptr) + CUDACHECK(cudaFree(*ptr)); +} + +void free_data_float(float **ptr) { free_data(ptr); } + +void free_data_double(double **ptr) { free_data(ptr); } + +void free_data_int(int **ptr) { free_data(ptr); } + +template +void copy_fecatorization_result(T *dst, const T **src, const int size) { + CUDACHECK(cudaMemcpy(dst, *src, sizeof(T) * size, cudaMemcpyDeviceToHost)); +} + +void copy_fecatorization_result_float(float *dst, const float **src, + const int size) { + copy_fecatorization_result(dst, src, size); +} + +void copy_fecatorization_result_double(double *dst, const double **src, + const int size) { + copy_fecatorization_result(dst, src, size); +} + template int make_factorization_data( const int m, const int n, const int f, const long nnz, const long nnz_test, @@ -77,7 +103,8 @@ int make_factorization_data( (size_t)(nnz * sizeof(**csrValDevicePtr)), cudaMemcpyHostToDevice)); - if (cooColIndexTestHostPtr && cooRowIndexTestHostPtr && cooValTestHostPtr) { + if (cooColIndexTestHostPtr && cooRowIndexTestHostPtr && cooValTestHostPtr && + nnz_test > 0) { CUDACHECK(cudaMalloc((void **)cooRowIndexTestDevicePtr, nnz_test * sizeof(**cooRowIndexTestDevicePtr))); CUDACHECK(cudaMalloc((void **)cooColIndexTestDevicePtr, diff --git a/src/include/solver/factorization.h b/src/include/solver/factorization.h index 70977ee95..7ddbb0c52 100644 --- a/src/include/solver/factorization.h +++ b/src/include/solver/factorization.h @@ -1,5 +1,17 @@ #ifndef SRC_INCLUDE_SOLVER_FACTORIZATION_H +void free_data_float(float **ptr); + +void free_data_double(double **ptr); + +void free_data_int(int **ptr); + +void copy_fecatorization_result_float(float *dst, const float **src, + const int size); + +void copy_fecatorization_result_double(double *dst, const double **src, + const int size); + int make_factorization_data_double( const int m, const int n, const int f, const long nnz, const long nnz_test, const int *csrRowIndexHostPtr, const int *csrColIndexHostPtr, diff --git a/src/interface_py/h2o4gpu/solvers/factorization.py b/src/interface_py/h2o4gpu/solvers/factorization.py index b0dcc4ddf..c7dd939ed 100644 --- a/src/interface_py/h2o4gpu/solvers/factorization.py +++ b/src/interface_py/h2o4gpu/solvers/factorization.py @@ -11,24 +11,63 @@ import scipy.sparse -class FactorizationH2O(object): - '''[summary] +def _get_sparse_matrixes(X): + '''Create csc, csr and coo sparse matrix from any of the above Arguments: - object {[type]} -- [description] + X {array-like, csc, csr or coo sparse matrix} Returns: - [type] -- [description] + csc, csr, coo + ''' + + X_coo = X_csc = X_csr = None + if scipy.sparse.isspmatrix_coo(X): + X_coo = X + X_csr = X_coo.tocsr(True) + X_csc = X_coo.tocsc(True) + elif scipy.sparse.isspmatrix_csr(X): + X_csr = X + X_csc = X_csr.tocoo(True) + X_coo = X_csr.tocsc(True) + elif scipy.sparse.isspmatrix_csc(X): + X_csc = X + X_csr = X_csc.tocsr(True) + X_coo = X_csc.tocoo(True) + else: + assert False, "only coo, csc and csr sparse matrixes are supported" + return X_csc, X_csr, X_coo + + +class FactorizationH2O(object): + '''Factors a sparse rating matrix X (m by n, with N_z non-zero elements) + into a m-by-f and a f-by-n matrices. + + Arguments: + f {int} -- decomposition size + lambda_ {float} -- lambda regularization + + Keyword Arguments: + max_iter {int} -- number of training iterations (default: {100}) + double_precision {bool} -- use double presition, not yet supported (default: {False}) + thetaT {array-like} shape (n, f) -- initial theta matrix (default: {None}) + XT {array-like} shape (m, f) -- initial XT matrix (default: {None}) + + Attributes: + X {array-like} shape (m, f) -- X matrix contains User's features + thetaT {array-like} shape (n, f) -- transposed theta matrix, item's features ''' def __init__(self, f, lambda_, max_iter=100, double_precision=False, thetaT=None, XT=None): - self.max_iter = max_iter + assert not double_precision, 'double precision is not yet supported' assert f % 10 == 0, 'f has to be a multiple of 10' self.f = f self.lambda_ = lambda_ self.double_precision = double_precision + self.dtype = np.float64 if self.double_precision else np.float32 self.thetaT = thetaT self.XT = XT + self.max_iter = max_iter def _load_lib(self): from ..libs.lib_utils import GPUlib @@ -36,37 +75,38 @@ def _load_lib(self): gpu_lib = GPUlib().get(1) return gpu_lib - def fit(self, X, X_test=None, X_BATCHES=1, THETA_BATCHES=1, early_stopping_rounds=None, verbose=False): - '''[summary] + def fit(self, X, y=None, X_test=None, X_BATCHES=1, THETA_BATCHES=1, early_stopping_rounds=None, verbose=False, scores=None): + #pylint: disable=unused-argument + '''earn model from rating matrix X Arguments: - X {[type]} -- [description] + X {array-like, sparse matrix}, shape (m, n) -- Data matrix to be decomposed Keyword Arguments: - X_test {[type]} -- [description] (default: {None}) - X_BATCHES {int} -- [description] (default: {1}) - THETA_BATCHES {int} -- [description] (default: {1}) - early_stopping_rounds {[type]} -- [description] (default: {None}) - verbose {bool} -- [description] (default: {False}) + y -- ignored + X_test {array-like, coo sparse matrix}, shape (m, n) -- Data matrix for cross validation + X_BATCHES {int} -- batches to split XT (default: {1}) + THETA_BATCHES {int} -- batches to split theta (default: {1}) + early_stopping_rounds {int} -- Activates early stopping. Cross validation error needs to decrease + at least every round(s) to continue training. Requires <>. If there’s + more than one, will use the last. Returns the model from the last iteration (not the best one). + If early stopping occurs, the model will have three additional fields: best_cv_score, + best_train_score and best_iteration. + verbose {bool} -- prints training and validation score(if applicable) on each iteration (default: {False}) + scores {list} -- list of tuples with train, cv score for every iteration ''' + csc_X, csr_X, coo_X = _get_sparse_matrixes(X) + if early_stopping_rounds is not None: assert X_test is not None, 'X_test is mandatory with early stopping' - assert scipy.sparse.isspmatrix_csc( - X), 'X must be a csc sparse scipy matrix' if X_test is not None: assert scipy.sparse.isspmatrix_coo( X_test), 'X_test must be a coo sparse scipy matrix' assert X.shape == X_test.shape + assert X_test.dtype == self.dtype - dtype = np.float64 if self.double_precision else np.float32 - - assert X.dtype == dtype - assert X_test.dtype == dtype - - csc_X = X - csr_X = csc_X.tocsr(True) - coo_X = csc_X.tocoo(True) + assert X.dtype == self.dtype coo_X_test = X_test @@ -75,27 +115,32 @@ def fit(self, X, X_test=None, X_BATCHES=1, THETA_BATCHES=1, early_stopping_round make_data = lib.make_factorization_data_double run_step = lib.run_factorization_step_double factorization_score = lib.factorization_score_double + copy_fecatorization_result = lib.copy_fecatorization_result_double + free_data = lib.free_data_double else: make_data = lib.make_factorization_data_float run_step = lib.run_factorization_step_float factorization_score = lib.factorization_score_float + copy_fecatorization_result = lib.copy_fecatorization_result_float + free_data = lib.free_data_float m = coo_X.shape[0] n = coo_X.shape[1] nnz = csc_X.nnz - nnz_test = coo_X_test.nnz + if coo_X_test is None: + nnz_test = 0 + else: + nnz_test = coo_X_test.nnz if self.thetaT is None: - thetaT = np.random.rand(n, self.f).astype(dtype) + self.thetaT = np.random.rand(n, self.f).astype(self.dtype) else: - thetaT = self.thetaT - assert thetaT.dtype == dtype + assert self.thetaT.dtype == self.dtype if self.XT is None: - XT = np.random.rand(m, self.f).astype(dtype) + self.XT = np.random.rand(m, self.f).astype(self.dtype) else: - XT = self.XT - XT.dtype = dtype + assert self.XT.dtype == self.dtype csrRowIndexDevicePtr = None csrColIndexDevicePtr = None @@ -120,17 +165,20 @@ def fit(self, X, X_test=None, X_BATCHES=1, THETA_BATCHES=1, early_stopping_round m, n, self.f, nnz, nnz_test, csr_X.indptr, csr_X.indices, csr_X.data, csc_X.indices, csc_X.indptr, csc_X.data, coo_X.row, coo_X.col, coo_X.data, - thetaT, XT, coo_X_test.row, - coo_X_test.col, coo_X_test.data, csrRowIndexDevicePtr, csrColIndexDevicePtr, - csrValDevicePtr, cscRowIndexDevicePtr, cscColIndexDevicePtr, cscValDevicePtr, + self.thetaT, self.XT, coo_X_test.row if coo_X_test is not None else None, + coo_X_test.col if coo_X_test is not None else None, coo_X_test.data if coo_X_test is not None else None, + csrRowIndexDevicePtr, csrColIndexDevicePtr, csrValDevicePtr, cscRowIndexDevicePtr, cscColIndexDevicePtr, cscValDevicePtr, cooRowIndexDevicePtr, cooColIndexDevicePtr, cooValDevicePtr, thetaTDevice, XTDevice, cooRowIndexTestDevicePtr, cooColIndexTestDevicePtr, cooValTestDevicePtr) assert status == 0, 'Failure uploading the data' - best_CV = np.inf - best_Iter = -1 + self.best_train_score = np.inf + self.best_cv_score = np.inf + self.best_iteration = -1 + cv_score = train_score = np.inf + for i in range(self.max_iter): status = run_step(m, n, @@ -147,35 +195,82 @@ def fit(self, X, X_test=None, X_BATCHES=1, THETA_BATCHES=1, early_stopping_round XTDevice, X_BATCHES, THETA_BATCHES) - result = factorization_score(m, - n, - self.f, - nnz, - self.lambda_, - thetaTDevice, - XTDevice, - cooRowIndexDevicePtr, - cooColIndexDevicePtr, - cooValDevicePtr) - train_score = result[0] - result = factorization_score(m, - n, - self.f, - nnz_test, - self.lambda_, - thetaTDevice, - XTDevice, - cooRowIndexTestDevicePtr, - cooColIndexTestDevicePtr, - cooValTestDevicePtr) - cv_score = result[0] + if verbose or scores is not None: + result = factorization_score(m, + n, + self.f, + nnz, + self.lambda_, + thetaTDevice, + XTDevice, + cooRowIndexDevicePtr, + cooColIndexDevicePtr, + cooValDevicePtr) + train_score = result[0] + if X_test is not None and (verbose or early_stopping_rounds is not None or scores is not None): + result = factorization_score(m, + n, + self.f, + nnz_test, + self.lambda_, + thetaTDevice, + XTDevice, + cooRowIndexTestDevicePtr, + cooColIndexTestDevicePtr, + cooValTestDevicePtr) + cv_score = result[0] if verbose: print("iteration {0} train: {1} cv: {2}".format( i, train_score, cv_score)) + if scores is not None: + scores.append((train_score, cv_score)) if early_stopping_rounds is not None: - if best_CV > cv_score: - best_CV = cv_score - best_Iter = i - if (i - best_Iter) > early_stopping_rounds: + if self.best_cv_score > cv_score: + self.best_cv_score = cv_score + self.best_train_score = train_score + self.best_iteration = i + if (i - self.best_iteration) > early_stopping_rounds: + if verbose: + print('best iteration:{0} train: {1} cv: {2}'.format( + self.best_iteration, self.best_train_score, self.best_cv_score)) break + + lib.free_data_int(csrRowIndexDevicePtr) + lib.free_data_int(csrColIndexDevicePtr) + free_data(csrValDevicePtr) + lib.free_data_int(cscRowIndexDevicePtr) + lib.free_data_int(cscColIndexDevicePtr) + free_data(cscValDevicePtr) + lib.free_data_int(cooRowIndexDevicePtr) + lib.free_data_int(cooColIndexDevicePtr) + free_data(cooValDevicePtr) + lib.free_data_int(cooRowIndexTestDevicePtr) + lib.free_data_int(cooColIndexTestDevicePtr) + free_data(cooValTestDevicePtr) + + copy_fecatorization_result(self.XT, XTDevice, m * self.f) + copy_fecatorization_result(self.thetaT, thetaTDevice, n * self.f) + + free_data(thetaTDevice) + free_data(XTDevice) + + return self + + def predict(self, X): + '''Predict none zero elements of coo sparse matrix X according to the fitted model + Arguments: + X {array-like, sparse coo matrix} shape (m, n) + Data matrix in coo format + Returns + prediction : array,shape (m, n) + ''' + + assert self.XT is not None and self.thetaT is not None, 'tranform is invoked on an unfitted model' + assert scipy.sparse.isspmatrix_coo( + X), 'convert X to coo sparse matrix' + assert X.dtype == self.dtype + a = np.take(self.XT, X.row, axis=0) + b = np.take(self.thetaT, X.col, axis=0) + val = np.sum(a * b, axis=1) + return scipy.sparse.coo_matrix((val, (X.row, X.col)), shape=X.shape) diff --git a/src/interface_py/setup.py b/src/interface_py/setup.py index 4e7757a2b..8c6d936da 100644 --- a/src/interface_py/setup.py +++ b/src/interface_py/setup.py @@ -1,3 +1,5 @@ +from setuptools.dist import Distribution +from setuptools import setup """ :copyright: 2017-2018 H2O.ai, Inc. :license: Apache License Version 2.0 (see LICENSE for details) @@ -13,6 +15,7 @@ class H2O4GPUBuild(build): """H2O4GPU library compiler""" + def run(self): """Run the compilation""" NVCC = os.popen("which nvcc").read() != "" @@ -64,14 +67,16 @@ def run(self): # install H2O4GPU executables self.copy_tree(self.build_lib, self.install_lib) + # reqs is a list of requirement # e.g. ['django==1.5.1', 'mezzanine==1.4.6'] requirements_file = 'requirements_runtime.txt' -import os if os.environ.get('CONDA_BUILD_STATE') is not None: requirements_file = 'requirements_conda.txt' with open(requirements_file, "r") as fs: - reqs = [r for r in fs.read().splitlines() if (len(r) > 0 and not r.startswith("#"))] + reqs = [r for r in fs.read().splitlines() if ( + len(r) > 0 and not r.startswith("#"))] + def get_packages(directory): paths = [] @@ -82,35 +87,40 @@ def get_packages(directory): paths.append(path[2:]) return paths + packages = get_packages('./') package_data = {} for package in packages: - package_data[package] = ['*'] + package_data[package] = ['*'] -import os -from setuptools import setup -from setuptools.dist import Distribution class BinaryDistribution(Distribution): def is_pure(self): return False + # Read version -about_info={} -with open('__about__.py') as f: exec(f.read(), about_info) +about_info = {} +with open('__about__.py') as f: + exec(f.read(), about_info) lines = [] -lines.append("__version__ = '" + about_info['__build_info__']['base_version'] + "'") -lines.append("__git_revision__ = '" + about_info['__build_info__']['commit'] + "'") -lines.append("__cuda_version__ = '" + about_info['__build_info__']['cuda_version'] + "'") -lines.append("__cuda_nccl__ = '" + about_info['__build_info__']['cuda_nccl'] + "'") -with open('build_info.txt','w') as fp: +lines.append("__version__ = '" + + about_info['__build_info__']['base_version'] + "'") +lines.append("__git_revision__ = '" + + about_info['__build_info__']['commit'] + "'") +lines.append("__cuda_version__ = '" + + about_info['__build_info__']['cuda_version'] + "'") +lines.append("__cuda_nccl__ = '" + + about_info['__build_info__']['cuda_nccl'] + "'") +with open('build_info.txt', 'w') as fp: fp.write('\n'.join(lines)+'\n') # Make the .whl contain required python and OS as we are version and distro specific try: from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + class bdist_wheel(_bdist_wheel): def finalize_options(self): _bdist_wheel.finalize_options(self) @@ -134,5 +144,6 @@ def finalize_options(self): zip_safe=False, description='H2O.ai GPU Edition', install_requires=reqs, - cmdclass={'bdist_wheel': bdist_wheel, 'build': H2O4GPUBuild, 'install': H2O4GPUInstall}, + cmdclass={'bdist_wheel': bdist_wheel, + 'build': H2O4GPUBuild, 'install': H2O4GPUInstall}, ) diff --git a/src/swig/solver/factorization.i b/src/swig/solver/factorization.i index 93ff55aed..cd8156984 100644 --- a/src/swig/solver/factorization.i +++ b/src/swig/solver/factorization.i @@ -36,7 +36,7 @@ %apply (int *IN_ARRAY1) {int* csrRowIndexHostPtr, int* csrColIndexHostPtr, int* cscRowIndexHostPtr, int* cscColIndexHostPtr, int* cooRowIndexHostPtr, int* cooColIndexHostPtr, int* cooRowIndexTestHostPtr, int* cooColIndexTestHostPtr}; -%apply (float *IN_ARRAY1) {float* csrValHostPtr, float* cscValHostPtr, float* cooValHostPtr, float* thetaTHost, float* XTHost, float* cooValTestHostPtr}; -%apply (double *IN_ARRAY1) {double* csrValHostPtr, double* cscValHostPtr, double* cooValHostPtr, double* thetaTHost, double* XTHost, double* cooValTestHostPtr}; +%apply (float *IN_ARRAY1) {float* dst, float* csrValHostPtr, float* cscValHostPtr, float* cooValHostPtr, float* thetaTHost, float* XTHost, float* cooValTestHostPtr}; +%apply (double *IN_ARRAY1) {double* dst, double* csrValHostPtr, double* cscValHostPtr, double* cooValHostPtr, double* thetaTHost, double* XTHost, double* cooValTestHostPtr}; %include "../../include/solver/factorization.h" diff --git a/tests/python/open_data/factorization/test_10M.py b/tests/python/open_data/factorization/test_10M.py index 2ffca6dd3..e78264284 100644 --- a/tests/python/open_data/factorization/test_10M.py +++ b/tests/python/open_data/factorization/test_10M.py @@ -2,9 +2,11 @@ import scipy import scipy.sparse import h2o4gpu +from sklearn.metrics import mean_squared_error +from math import sqrt -def test_factorization(): +def _load_train_test(): R_csc_data = np.fromfile( 'open_data/factorization/R_train_csc.data.bin', dtype=np.float32) R_csc_indices = np.fromfile( @@ -18,14 +20,56 @@ def test_factorization(): 'open_data/factorization/R_test_coo.row.bin', dtype=np.int32) R_test_coo_data = np.fromfile( 'open_data/factorization/R_test_coo.data.bin', dtype=np.float32) - - factorization = h2o4gpu.solvers.FactorizationH2O( - 50, 0.1, max_iter=10, double_precision=False) X = scipy.sparse.csc_matrix((R_csc_data, R_csc_indices, R_csc_indptr)) X_test = scipy.sparse.coo_matrix( (R_test_coo_data, (R_test_coo_row, R_test_coo_col)), shape=X.shape) - factorization.fit(X, X_test) + return X, X_test + + +def test_factorization_memory_leak(): + for i in range(100): + X, _ = _load_train_test() + factorization = h2o4gpu.solvers.FactorizationH2O(10, 0.1, max_iter=1) + factorization.fit(X) + + +def test_factorization_fit_predict(): + X, X_test = _load_train_test() + scores = [] + factorization = h2o4gpu.solvers.FactorizationH2O( + 50, 0.1, max_iter=10) + factorization.fit(X, scores=scores) + X_pred = factorization.predict(X.tocoo()) + not_nan = ~np.isnan(X_pred.data) + assert np.allclose(sqrt(mean_squared_error( + X.data[not_nan], X_pred.data[not_nan])), scores[-1][0]) + + +def test_early_stop(): + X, X_test = _load_train_test() + scores = [] + factorization = h2o4gpu.solvers.FactorizationH2O( + 50, 0.01, max_iter=10000) + factorization.fit(X, scores=scores, X_test=X_test, + early_stopping_rounds=10, verbose=True) + best = factorization.best_iteration + for i in range(best, best + 10, 1): + assert scores[best][1] <= scores[i][1] + + +def test_multi_batches(): + X, X_test = _load_train_test() + scores = [] + factorization = h2o4gpu.solvers.FactorizationH2O( + 80, 0.1, max_iter=40) + factorization.fit(X, scores=scores, X_BATCHES=2, THETA_BATCHES=2) + X_pred = factorization.predict(X.tocoo()) + not_nan = ~np.isnan(X_pred.data) + assert np.allclose(sqrt(mean_squared_error( + X.data[not_nan], X_pred.data[not_nan])), scores[-1][0]) if __name__ == '__main__': - test_factorization() + test_early_stop() + test_factorization_fit_predict() + test_multi_batches() diff --git a/xgboost b/xgboost index 9c709bdae..bb91675c7 160000 --- a/xgboost +++ b/xgboost @@ -1 +1 @@ -Subproject commit 9c709bdaec7e45fa496a97bc44388d222a50804d +Subproject commit bb91675c7d7c448126a3da4c6ce7b483e39b89d3