Skip to content

Commit

Permalink
Removed unnecessary test and fixed initialize
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian Jorgensen committed May 2, 2024
1 parent c1a1b9e commit c25c850
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 18 deletions.
17 changes: 8 additions & 9 deletions src/skmatter/_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ class _FPS(GreedySelector):
Parameters
----------
initialize: int, list of int, ndarray of int, or 'random', default=0
initialize: int, list of int, numpy.ndarray of int, or 'random', default=0
Index of the first selection(s). If 'random', picks a random
value when fit starts. Stored in :py:attr:`self.initialize`.
Expand Down Expand Up @@ -1038,14 +1038,7 @@ def _init_greedy_search(self, X, y, n_to_select):
self.hausdorff_ = np.full(X.shape[self._axis], np.inf)
self.hausdorff_at_select_ = np.full(X.shape[self._axis], np.inf)

if isinstance(self.initialize, np.ndarray):
if all(isinstance(i, numbers.Integral) for i in self.initialize):
for i, val in enumerate(self.initialize):
self.selected_idx_[i] = val
self._update_post_selection(X, y, self.selected_idx_[i])
else:
raise ValueError("Initialize parameter must contain only int")
elif self.initialize == "random":
if self.initialize == "random":
random_state = check_random_state(self.random_state)
initialize = random_state.randint(X.shape[self._axis])
self.selected_idx_[0] = initialize
Expand All @@ -1060,6 +1053,12 @@ def _init_greedy_search(self, X, y, n_to_select):
for i, val in enumerate(self.initialize):
self.selected_idx_[i] = val
self._update_post_selection(X, y, self.selected_idx_[i])
elif isinstance(self.initialize, np.ndarray) and all(
isinstance(i, numbers.Integral) for i in self.initialize
):
for i, val in enumerate(self.initialize):
self.selected_idx_[i] = val
self._update_post_selection(X, y, self.selected_idx_[i])

else:
raise ValueError("Invalid value of the initialize parameter")
Expand Down
9 changes: 0 additions & 9 deletions tests/test_feature_simple_fps.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@ def test_initialize(self):
for i in range(4):
self.assertEqual(selector.selected_idx_[i], self.idx[i])

initialize = np.array([1, 5, 3, 0.25])
with self.subTest(initialize=initialize):
with self.assertRaises(ValueError) as cm:
selector = FPS(n_to_select=len(self.idx) - 1, initialize=initialize)
selector.fit(self.X)
self.assertEqual(
str(cm.exception), "Initialize parameter must contain only int"
)

with self.assertRaises(ValueError) as cm:
selector = FPS(n_to_select=1, initialize="bad")
selector.fit(self.X)
Expand Down

0 comments on commit c25c850

Please sign in to comment.