Skip to content

A parser for scikit-learn exported text models to execute in the Java runtime.


Notifications You must be signed in to change notification settings


Folders and files

Last commit message
Last commit date

Latest commit



27 Commits

Repository files navigation


This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.



  • The tree.DecisionTreeClassifier is supported
    • Supports predict(),
    • Supports predict_proba() when export_text() configured with show_weights=True
  • The tree.RandomForestClassifier is supported
    • Supports predict(),
    • Supports predict_proba() when export_text() configured with show_weights=True


Importing Maven Dependency



As an example, a DecisionTreeClassifier model trained on the Iris dataset and exported using sklearn.tree export_text() as shown below:

>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree =, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'], show_weights=True, max_depth=sys.maxsize)
>>> print(r)

|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- class: 1
|   |--- petal width (cm) >  1.75
|   |   |--- class: 2

The exported text can then be executed in Java. Note that when calling export_text it is recommended that max_depth be set to sys.maxsize so that the tree isn't truncated.

Java Example

In this example the iris model exported using export_text is parsed, features are created as a Java Map and the decision tree is asked to predict the class.

    Reader tree = getTrainedModel("iris.model");
    final Classifier<Integer> decisionTree = DecisionTreeClassifier.parse(tree,

    Features features = Features.of("sepal length (cm)",
                "sepal width (cm)",
                "petal length (cm)",
                "petal width (cm)");
    FeatureVector fv = features.newSample();
    fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0);
    Integer prediction = decisionTree.predict(fv);


To use a RandomForestClassifier that has been trained on the Iris dataset, each of the estimators
in the classifiers need to be and exported using from sklearn.tree export export_text as shown below:

>>> from sklearn import datasets
>>> from sklearn import tree
>>> from sklearn.ensemble import RandomForestClassifier
>>> import os
>>> iris = datasets.load_iris()
>>> X =
>>> y =
>>> clf = RandomForestClassifier(n_estimators = 50, n_jobs=8)
>>> model =, y)
>>> for i, t in enumerate(clf.estimators_):
>>>     with open(os.path.join('/tmp/estimators', "iris-" + str(i) + ".txt"), "w") as file1:
>>>         text_representation = tree.export_text(t, feature_names=iris.feature_names, show_weights=True, decimals=4, max_depth=sys.maxsize)
>>>         file1.write(text_representation)

Once all the estimators are exported into /tmp/estimators, you can create a TAR archive, for example:

cd /tmp/estimators
tar -czvf /tmp/iris.tgz .

Then you can use the RandomForestClassifier class to parse the TAR archive.

    import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
    TarArchiveInputStream tree = getArchive("iris.tgz");
    final Classifier<Double> decisionTree = RandomForestClassifier.parse(tree,


Testing was done using models exported using sci-kit learn version 1.1.3, but should work with newer versions of sci-kit learn.