Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data split mode to DMatrix MetaInfo #8568

Merged
merged 27 commits into from
Dec 25, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
40252ee
Add data split mode to DMatrix MetaInfo
rongou Dec 7, 2022
7c35c40
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 8, 2022
26ed1a9
remove dsplit training param
rongou Dec 8, 2022
d3fda24
fix dmatrix validation
rongou Dec 8, 2022
8e797f7
fix python
rongou Dec 8, 2022
e12f361
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 12, 2022
8f7ac3e
fix dsplit for local mode
rongou Dec 12, 2022
fa7a670
fix java bulid
rongou Dec 12, 2022
afc5fa0
fix R package
rongou Dec 12, 2022
31b7112
fix demo
rongou Dec 12, 2022
32d7fcc
fix line too long
rongou Dec 12, 2022
c857cd9
fix r doc
rongou Dec 12, 2022
aa0c26c
update roxgen
rongou Dec 12, 2022
cbd1a42
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 13, 2022
d7830cb
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 15, 2022
c9ee1d6
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 15, 2022
6782dd9
add XGDMatrixCreateFromFileV2
rongou Dec 15, 2022
86226e0
add a test for v2
rongou Dec 16, 2022
914df2a
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 16, 2022
bde1e4c
add need_split to json config
rongou Dec 16, 2022
55f8aa4
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 19, 2022
9002705
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 20, 2022
c80a3ae
change to uri
rongou Dec 20, 2022
58ae574
remove need_split as a parameter
rongou Dec 20, 2022
f6148a3
fix python
rongou Dec 20, 2022
da7d545
fix dask test
rongou Dec 20, 2022
417dc18
Merge remote-tracking branch 'upstream/master' into data-split-param
rongou Dec 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ Imports:
methods,
data.table (>= 1.9.6),
jsonlite (>= 1.0),
RoxygenNote: 7.2.2
RoxygenNote: 7.2.3
Encoding: UTF-8
SystemRequirements: GNU make, C++14
5 changes: 3 additions & 2 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
#' It is useful when a 0 or some other extreme value represents missing values in data.
#' @param silent whether to suppress printing an informational message after loading from a file.
#' @param dsplit data split mode.
#' @param nthread Number of threads used for creating DMatrix.
#' @param ... the \code{info} data could be passed directly as parameters, without creating an \code{info} list.
#'
Expand All @@ -23,14 +24,14 @@
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
#' @export
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) {
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, dsplit = 0, nthread = NULL, ...) {
cnames <- NULL
if (typeof(data) == "character") {
if (length(data) > 1)
stop("'data' has class 'character' and length ", length(data),
".\n 'data' accepts either a numeric matrix or a single filename.")
data <- path.expand(data)
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent), as.integer(dsplit))
} else if (is.matrix(data)) {
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
cnames <- colnames(data)
Expand Down
3 changes: 3 additions & 0 deletions R-package/man/xgb.DMatrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions R-package/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
extern SEXP XGCheckNullPtr_R(SEXP);
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
Expand Down Expand Up @@ -77,7 +77,7 @@ static const R_CallMethodDef CallEntries[] = {
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 5},
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 5},
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 3},
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
Expand Down
5 changes: 3 additions & 2 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ XGB_DLL SEXP XGBGetGlobalConfig_R() {
return mkString(json_str);
}

XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent, SEXP dsplit) {
SEXP ret;
R_API_BEGIN();
DMatrixHandle handle;
CHECK_CALL(XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent), &handle));
CHECK_CALL(
XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent), asInteger(dsplit), &handle));
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END();
Expand Down
3 changes: 2 additions & 1 deletion R-package/src/xgboost_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ XGB_DLL SEXP XGBGetGlobalConfig_R();
* \brief load a data matrix
* \param fname name of the content
* \param silent whether print messages
* \param dsplit data split mode
* \return a loaded data matrix
*/
XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent);
XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent, SEXP dsplit);

