Skip to content

Commit

Permalink
Cuppa: Fix bug where RNA classifier probs were shown when RNA data is…
Browse files Browse the repository at this point in the history
… missing

This was due to MissingFeaturesHandler filling in NAs for samples with no RNA data
  • Loading branch information
luan-n-nguyen committed Dec 16, 2024
1 parent 61e9f0d commit 069deaa
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from sklearn.compose import make_column_selector

from cuppa.components.passthrough import PassthroughTransformer
from cuppa.components.preprocessing import NaRowFilter
from cuppa.constants import SUB_CLF_NAMES, SIG_QUANTILE_TRANSFORMER_NAME, PREDICT_NA_FILL_VALUE
from cuppa.logger import LoggerMixin

Expand Down Expand Up @@ -110,13 +109,13 @@ def check_features(self) -> None:
n_missing,
self._get_feat_type_counts_string(missing)
))
self.logger.error("Please use " + self.__class__.__name__ + ".fill_missing_cols() to ensure `X` has the required columns")
self.logger.error(f"Please use {self.__class__.__name__}.{self.fill_missing.__name__}() to ensure `X` has the required columns")
raise LookupError

def fill_missing(self) -> pd.DataFrame:

## DNA --------------------------------
X_dna = self.X.reindex(columns=self.required_dna_features, fill_value=self.fill_value)
X_dna = self.X.reindex(columns=self.required_dna_features).fillna(self.fill_value)

## RNA --------------------------------
pattern_rna_features = f"^{SUB_CLF_NAMES.GENE_EXP}|{SUB_CLF_NAMES.ALT_SJ}"
Expand All @@ -130,12 +129,15 @@ def fill_missing(self) -> pd.DataFrame:
columns=self.required_rna_features
)
else:
## Samples without RNA data either have no RNA columns
## We don't want to fill these rows with 0 because this would produce an unwanted probability
is_missing_rna_data = NaRowFilter.detect_na_rows(X_rna, use_first_col=True)
## Samples with no RNA data -> Rows with all NA. We don't want to fill these rows with e.g. 0 because this
## would produce an unwanted probability. We therefore first remove these samples/rows altogether.
has_rna_data = ~np.isnan(X_rna).all(axis=1)
X_rna = X_rna.loc[has_rna_data]

## ... then we can fill NAs only for the samples that have RNA data
X_rna = X_rna \
.loc[~is_missing_rna_data] \
.reindex(columns=self.required_rna_features, fill_value=self.fill_value)
.reindex(columns=self.required_rna_features) \
.fillna(self.fill_value)

X_new = pd.concat([X_dna, X_rna], axis=1)
del X_dna, X_rna
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from cuppa.classifier.cuppa_classifier import CuppaClassifier
from cuppa.classifier.cuppa_prediction import CuppaPrediction, CuppaPredSummary
from cuppa.constants import PREDICT_NA_FILL_VALUE
from cuppa.runners.args import DEFAULT_RUNNER_ARGS
from cuppa.logger import LoggerMixin, initialize_logging
from cuppa.sample_data.cuppa_features import CuppaFeaturesLoader
Expand Down Expand Up @@ -72,7 +73,7 @@ def get_X(self) -> None:
loader = CuppaFeaturesLoader(self.features_path, sample_id=self.sample_id)
X = loader.load()

X = self.cuppa_classifier.fill_missing_cols(X)
X = self.cuppa_classifier.fill_missing_cols(X, PREDICT_NA_FILL_VALUE)

self.X = X

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os.path
import pandas as pd

from cuppa.constants import PREDICT_NA_FILL_VALUE
from cuppa.logger import LoggerMixin
from cuppa.misc.utils import check_required_columns

Expand Down Expand Up @@ -67,11 +66,9 @@ def __init__(
self,
path: str,
sample_id: str | None = None,
na_fill_value: int | float = PREDICT_NA_FILL_VALUE
):
self.path = path
self.sample_id = sample_id
self.na_fill_value = na_fill_value

self.df: pd.DataFrame = None

Expand Down Expand Up @@ -197,4 +194,4 @@ def load(self) -> pd.DataFrame:
self._assign_feature_names()
self._print_stats()

return self.df.fillna(self.na_fill_value).transpose()
return self.df.transpose()
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import tempfile

import numpy as np
import pytest

from tests.mock_data import MockInputData
Expand Down Expand Up @@ -83,14 +84,18 @@ def test_can_load_multi_sample_from_tsvs(self):

loader = CuppaFeaturesLoader(MockInputData.cohort_dir)
features = loader.load()

na_fill_value = -1e-8
features = features.fillna(na_fill_value)

assert features.shape == (2, 6225)

assert features["gen_pos.1_500000"].tolist() == [0, 1]
assert features["snv96.C>T_TCC"].tolist() == [0, 2]
assert features["event.tmb.snv_count"].tolist() == [0, 8]
assert features["event.sv.SIMPLE_DEL_20KB_1MB"].tolist() == [0, 20]
assert features["event.fusion.TMPRSS2_ERG"].tolist() == [loader.na_fill_value, 1]
assert features["event.fusion.TMPRSS2_ERG"].tolist() == [na_fill_value, 1]
assert features["event.trait.is_male"].tolist() == [0, 1]
assert features["sig.UV (SBS7)"].tolist() == [0, 6.4]
assert features["gene_exp.BRAF"].tolist() == [loader.na_fill_value, 3.434]
assert features["alt_sj.7;140426316;140439612"].tolist() == [loader.na_fill_value, 2]
assert features["gene_exp.BRAF"].tolist() == [na_fill_value, 3.434]
assert features["alt_sj.7;140426316;140439612"].tolist() == [na_fill_value, 2]

0 comments on commit 069deaa

Please sign in to comment.