Skip to content

Commit

Permalink
Use RealNotInt for parameters that accept ints and floats
Browse files Browse the repository at this point in the history
  • Loading branch information
sebp committed Dec 31, 2024
1 parent e128071 commit 0e9ca7e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions sksurv/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.tree._splitter import Splitter
from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
from sklearn.tree._utils import _any_isnan_axis0
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
from sklearn.utils.validation import (
_assert_all_finite_element_wise,
_check_n_features,
Expand Down Expand Up @@ -161,16 +161,16 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
"min_samples_split": [
Interval(Integral, 2, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="neither"),
Interval(RealNotInt, 0.0, 1.0, closed="neither"),
],
"min_samples_leaf": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0.0, 0.5, closed="right"),
Interval(RealNotInt, 0.0, 0.5, closed="right"),
],
"min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
"max_features": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="right"),
Interval(RealNotInt, 0.0, 1.0, closed="right"),
StrOptions({"sqrt", "log2"}),
None,
],
Expand Down Expand Up @@ -363,7 +363,7 @@ def _check_params(self, n_samples):

max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes

if isinstance(self.min_samples_leaf, (Integral, np.integer)):
if isinstance(self.min_samples_leaf, Integral):
min_samples_leaf = self.min_samples_leaf
else: # float
min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
Expand Down Expand Up @@ -397,7 +397,7 @@ def _check_max_features(self):

elif self.max_features is None:
max_features = self.n_features_in_
elif isinstance(self.max_features, (Integral, np.integer)):
elif isinstance(self.max_features, Integral):
max_features = self.max_features
else: # float
if self.max_features > 0.0:
Expand Down

0 comments on commit 0e9ca7e

Please sign in to comment.