Skip to content

Commit 17987c8

Browse files
gcattanpre-commit-ci[bot]qbarthelemy
authored
Add "full" strategy to NCH (#353)
* Update classification.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update classification.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update test_classification.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update pyriemann_qiskit/classification.py Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
1 parent 93ef8d3 commit 17987c8

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

pyriemann_qiskit/classification.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -867,11 +867,14 @@ class NearestConvexHull(BaseEstimator, ClassifierMixin, TransformerMixin):
867867
The number of hulls used per class, when subsampling is "random".
868868
n_samples_per_hull : int, default=15
869869
Defines how many samples are used to build a hull. -1 will include
870-
all samples per class.
871-
subsampling : {"min", "random"}, default="min"
872-
Subsampling strategy of training set to estimate distance to hulls.
873-
"min" estimates hull using the n_samples_per_hull closest matrices.
874-
"random" estimates hull using n_samples_per_hull random matrices.
870+
all samples per class. If subsampling is "full", this
871+
parameter is defaulted to -1.
872+
subsampling : {"min", "random", "full"}, default="min"
873+
Subsampling strategy of training set to estimate distance to hulls:
874+
875+
- "min" estimates hull using the n_samples_per_hull closest matrices;
876+
- "random" estimates hull using n_samples_per_hull random matrices;
877+
- "full" estimates the hull using the entire training matrices, as in [1]_.
875878
seed : float, default=None
876879
Optional random seed to use when subsampling is set to `random`.
877880
@@ -900,9 +903,14 @@ def __init__(
900903
self.subsampling = subsampling
901904
self.seed = seed
902905

903-
if subsampling not in ["min", "random"]:
906+
if subsampling not in ["min", "random", "full"]:
904907
raise ValueError(f"Unknown subsampling type {subsampling}.")
905908

909+
if subsampling == "full":
910+
# From code perspective, "full" strategy is the same as min strategy
911+
# without sorting
912+
self.n_samples_per_hull = -1
913+
906914
def fit(self, X, y):
907915
"""Fit (store the training data).
908916
@@ -996,7 +1004,7 @@ def _predict_distances(self, X):
9961004
if self.debug:
9971005
print("Total test samples:", X.shape[0])
9981006

999-
if self.subsampling == "min":
1007+
if self.subsampling == "min" or self.subsampling == "full":
10001008
self._process_sample = self._process_sample_min_hull
10011009
elif self.subsampling == "random":
10021010
self._process_sample = self._process_sample_random_hull

tests/test_classification.py

+21
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def test_qsvm_init(quantum):
6161
assert not hasattr(q, "_provider")
6262

6363

64+
def test_nch_init_full():
65+
"""Test init of NCH classifier with `full` subsampling"""
66+
67+
q = QuanticNCH(subsampling="full").fit(X=np.array([[0], [1]]), y=np.array([0, 1]))
68+
clf = q._classifier
69+
70+
assert type(clf).__name__ == "NearestConvexHull"
71+
72+
# Check "full" subsampling takes all training samples.
73+
assert clf.n_samples_per_hull == -1
74+
75+
# Check it calls the "min" strategy on prediction.
76+
def mocked_process_sample_min_hull(*args):
77+
raise ValueError("Min strategy called")
78+
79+
clf._process_sample_min_hull = mocked_process_sample_min_hull
80+
81+
with pytest.raises(ValueError):
82+
clf._predict_distances(X=[0])
83+
84+
6485
class TestQSVMSplitClasses(BinaryTest):
6586
"""Test _split_classes method of quantum classifiers"""
6687

0 commit comments

Comments
 (0)