Skip to content

Commit

Permalink
Cuppa: Added arg --force_plot to allow cuppa vis plotting when there …
Browse files Browse the repository at this point in the history
…are too many samples
  • Loading branch information
luan-n-nguyen committed Dec 13, 2024
1 parent 1c88a06 commit a5e7e98
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
2 changes: 1 addition & 1 deletion cuppa/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ Below are all arguments that can be passed to `CuppaDataPrep`. Superscript numbe
| `-ref_genome_version` | V37 | Valid values: V37 (default), V38 |
| `-threads` | 8 | Number of threads to use. Each thread processes one sample at a time |
| `-write_by_category` | | Flag. Split output of `CuppaDataPrep` over multiple files |
| `-progress_interval` | 100 | Print progress per this number of samples in multi-sample mode |
| `-log_level` | DEBUG | Set log level to one of: ERROR, WARN, INFO, DEBUG or TRACE |
| `-log_debug` | | Flag. Set log level to DEBUG |

Expand Down Expand Up @@ -267,6 +266,7 @@ The below table lists all possible arguments for training and/or predicting.
| `--metadata_path` | Train | [Required] Path to the metadata file with cancer type labels per sample |
| `--cv_predictions_path` | Predict | Path to a CuppaPrediction tsv file containing the cross-validation predictions. Samples found in <br/>this file will have their predictions returned from this file instead of being computed |
| `--compress_tsv_files` | Predict | Compress tsv files with gzip? (will add .gz to the file extension) |
| `--force_plot` | Predict | Force plotting when there are >10 samples in multi-sample mode |
| `--excl_classes` | Train | Comma separated list of cancer subtypes to exclude from training. E.g. 'Breast' or 'Breast,Lung'. Default: '_Other,_Unknown' |
| `--min_samples_with_rna` | Train | Minimum number of samples with RNA in each cancer subtype. If the cancer subtype has fewer samples with RNA than this value, the cancer subtype will be excluded from training. Default: 5 |
| `--fusion_overrides_path` | Train | Path to the fusion overrides tsv file. See section [FusionProbOverrider](#fusionproboverrider) for how this file should be formatted |
Expand Down
8 changes: 8 additions & 0 deletions cuppa/src/main/python/pycuppa/cuppa/runners/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse

import cuppa.visualization.visualization
from cuppa.constants import DEFAULT_FUSION_OVERRIDES_PATH


Expand All @@ -20,6 +21,7 @@ class DEFAULT_RUNNER_ARGS:
cv_folds: int = 10
cache_training: bool = True
n_jobs: int = 1
force_plot: bool = False

log_to_file: bool = False
log_path: str | None = None
Expand Down Expand Up @@ -94,6 +96,11 @@ class RunnerArgs:
help="Path to the fusion overrides tsv file"
)

force_plot = dict(
action="store_true",
help="Force plotting when number of samples is >%s" % cuppa.visualization.visualization.CuppaVisPlotter.PLOT_MAX_SAMPLES
)

## Cross-validation / training ================================
cv_predictions_path = dict(
help="Path to a CuppaPrediction tsv file containing the cross-validation predictions."
Expand Down Expand Up @@ -166,6 +173,7 @@ def get_kwargs_predict(self) -> dict:
"clf_group",
"cv_predictions_path",
"compress_tsv_files",
"force_plot",
"log_to_file",
"log_path",
"log_format"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
output_dir: str,
sample_id: str | None = None,
compress_tsv_files: bool = False,
force_plot: bool = False,
cv_predictions_path: str = None,
cv_predictions: CuppaPrediction | None = None,
clf_group: str = DEFAULT_RUNNER_ARGS.clf_group,
Expand All @@ -35,6 +36,7 @@ def __init__(
self.output_dir = output_dir
self.sample_id = sample_id
self.compress_tsv_files = compress_tsv_files
self.force_plot = force_plot
self.classifier_path = classifier_path

self.cv_predictions_path = cv_predictions_path
Expand Down Expand Up @@ -180,4 +182,10 @@ def run(self) -> None:
self.pred_summ.to_tsv(self.pred_summ_path, verbose=True)
self.vis_data.to_tsv(self.vis_data_path, verbose=True)

CuppaVisPlotter.from_tsv(path=self.vis_data_path, plot_path=self.plot_path, verbose=True).plot()
plotter = CuppaVisPlotter.from_tsv(
path=self.vis_data_path,
plot_path=self.plot_path,
force_plot=self.force_plot,
verbose=True
)
plotter.plot()
42 changes: 32 additions & 10 deletions cuppa/src/main/python/pycuppa/cuppa/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,16 @@ def __init__(
self,
vis_data: CuppaVisData,
plot_path: str,
force_plot: bool = False,
verbose: bool = True
):
self.vis_data = vis_data
self.plot_path = os.path.expanduser(plot_path)
self.force_plot = force_plot
self.verbose = verbose

self.vis_data_path: str | None = None

self._check_number_of_samples()
self._check_plot_path_extension()

@classmethod
def from_tsv(cls, path: str, **kwargs) -> CuppaVisPlotter:
## This method exists to avoid writing a temporary vis data file if a vis data file already exists
Expand All @@ -289,18 +288,36 @@ def from_tsv(cls, path: str, **kwargs) -> CuppaVisPlotter:
plotter.vis_data_path = path
return plotter

def _check_number_of_samples(self):
sample_ids = self.vis_data["sample_id"].dropna().unique()
@property
def sample_ids(self):
return self.vis_data["sample_id"].dropna().unique()

PLOT_MAX_SAMPLES = 10


def _check_should_plot(self) -> bool:

n_samples = len(self.sample_ids)

max_samples = 25
if len(sample_ids) > max_samples:
self.logger.error("Plotting predictions for >", max_samples, " is not supported")
raise RuntimeError
if n_samples <= self.PLOT_MAX_SAMPLES:
return True

elif self.force_plot:
self.logger.info(f"Forcing plotting predictions for {n_samples} (>{self.PLOT_MAX_SAMPLES}) samples")
return True

else:
self.logger.warning(f"Skipping plotting predictions for {n_samples} (>{self.PLOT_MAX_SAMPLES}) samples. Please use arg --force_plot to plot anyway")
return False

def _check_plot_path_extension(self):
if not self.plot_path.endswith((".pdf", ".png")):
self.logger.error("`plot_path` must end with .pdf or .png")
raise ValueError
raise Exception

if len(self.sample_ids) > 1 and not self.plot_path.endswith(".pdf"):
self.logger.error("`plot_path` must end with .pdf for multi-sample plotting")
raise Exception

@property
def _tmp_vis_data_path(self) -> str:
Expand All @@ -321,6 +338,11 @@ def _remove_tmp_vis_data(self):
def plot(self) -> None:

try:
self._check_plot_path_extension()

if not self._check_should_plot():
return

if self.vis_data_path is None or not os.path.exists(self.vis_data_path):
self._write_tmp_vis_data()
self.vis_data_path = self._tmp_vis_data_path
Expand Down

0 comments on commit a5e7e98

Please sign in to comment.