Skip to content

Commit

Permalink
Add command-line options to address running out of GPU memory in post…
Browse files Browse the repository at this point in the history
…erior generation (#103)

* Remove old cudatoolkit version, because of incompatibilities

* Allow configuration of batch size for posterior generation.

* Add option to change number of cells used to estimate
posterior regularization lambda.
  • Loading branch information
alecw authored Jun 14, 2021
1 parent 4d89463 commit 2507742
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 13 deletions.
10 changes: 10 additions & 0 deletions cellbender/remove_background/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,14 @@ def add_subparser_args(subparsers: argparse) -> argparse:
help="Learning rate is multiplied by this amount each time "
"(default: %(default)s)")

subparser.add_argument("--posterior-batch-size", type=int, default=consts.PROP_POSTERIOR_BATCH_SIZE,
dest="posterior_batch_size",
help="Size of batches when creating the posterior. Reduce this to avoid "
"running out of GPU memory creating the posterior (will be slower). "
"(default: %(default)s)")
subparser.add_argument("--cells-posterior-reg-calc", type=int, default=consts.CELLS_POSTERIOR_REG_CALC,
dest="cells_posterior_reg_calc",
help="Number of cells used to estimate posterior regularization lambda. "
"(default: %(default)s)")

return subparsers
2 changes: 2 additions & 0 deletions cellbender/remove_background/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def run_remove_background(args):
try:
dataset_obj.save_to_output_file(args.output_file,
inferred_model,
posterior_batch_size=args.posterior_batch_size,
cells_posterior_reg_calc=args.cells_posterior_reg_calc,
save_plots=True)

logging.info("Completed remove-background.")
Expand Down
7 changes: 5 additions & 2 deletions cellbender/remove_background/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
REG_SCALE_EMPTY_PROB = 1.0
REG_SCALE_CELL_PROB = 10.0

# Number of cells used to esitmate posterior regularization lambda. Memory hungry.
# Number of cells used to estimate posterior regularization lambda. Memory hungry.
CELLS_POSTERIOR_REG_CALC = 100

# Posterior regularization constant's upper and lower bounds.
Expand All @@ -80,4 +80,7 @@

# Minimum number of barcodes we expect in an unfiltered `h5ad` input file.
# Throws a warning if the input has fewer than this number.
MINIMUM_BARCODES_H5AD = 1e5
MINIMUM_BARCODES_H5AD = 1e5

# reduce this if running out of GPU memory https://github.com/broadinstitute/CellBender/issues/67
PROP_POSTERIOR_BATCH_SIZE = 20
6 changes: 5 additions & 1 deletion cellbender/remove_background/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ def save_to_output_file(
self,
output_file: str,
inferred_model: 'RemoveBackgroundPyroModel',
posterior_batch_size: int,
cells_posterior_reg_calc: int,
save_plots: bool = False) -> bool:
"""Write the results of an inference procedure to an output file.
Expand Down Expand Up @@ -514,7 +516,9 @@ def save_to_output_file(
# Create posterior.
self.posterior = ProbPosterior(dataset_obj=self,
vi_model=inferred_model,
fpr=self.fpr[0])
fpr=self.fpr[0],
batch_size=posterior_batch_size,
cells_posterior_reg_calc=cells_posterior_reg_calc)

# Encoded values of latent variables.
enc = self.posterior.latents
Expand Down
22 changes: 13 additions & 9 deletions cellbender/remove_background/infer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# Posterior inference.

import logging
from abc import ABC, abstractmethod
from typing import Tuple, List, Dict, Optional

import numpy as np
import pyro
import pyro.distributions as dist
import torch
import numpy as np
import scipy.sparse as sp
import torch

import cellbender.remove_background.consts as consts

from typing import Tuple, List, Dict, Optional
from abc import ABC, abstractmethod
import logging


class Posterior(ABC):
"""Base class Posterior handles posterior count inference.
Expand Down Expand Up @@ -321,13 +321,17 @@ def __init__(self,
dataset_obj: 'SingleCellRNACountsDataset',
vi_model: 'RemoveBackgroundPyroModel',
fpr: float = 0.01,
float_threshold: float = 0.5):
float_threshold: float = 0.5,
batch_size: int = consts.PROP_POSTERIOR_BATCH_SIZE,
cells_posterior_reg_calc: int = consts.CELLS_POSTERIOR_REG_CALC):
self.vi_model = vi_model
self.use_cuda = vi_model.use_cuda
self.fpr = fpr
self.lambda_multiplier = None
self._encodings = None
self._mean = None
self.batch_size = batch_size
self.cells_posterior_reg_calc = cells_posterior_reg_calc
self.random = np.random.RandomState(seed=1234)
super(ProbPosterior, self).__init__(dataset_obj=dataset_obj,
vi_model=vi_model,
Expand All @@ -348,7 +352,7 @@ def _get_mean(self):

for i in range(lambda_mults.size):

n_cells = min(consts.CELLS_POSTERIOR_REG_CALC, cell_inds.size)
n_cells = min(self.cells_posterior_reg_calc, cell_inds.size)
if n_cells == 0:
raise ValueError('No cells found! Cannot compute expected FPR.')
cell_ind_subset = self.random.choice(cell_inds, size=n_cells, replace=False)
Expand Down Expand Up @@ -376,7 +380,7 @@ def _get_mean(self):
analyzed_bcs_only = True
data_loader = self.dataset_obj.get_dataloader(use_cuda=self.use_cuda,
analyzed_bcs_only=analyzed_bcs_only,
batch_size=20,
batch_size=self.batch_size,
shuffle=False)
barcodes = []
genes = []
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
ENV PATH=/home/user/miniconda/bin:$PATH
ENV CONDA_AUTO_UPDATE_CONDA=false

RUN conda install -y pytorch torchvision cudatoolkit=9.2 -c pytorch \
RUN conda install -y pytorch torchvision cudatoolkit -c pytorch \
&& conda install -y -c anaconda pytables \
&& conda clean -ya

Expand Down

0 comments on commit 2507742

Please sign in to comment.