Skip to content

Commit

Permalink
[tmva][pymva] Change default algorithm for AdaBoost to SAMME
Browse files Browse the repository at this point in the history
Since scikit version 1.4 the SAMME.R algorithm is deprecated and it has been removed since version 1.6.
Change default to use SAMME
See https://scikit-learn.org/1.5/modules/generated/sklearn.ensemble.AdaBoostClassifier.html
  • Loading branch information
lmoneta committed Dec 10, 2024
1 parent 3d8b3fc commit 736b402
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tmva/pymva/src/MethodPyAdaBoost.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ MethodPyAdaBoost::MethodPyAdaBoost(const TString &jobName,
fBaseEstimator("None"),
fNestimators(50),
fLearningRate(1.0),
fAlgorithm("SAMME.R"),
fAlgorithm("SAMME"),
fRandomState("None")
{
}
Expand All @@ -79,7 +79,7 @@ MethodPyAdaBoost::MethodPyAdaBoost(DataSetInfo &theData,
fBaseEstimator("None"),
fNestimators(50),
fLearningRate(1.0),
fAlgorithm("SAMME.R"),
fAlgorithm("SAMME"),
fRandomState("None")
{
}
Expand Down Expand Up @@ -116,12 +116,13 @@ void MethodPyAdaBoost::DeclareOptions()
``learning_rate``. There is a trade-off between ``learning_rate`` and\
``n_estimators``.");

DeclareOptionRef(fAlgorithm, "Algorithm", "{'SAMME', 'SAMME.R'}, optional (default='SAMME.R')\
DeclareOptionRef(fAlgorithm, "Algorithm", "{'SAMME', 'SAMME.R'}, optional (default='SAMME')\
If 'SAMME.R' then use the SAMME.R real boosting algorithm.\
``base_estimator`` must support calculation of class probabilities.\
If 'SAMME' then use the SAMME discrete boosting algorithm.\
The SAMME.R algorithm typically converges faster than SAMME,\
achieving a lower test error with fewer boosting iterations.");
achieving a lower test error with fewer boosting iterations.\
'SAME.R' is deprecated since version 1.4 and removed since 1.6");

DeclareOptionRef(fRandomState, "RandomState", "int, RandomState instance or None, optional (default=None)\
If int, random_state is the seed used by the random number generator;\
Expand Down Expand Up @@ -309,11 +310,11 @@ std::vector<Double_t> MethodPyAdaBoost::GetMvaValues(Long64_t firstEvt, Long64_t
Py_DECREF(result);

if (logProgress) {
Log() << kINFO
Log() << kINFO
<< "Elapsed time for evaluation of " << nEvents << " events: "
<< timer.GetElapsedTime() << " " << Endl;
}

return mvaValues;
}

Expand Down

0 comments on commit 736b402

Please sign in to comment.