diff --git a/cuppa/README.md b/cuppa/README.md index 6873880c27..fcfb1b5fac 100644 --- a/cuppa/README.md +++ b/cuppa/README.md @@ -153,7 +153,6 @@ Below are all arguments that can be passed to `CuppaDataPrep`. Superscript numbe | `-ref_genome_version` | V37 | Valid values: V37 (default), V38 | | `-threads` | 8 | Number of threads to use. Each thread processes one sample at a time | | `-write_by_category` | | Flag. Split output of `CuppaDataPrep` over multiple files | -| `-progress_interval` | 100 | Print progress per this number of samples in multi-sample mode | | `-log_level` | DEBUG | Set log level to one of: ERROR, WARN, INFO, DEBUG or TRACE | | `-log_debug` | | Flag. Set log level to DEBUG | @@ -267,6 +266,7 @@ The below table lists all possible arguments for training and/or predicting. | `--metadata_path` | Train | [Required] Path to the metadata file with cancer type labels per sample | | `--cv_predictions_path` | Predict | Path to a CuppaPrediction tsv file containing the cross-validation predictions. Samples found in
this file will have their predictions returned from this file instead of being computed | | `--compress_tsv_files` | Predict | Compress tsv files with gzip? (will add .gz to the file extension) | +| `--force_plot` | Predict | Force plotting when there are >10 samples in multi-sample mode | | `--excl_classes` | Train | Comma separated list of cancer subtypes to exclude from training. E.g. 'Breast' or 'Breast,Lung'. Default: '_Other,_Unknown' | | `--min_samples_with_rna` | Train | Minimum number of samples with RNA in each cancer subtype. If the cancer subtype has fewer samples with RNA than this value, the cancer subtype will be excluded from training. Default: 5 | | `--fusion_overrides_path` | Train | Path to the fusion overrides tsv file. See section [FusionProbOverrider](#fusionproboverrider) for how this file should be formatted | diff --git a/cuppa/src/main/python/pycuppa/cuppa/runners/args.py b/cuppa/src/main/python/pycuppa/cuppa/runners/args.py index 4969a807ff..d6c22a21e8 100644 --- a/cuppa/src/main/python/pycuppa/cuppa/runners/args.py +++ b/cuppa/src/main/python/pycuppa/cuppa/runners/args.py @@ -2,6 +2,7 @@ import argparse +import cuppa.visualization.visualization from cuppa.constants import DEFAULT_FUSION_OVERRIDES_PATH @@ -20,6 +21,7 @@ class DEFAULT_RUNNER_ARGS: cv_folds: int = 10 cache_training: bool = True n_jobs: int = 1 + force_plot: bool = False log_to_file: bool = False log_path: str | None = None @@ -94,6 +96,11 @@ class RunnerArgs: help="Path to the fusion overrides tsv file" ) + force_plot = dict( + action="store_true", + help="Force plotting when number of samples is >%s" % cuppa.visualization.visualization.CuppaVisPlotter.PLOT_MAX_SAMPLES + ) + ## Cross-validation / training ================================ cv_predictions_path = dict( help="Path to a CuppaPrediction tsv file containing the cross-validation predictions." @@ -166,6 +173,7 @@ def get_kwargs_predict(self) -> dict: "clf_group", "cv_predictions_path", "compress_tsv_files", + "force_plot", "log_to_file", "log_path", "log_format" diff --git a/cuppa/src/main/python/pycuppa/cuppa/runners/prediction_runner.py b/cuppa/src/main/python/pycuppa/cuppa/runners/prediction_runner.py index 0ea371dd9c..b495b70fc2 100644 --- a/cuppa/src/main/python/pycuppa/cuppa/runners/prediction_runner.py +++ b/cuppa/src/main/python/pycuppa/cuppa/runners/prediction_runner.py @@ -23,6 +23,7 @@ def __init__( output_dir: str, sample_id: str | None = None, compress_tsv_files: bool = False, + force_plot: bool = False, cv_predictions_path: str = None, cv_predictions: CuppaPrediction | None = None, clf_group: str = DEFAULT_RUNNER_ARGS.clf_group, @@ -35,6 +36,7 @@ def __init__( self.output_dir = output_dir self.sample_id = sample_id self.compress_tsv_files = compress_tsv_files + self.force_plot = force_plot self.classifier_path = classifier_path self.cv_predictions_path = cv_predictions_path @@ -180,4 +182,10 @@ def run(self) -> None: self.pred_summ.to_tsv(self.pred_summ_path, verbose=True) self.vis_data.to_tsv(self.vis_data_path, verbose=True) - CuppaVisPlotter.from_tsv(path=self.vis_data_path, plot_path=self.plot_path, verbose=True).plot() + plotter = CuppaVisPlotter.from_tsv( + path=self.vis_data_path, + plot_path=self.plot_path, + force_plot=self.force_plot, + verbose=True + ) + plotter.plot() diff --git a/cuppa/src/main/python/pycuppa/cuppa/visualization/visualization.py b/cuppa/src/main/python/pycuppa/cuppa/visualization/visualization.py index 947a6c91ba..ab1679060f 100644 --- a/cuppa/src/main/python/pycuppa/cuppa/visualization/visualization.py +++ b/cuppa/src/main/python/pycuppa/cuppa/visualization/visualization.py @@ -270,17 +270,16 @@ def __init__( self, vis_data: CuppaVisData, plot_path: str, + force_plot: bool = False, verbose: bool = True ): self.vis_data = vis_data self.plot_path = os.path.expanduser(plot_path) + self.force_plot = force_plot self.verbose = verbose self.vis_data_path: str | None = None - self._check_number_of_samples() - self._check_plot_path_extension() - @classmethod def from_tsv(cls, path: str, **kwargs) -> CuppaVisPlotter: ## This method exists to avoid writing a temporary vis data file if a vis data file already exists @@ -289,18 +288,36 @@ def from_tsv(cls, path: str, **kwargs) -> CuppaVisPlotter: plotter.vis_data_path = path return plotter - def _check_number_of_samples(self): - sample_ids = self.vis_data["sample_id"].dropna().unique() + @property + def sample_ids(self): + return self.vis_data["sample_id"].dropna().unique() + + PLOT_MAX_SAMPLES = 10 + + + def _check_should_plot(self) -> bool: + + n_samples = len(self.sample_ids) - max_samples = 25 - if len(sample_ids) > max_samples: - self.logger.error("Plotting predictions for >", max_samples, " is not supported") - raise RuntimeError + if n_samples <= self.PLOT_MAX_SAMPLES: + return True + + elif self.force_plot: + self.logger.info(f"Forcing plotting predictions for {n_samples} (>{self.PLOT_MAX_SAMPLES}) samples") + return True + + else: + self.logger.warning(f"Skipping plotting predictions for {n_samples} (>{self.PLOT_MAX_SAMPLES}) samples. Please use arg --force_plot to plot anyway") + return False def _check_plot_path_extension(self): if not self.plot_path.endswith((".pdf", ".png")): self.logger.error("`plot_path` must end with .pdf or .png") - raise ValueError + raise Exception + + if len(self.sample_ids) > 1 and not self.plot_path.endswith(".pdf"): + self.logger.error("`plot_path` must end with .pdf for multi-sample plotting") + raise Exception @property def _tmp_vis_data_path(self) -> str: @@ -321,6 +338,11 @@ def _remove_tmp_vis_data(self): def plot(self) -> None: try: + self._check_plot_path_extension() + + if not self._check_should_plot(): + return + if self.vis_data_path is None or not os.path.exists(self.vis_data_path): self._write_tmp_vis_data() self.vis_data_path = self._tmp_vis_data_path