Skip to content

Commit

Permalink
Merge pull request #6 from TarikExner/SOMEstimator_seed
Browse files Browse the repository at this point in the history
Fix SOMEstimator seed, contributed by @TarikExner
  • Loading branch information
berombau authored Apr 16, 2024
2 parents 5b96955 + 343f228 commit 58ed09a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 18 deletions.
31 changes: 14 additions & 17 deletions src/FlowSOM/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.base import check_is_fitted

from flowsom.io import read_csv, read_FCS
from flowsom.models.base_flowsom_estimator import BaseFlowSOMEstimator
from flowsom.models.flowsom_estimator import FlowSOMEstimator
from flowsom.tl import get_channels, get_markers

Expand All @@ -24,37 +25,31 @@ class FlowSOM:
def __init__(
self,
inp,
n_clusters,
cols_to_use=None,
model=FlowSOMEstimator,
xdim=10,
ydim=10,
rlen=10,
mst=1,
alpha=(0.05, 0.01),
n_clusters: int,
cols_to_use: np.ndarray | None = None,
model: type[BaseFlowSOMEstimator] = FlowSOMEstimator,
xdim: int = 10,
ydim: int = 10,
rlen: int = 10,
mst: int = 1,
alpha: tuple[float, float] = (0.05, 0.01),
seed: int | None = None,
mad_allowed=4,
**kwargs,
):
"""Initialize the FlowSOM AnnData object.
:param inp: An AnnData or filepath to an FCS file
:param n_clusters: The number of clusters
:type n_clusters: int
:param xdim: The x dimension of the SOM
:type xdim: int
:param ydim: The y dimension of the SOM
:type ydim: int
:param rlen: Number of times to loop over the training data for each MST
:type rlen: int
:param mst: Number of times to loop over the training data for each MST
:type mst: int
:param alpha: The learning rate
:type alpha: tuple
:param seed: The random seed to use
:param cols_to_use: The columns to use for clustering
:type cols_to_use: np.array
:param mad_allowed: Number of median absolute deviations allowed
:type mad_allowed: int
:param model: The model to use
:type model: FlowSOMEstimator
:param kwargs: Additional keyword arguments. See documentation of the cluster_model and metacluster_model for more information.
:type kwargs: dict
"""
Expand All @@ -66,6 +61,7 @@ def __init__(
self.rlen = rlen
self.mst = mst
self.alpha = alpha
self.seed = seed
# metacluster model params
self.n_clusters = n_clusters

Expand All @@ -75,6 +71,7 @@ def __init__(
rlen=rlen,
mst=mst,
alpha=alpha,
seed=seed,
n_clusters=n_clusters,
**kwargs,
)
Expand Down
4 changes: 4 additions & 0 deletions src/FlowSOM/models/som_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def fit(
# Initialize the grid
grid = [(x, y) for x in range(xdim) for y in range(ydim)]
n_codes = len(grid)

if self.seed is not None:
np.random.seed(self.seed)

if codes is None:
if init:
codes = self.initf(X, xdim, ydim)
Expand Down
18 changes: 18 additions & 0 deletions tests/models/test_FlowSOMModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,21 @@ def test_clustering_v_measure(X_and_y):
y_pred = som.fit_predict(X)
score = v_measure_score(y_true, y_pred)
assert score > 0.7


def test_reproducibility_no_seed(X):
fsom_1 = FlowSOMEstimator(n_clusters=10)
fsom_2 = FlowSOMEstimator(n_clusters=10)
y_pred_1 = fsom_1.fit_predict(X)
y_pred_2 = fsom_2.fit_predict(X)

assert not all(y_pred_1 == y_pred_2)


def test_reproducibility_seed(X):
fsom_1 = FlowSOMEstimator(n_clusters=10, seed=0)
fsom_2 = FlowSOMEstimator(n_clusters=10, seed=0)
y_pred_1 = fsom_1.fit_predict(X)
y_pred_2 = fsom_2.fit_predict(X)

assert all(y_pred_1 == y_pred_2)
20 changes: 19 additions & 1 deletion tests/models/test_SOMModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,26 @@ def test_clustering(X):


def test_clustering_v_measure(X_and_y):
som = SOMEstimator()
som = SOMEstimator(seed=1)
X, y_true = X_and_y
y_pred = som.fit_predict(X)
score = v_measure_score(y_true, y_pred)
assert score > 0.7


def test_reproducibility_no_seed(X):
som_1 = SOMEstimator(seed=None)
som_2 = SOMEstimator(seed=None)
codes_1 = som_1.fit(X).codes.flatten()
codes_2 = som_2.fit(X).codes.flatten()

assert not all(codes_1 == codes_2)


def test_reproducibility_seed(X):
som_1 = SOMEstimator(seed=1)
som_2 = SOMEstimator(seed=1)
codes_1 = som_1.fit(X).codes.flatten()
codes_2 = som_2.fit(X).codes.flatten()

assert all(codes_1 == codes_2)

0 comments on commit 58ed09a

Please sign in to comment.