diff --git a/muon/_core/preproc.py b/muon/_core/preproc.py index 3ab4244..30d5489 100644 --- a/muon/_core/preproc.py +++ b/muon/_core/preproc.py @@ -471,7 +471,7 @@ def neighbors( weights = softmax(ratios, axis=1) neighbordistances = csr_matrix((mdata.n_obs, mdata.n_obs), dtype=np.float64) - largeidx = mdata.n_obs ** 2 > np.iinfo(np.int32).max + largeidx = mdata.n_obs**2 > np.iinfo(np.int32).max if largeidx: # work around scipy bug https://github.com/scipy/scipy/issues/13155 neighbordistances.indptr = neighbordistances.indptr.astype(np.int64) neighbordistances.indices = neighbordistances.indices.astype(np.int64) diff --git a/muon/_core/tools.py b/muon/_core/tools.py index bc868f3..e2c92aa 100644 --- a/muon/_core/tools.py +++ b/muon/_core/tools.py @@ -302,8 +302,13 @@ def mofa( spikeslab_factors: bool = False, n_iterations: int = 1000, convergence_mode: str = "fast", - gpu_mode: bool = False, use_float32: bool = False, + gpu_mode: bool = False, + svi_mode: bool = False, + svi_batch_size: float = 0.5, + svi_learning_rate: float = 1.0, + svi_forgetting_rate: float = 0.5, + svi_start_stochastic: int = 1, smooth_covariate: Optional[str] = None, smooth_warping: bool = False, smooth_kwargs: Optional[Mapping[str, Any]] = None, @@ -361,6 +366,16 @@ def mofa( use reduced precision (float32) gpu_mode : optional if to use GPU mode + svi_mode : optional + if to use Stochastic Variational Inference (SVI) + svi_batch_size : optional + batch size as a fraction (only applicable when svi_mode=True, 0.5 by default) + svi_learning_rate : optional + learning rate (only applicable when svi_mode=True, 1.0 by default) + svi_forgetting_rate : optional + forgetting_rate (only applicable when svi_mode=True, 0.5 by default) + svi_start_stochastic : optional + first iteration to start SVI (only applicable when svi_mode=True, 1 by default) smooth_covariate : optional use a covariate (column in .obs) to learn smooth factors (MEFISTO) smooth_warping : optional @@ -481,6 +496,15 @@ def mofa( save_interrupted=save_interrupted, ) + if svi_mode: + logging.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Setting up SVI...") + ent.set_stochastic_options( + learning_rate=svi_learning_rate, + forgetting_rate=svi_forgetting_rate, + batch_size=svi_batch_size, + start_stochastic=svi_start_stochastic, + ) + # MEFISTO options smooth_kwargs_default = dict( diff --git a/tests/test_muon_tools.py b/tests/test_muon_tools.py index c422bec..9f01d3b 100644 --- a/tests/test_muon_tools.py +++ b/tests/test_muon_tools.py @@ -36,7 +36,7 @@ def test_mofa_nfactors(self): r2 = [] for i in range(n_factors): yhat = np.dot(self.mdata.obsm["X_mofa"][:, [i]], self.mdata.varm["LFs"][:, [i]].T) - r2.append(1 - np.sum((y - yhat) ** 2) / np.sum(y ** 2)) + r2.append(1 - np.sum((y - yhat) ** 2) / np.sum(y**2)) # Only first 5 factors should have high R2 self.assertTrue(all([i > 0.1 for i in r2[:5]]))