/*!
* \brief create matrix content from dense matrix
Expand Down
5 changes: 3 additions & 2 deletions demo/c-api/basic/c-api-demo.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ if (err != 0) { \

int main() {
int silent = 0;
int dsplit = 0;
int use_gpu = 0; // set to 1 to use the GPU for training

// load the data
DMatrixHandle dtrain, dtest;
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, &dtest));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, dsplit, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, dsplit, &dtest));

// create the booster
BoosterHandle booster;
Expand Down
6 changes: 3 additions & 3 deletions doc/tutorials/c_api_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ In your application, wrap all C API function calls with the macro as follows:
.. code-block:: c

DMatrixHandle train;
safe_xgboost(XGDMatrixCreateFromFile("/path/to/training/dataset/", silent, &train));
safe_xgboost(XGDMatrixCreateFromFile("/path/to/training/dataset/", silent, dsplit, &train));

b. In a C++ application: modify the macro ``safe_xgboost`` to throw an exception upon an error.

Expand All @@ -114,7 +114,7 @@ c. Assertion technique: It works both in C/ C++. If expression evaluates to 0 (f
.. code-block:: c

DMatrixHandle dmat;
assert( XGDMatrixCreateFromFile("training_data.libsvm", 0, &dmat) == 0);
assert( XGDMatrixCreateFromFile("training_data.libsvm", 0, 0, &dmat) == 0);


2. Always remember to free the allocated space by BoosterHandle & DMatrixHandle appropriately:
Expand Down Expand Up @@ -169,7 +169,7 @@ Sample examples along with Code snippet to use C API functions

DMatrixHandle data; // handle to DMatrix
// Load the dat from file & store it in data variable of DMatrixHandle datatype
safe_xgboost(XGDMatrixCreateFromFile("/path/to/file/filename", silent, &data));
safe_xgboost(XGDMatrixCreateFromFile("/path/to/file/filename", silent, dsplit, &data));


2. You can also create a ``DMatrix`` object from a 2D Matrix using the :cpp:func:`XGDMatrixCreateFromMat`
Expand Down
1 change: 0 additions & 1 deletion doc/tutorials/saving_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ Will print out something similar to (not actual output as it's too long for demo
"learner_train_param": {
"booster": "gbtree",
"disable_default_eval_metric": "0",
"dsplit": "auto",
"objective": "reg:squarederror"
},
"metrics": [],
Expand Down
3 changes: 2 additions & 1 deletion include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ XGB_DLL int XGBGetGlobalConfig(char const **out_config);
* \brief load a data matrix
* \param fname the name of the file
* \param silent whether print messages during loading
* \param dsplit data split mode
* \param out a loaded data matrix
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out);
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, int dsplit, DMatrixHandle *out);
rongou marked this conversation as resolved.
Show resolved Hide resolved
/**
* @example c-api-demo.c
*/
Expand Down
2 changes: 2 additions & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class MetaInfo {
uint64_t num_nonzero_{0}; // NOLINT
/*! \brief label of each instance */
linalg::Tensor<float, 2> labels;
/*! \brief data split mode */
DataSplitMode data_split_mode{DataSplitMode::kNone};
/*!
* \brief the index of begin and end of a group
* needed when the learning task is ranking.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,53 @@ public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostErro
handle = out[0];
}

/**
* data split mode
*/
public enum DataSplitMode {
AUTO(0),
COL(1),
ROW(2),
NONE(3);

private final int value;

DataSplitMode(int value) {
this.value = value;
}

public int getValue() {
return value;
}
}

/**
* Create DMatrix by loading libsvm file from dataPath
*
* @param dataPath The path to the data.
* @param dataSplitMode Data split mode.
* @throws XGBoostError
*/
public DMatrix(String dataPath) throws XGBoostError {
public DMatrix(String dataPath, DataSplitMode dataSplitMode) throws XGBoostError {
if (dataPath == null) {
throw new NullPointerException("dataPath: null");
}
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
XGBoostJNI.checkCall(
XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, dataSplitMode.getValue(), out));
handle = out[0];
}

/**
* Create DMatrix by loading libsvm file from dataPath
*
* @param dataPath The path to the data.
* @throws XGBoostError
*/
public DMatrix(String dataPath) throws XGBoostError {
this(dataPath, DataSplitMode.AUTO);
}

/**
* Create DMatrix from Sparse matrix in CSR/CSC format.
* @param headers The row index of the matrix.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static void checkCall(int ret) throws XGBoostError {

public final static native String XGBGetLastError();

public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native int XGDMatrixCreateFromFile(String fname, int silent, int dsplit, long[] out);

final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
Expand Down
4 changes: 2 additions & 2 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
* Signature: (Ljava/lang/String;I[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jint jdsplit, jlongArray jout) {
DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
int ret = XGDMatrixCreateFromFile(fname, jsilent, jdsplit, &result);
JVM_CHECK_CALL(ret);
if (fname) {
jenv->ReleaseStringUTFChars(jfname, fname);
Expand Down
2 changes: 1 addition & 1 deletion jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
from abc import ABC, abstractmethod
from collections.abc import Mapping
from enum import IntEnum, unique
from functools import wraps
from inspect import Parameter, signature
from typing import (
Expand Down Expand Up @@ -624,6 +625,15 @@ def inner_f(*args: Any, **kwargs: Any) -> _T:
_deprecate_positional_args = require_keyword_args(False)


@unique
class DataSplitMode(IntEnum):
"""Supported data split mode for DMatrix."""
AUTO = 0
COL = 1
ROW = 2
NONE = 3


class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""Data Matrix used in XGBoost.

Expand Down Expand Up @@ -651,6 +661,7 @@ def __init__(
label_upper_bound: Optional[ArrayLike] = None,
feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.AUTO,
) -> None:
"""Parameters
----------
Expand Down Expand Up @@ -744,6 +755,7 @@ def __init__(
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical,
data_split_mode=data_split_mode,
)
assert handle is not None
self.handle = handle
Expand Down
6 changes: 5 additions & 1 deletion python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .core import (
_LIB,
DataIter,
DataSplitMode,
DMatrix,
_check_call,
_cuda_array_interface,
Expand Down Expand Up @@ -865,12 +866,14 @@ def _from_uri(
missing: Optional[FloatCompatible],
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.AUTO,
) -> DispatchedDataBackendReturnType:
_warn_unused_missing(data, missing)
handle = ctypes.c_void_p()
data = os.fspath(os.path.expanduser(data))
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
ctypes.c_int(1),
ctypes.c_int(data_split_mode),
ctypes.byref(handle)))
return handle, feature_names, feature_types

Expand Down Expand Up @@ -938,6 +941,7 @@ def dispatch_data_backend(
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.AUTO,
) -> DispatchedDataBackendReturnType:
'''Dispatch data for DMatrix.'''
if not _is_cudf_ser(data) and not _is_pandas_series(data):
Expand All @@ -953,7 +957,7 @@ def dispatch_data_backend(
if _is_numpy_array(data):
return _from_numpy_array(data, missing, threads, feature_names, feature_types)
if _is_uri(data):
return _from_uri(data, missing, feature_names, feature_types)
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
if _is_list(data):
return _from_list(data, missing, threads, feature_names, feature_types)
if _is_tuple(data):
Expand Down
19 changes: 15 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,25 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
API_END();
}

XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, int dsplit, DMatrixHandle *out) {
API_BEGIN();
auto data_split_mode = DataSplitMode::kNone;
auto data_split_mode = static_cast<DataSplitMode>(dsplit);
if (collective::IsFederated()) {
CHECK(data_split_mode == DataSplitMode::kAuto || data_split_mode == DataSplitMode::kNone)
<< "Precondition violated; dsplit can only be 'auto' or 'none' in federated mode";
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
data_split_mode = DataSplitMode::kNone;
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
data_split_mode = DataSplitMode::kRow;
CHECK(data_split_mode != DataSplitMode::kCol)
<< "Column-wise data split is currently not supported in distributed mode";
if (data_split_mode != DataSplitMode::kNone) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
data_split_mode = DataSplitMode::kRow;
}
} else {
CHECK(data_split_mode == DataSplitMode::kAuto || data_split_mode == DataSplitMode::kNone)
<< "Precondition violated; dsplit can only be 'auto' or 'none' in local mode";
data_split_mode = DataSplitMode::kNone;
}
xgboost_CHECK_C_ARG_PTR(fname);
xgboost_CHECK_C_ARG_PTR(out);
Expand Down
1 change: 1 addition & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
delete dmat;
return sliced;
} else {
dmat->Info().data_split_mode = data_split_mode;
return dmat;
}
}
Expand Down
1 change: 1 addition & 0 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
out->Info() = this->Info().Copy();
out->Info().num_nonzero_ = h_offset.back();
}
out->Info().data_split_mode = DataSplitMode::kCol;
return out;
}

Expand Down
Loading