Skip to content

Commit

Permalink
Backport PR scverse#2605 on branch 1.2.x ([feat] add support for data…
Browse files Browse the repository at this point in the history
…modules in autotune) (scverse#2606)

Backport PR scverse#2605: [feat] add support for datamodules in autotune

Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and martinkim0 authored Mar 7, 2024
1 parent 994363a commit c3d719e
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 34 deletions.
2 changes: 2 additions & 0 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/
{class}`scvi.module.base.EmbeddingModuleMixin` {pr}`2576`.
- Add option to generate synthetic spatial coordinates in {func}`scvi.data.synthetic_iid` with
argument `generate_coordinates` {pr}`2603`.
- Add experimental support for using custom {class}`lightning.pytorch.core.LightningDataModule`s
in {func}`scvi.autotune.run_autotune` {pr}`2605`.

#### Changed

Expand Down
58 changes: 35 additions & 23 deletions scvi/autotune/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from os.path import join
from typing import Any, Literal

from anndata import AnnData
from lightning.pytorch import LightningDataModule
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger
from mudata import MuData
from ray.tune import ResultGrid, Tuner
from ray.tune.schedulers import TrialScheduler
from ray.tune.search import SearchAlgorithm
Expand All @@ -27,9 +30,9 @@ class AutotuneExperiment:
model_cls
Model class on which to tune hyperparameters. Must implement a constructor and a ``train``
method.
adata
data
:class:`~anndata.AnnData` or :class:`~mudata.MuData` that has been setup with
``model_cls``.
``model_cls`` or a :class:`~lightning.pytorch.core.LightningDataModule` (``EXPERIMENTAL``).
metrics
Either a single metric or a list of metrics to track during the experiment. If a list is
provided, the primary metric will be the first element in the list.
Expand Down Expand Up @@ -95,7 +98,7 @@ class AutotuneExperiment:
def __init__(
self,
model_cls: BaseModelClass,
adata: AnnOrMuData,
data: AnnOrMuData | LightningDataModule,
metrics: str | list[str],
mode: Literal["min", "max"],
search_space: dict[str, dict[Literal["model_args", "train_args"], dict[str, Any]]],
Expand All @@ -110,7 +113,7 @@ def __init__(
searcher_kwargs: dict | None = None,
) -> None:
self.model_cls = model_cls
self.adata = adata
self.data = data
self.metrics = metrics
self.mode = mode
self.search_space = search_space
Expand Down Expand Up @@ -145,34 +148,39 @@ def model_cls(self, value: BaseModelClass) -> None:
self._model_cls = value

@property
def adata(self) -> AnnOrMuData:
""":class:`~anndata.AnnData` or :class:`~mudata.MuData` for the experiment."""
return self._adata
def data(self) -> AnnOrMuData | LightningDataModule:
"""Data on which to tune hyperparameters."""
return self._data

@adata.setter
def adata(self, value: AnnOrMuData) -> None:
@data.setter
def data(self, value: AnnOrMuData | LightningDataModule) -> None:
from scvi.data._constants import _SETUP_ARGS_KEY, _SETUP_METHOD_NAME

if hasattr(self, "_adata"):
raise AttributeError("Cannot reassign `adata`")
if hasattr(self, "_data"):
raise AttributeError("Cannot reassign `data`")

data_manager = self.model_cls._get_most_recent_anndata_manager(value, required=True)
self._adata = value
self._setup_method_name = data_manager._registry.get(_SETUP_METHOD_NAME, "setup_anndata")
self._setup_method_args = data_manager._get_setup_method_args().get(_SETUP_ARGS_KEY, {})
self._data = value
if isinstance(value, (AnnData, MuData)):
data_manager = self.model_cls._get_most_recent_anndata_manager(value, required=True)
self._setup_method_name = data_manager._registry.get(
_SETUP_METHOD_NAME, "setup_anndata"
)
self._setup_method_args = data_manager._get_setup_method_args().get(
_SETUP_ARGS_KEY, {}
)

@property
def setup_method_name(self) -> str:
"""Either ``"setup_anndata"`` or ``"setup_mudata"``."""
if not hasattr(self, "_setup_method_name"):
raise AttributeError("`setup_method_name` not yet available.")
raise AttributeError("`setup_method_name` not available.")
return self._setup_method_name

@property
def setup_method_args(self) -> dict[str, Any]:
"""Keyword arguments for the setup method."""
if not hasattr(self, "_setup_method_args"):
raise AttributeError("`setup_method_args` not yet available.")
raise AttributeError("`setup_method_args` not available.")
return self._setup_method_args

@property
Expand Down Expand Up @@ -523,9 +531,13 @@ def _trainable(
}

settings.seed = experiment.seed
getattr(experiment.model_cls, experiment.setup_method_name)(
experiment.adata,
**experiment.setup_method_args,
)
model = experiment.model_cls(experiment.adata, **model_args)
model.train(**train_args)
if isinstance(experiment.data, (AnnData, MuData)):
getattr(experiment.model_cls, experiment.setup_method_name)(
experiment.data,
**experiment.setup_method_args,
)
model = experiment.model_cls(experiment.data, **model_args)
model.train(**train_args)
else:
model = experiment.model_cls(**model_args)
model.train(data_module=experiment.data, **train_args)
10 changes: 6 additions & 4 deletions scvi/autotune/_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from typing import Any, Literal

from lightning.pytorch import LightningDataModule

from scvi._types import AnnOrMuData
from scvi.autotune._experiment import AutotuneExperiment
from scvi.model.base import BaseModelClass
Expand All @@ -12,7 +14,7 @@

def run_autotune(
model_cls: BaseModelClass,
adata: AnnOrMuData,
data: AnnOrMuData | LightningDataModule,
metrics: str | list[str],
mode: Literal["min", "max"],
search_space: dict[str, dict[Literal["model_args", "train_args"], dict[str, Any]]],
Expand All @@ -32,9 +34,9 @@ def run_autotune(
----------
model_cls
Model class on which to tune hyperparameters.
adata
data
:class:`~anndata.AnnData` or :class:`~mudata.MuData` that has been setup with
`model_cls``.
``model_cls`` or a :class:`~lightning.pytorch.core.LightningDataModule` (``EXPERIMENTAL``).
metrics
Either a single metric or a list of metrics to track during the experiment. If a list is
provided, the primary metric will be the first element in the list.
Expand Down Expand Up @@ -107,7 +109,7 @@ def run_autotune(

experiment = AutotuneExperiment(
model_cls,
adata,
data,
metrics,
mode,
search_space,
Expand Down
6 changes: 3 additions & 3 deletions scvi/train/_trainrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def __call__(self):
self._update_history()

# data splitter only gets these attrs after fit
self.model.train_indices = self.data_splitter.train_idx
self.model.test_indices = self.data_splitter.test_idx
self.model.validation_indices = self.data_splitter.val_idx
self.model.train_indices = getattr(self.data_splitter, "train_idx", None)
self.model.test_indices = getattr(self.data_splitter, "test_idx", None)
self.model.validation_indices = getattr(self.data_splitter, "val_idx", None)

self.model.module.eval()
self.model.is_trained_ = True
Expand Down
8 changes: 4 additions & 4 deletions tests/autotune/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def test_experiment_init(save_path: str):
with raises(AttributeError):
experiment.id = "new_id"

assert hasattr(experiment, "adata")
assert experiment.adata is not None
assert experiment.adata is adata
assert hasattr(experiment, "data")
assert experiment.data is not None
assert experiment.data is adata
with raises(AttributeError):
experiment.adata = "new_adata"
experiment.data = "new_adata"

assert hasattr(experiment, "setup_method_name")
assert experiment.setup_method_name is not None
Expand Down
35 changes: 35 additions & 0 deletions tests/autotune/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,38 @@ def test_run_autotune_scvi_basic(save_path: str):
assert isinstance(experiment, AutotuneExperiment)
assert hasattr(experiment, "result_grid")
assert isinstance(experiment.result_grid, ResultGrid)


def test_run_autotune_scvi_no_anndata(save_path: str, n_batches: int = 3):
from scvi.dataloaders import DataSplitter

settings.logging_dir = save_path
adata = synthetic_iid(n_batches=n_batches)
SCVI.setup_anndata(adata, batch_key="batch")
manager = SCVI._get_most_recent_anndata_manager(adata)

data_module = DataSplitter(manager)
data_module.n_vars = adata.n_vars
data_module.n_batch = n_batches

experiment = run_autotune(
SCVI,
data_module,
metrics=["elbo_validation"],
mode="min",
search_space={
"model_args": {
"n_hidden": tune.choice([1, 2]),
},
"train_args": {
"max_epochs": 1,
},
},
num_samples=1,
seed=0,
scheduler="asha",
searcher="hyperopt",
)
assert isinstance(experiment, AutotuneExperiment)
assert hasattr(experiment, "result_grid")
assert isinstance(experiment.result_grid, ResultGrid)

0 comments on commit c3d719e

Please sign in to comment.