Skip to content

Commit

Permalink
Add loading CatBoostModel from a byte array to API.. Fix #2539
Browse files Browse the repository at this point in the history
11ec400f27df89ff41b2aee36cc3271774b57dff
  • Loading branch information
andrey-khropov committed Feb 23, 2024
1 parent 56a0b44 commit a13b5ba
Showing 1 changed file with 29 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,34 @@ public static CatBoostModel loadModel(final @NotNull String modelPath, @NotNull
return new CatBoostModel(handles[0]);
}

/**
* Load CatBoost model serialized in an array.
*
* @param serializedModel Byte array containing model.
* @return CatBoost model.
* @throws CatBoostError When failed to load model.
*/
@NotNull
public static CatBoostModel loadModel(final @NotNull byte[] serializedModel) throws CatBoostError {
return loadModel(serializedModel, "bin");
}

/**
* Load CatBoost model serialized in an array.
*
* @param serializedModel Byte array containing model.
* @param modelFormat Model file format (bin or json)
* @return CatBoost model.
* @throws CatBoostError When failed to load model.
*/
@NotNull
public static CatBoostModel loadModel(final @NotNull byte[] serializedModel, @NotNull String modelFormat) throws CatBoostError {
final long[] handles = new long[1];

implLibrary.catBoostLoadModelFromArray(serializedModel, handles, modelFormat);
return new CatBoostModel(handles[0]);
}

/**
* Load CatBoost model from stream.
*
Expand All @@ -276,7 +304,6 @@ public static CatBoostModel loadModel(final InputStream in) throws CatBoostError
*/
@NotNull
public static CatBoostModel loadModel(final InputStream in, @NotNull String modelFormat) throws CatBoostError, IOException {
final long[] handles = new long[1];
final byte[] copyBuffer = new byte[4 * 1024];

int bytesRead;
Expand All @@ -286,8 +313,7 @@ public static CatBoostModel loadModel(final InputStream in, @NotNull String mode
out.write(copyBuffer, 0, bytesRead);
}

implLibrary.catBoostLoadModelFromArray(out.toByteArray(), handles, modelFormat);
return new CatBoostModel(handles[0]);
return loadModel(out.toByteArray(), modelFormat);
}

/**
Expand Down

0 comments on commit a13b5ba

Please sign in to comment.