Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] error in conformer model #72

Merged
merged 13 commits into from
Jan 24, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![pypiv](https://img.shields.io/pypi/v/massdash.svg)](https://pypi.python.org/pypi/massdash)
[![continuous-integration](https://github.com/Roestlab/massdash/workflows/continuous-integration/badge.svg)](https://github.com/Roestlab/massdash/actions)
[![pypidownload](https://img.shields.io/pypi/dm/massdash?color=orange)](https://pypistats.org/packages/massdash)
[![biocondav](https://img.shields.io/conda/v/bioconda/massdash?label=bioconda&color=purple)](https://anaconda.org/bioconda/massdash)
[![dockerv](https://img.shields.io/docker/v/singjust/massdash?label=docker&color=green)](https://hub.docker.com/r/singjust/massdash)
[![dockerpull](https://img.shields.io/docker/pulls/singjust/massdash?color=green)](https://hub.docker.com/r/singjust/massdash)
[![continuous-integration](https://github.com/Roestlab/massdash/workflows/continuous-integration/badge.svg)](https://github.com/Roestlab/massdash/actions)
[![readthedocs](https://img.shields.io/readthedocs/massdash)](https://massdash.readthedocs.io/en/latest/index.html)
[![Licence](https://img.shields.io/badge/License-BSD_3--Clause-orange.svg)](https://raw.githubusercontent.com/RoestLab/massdash/main/LICENSE)

Expand Down
3 changes: 2 additions & 1 deletion massdash/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
URL_TEST_OSW = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/openswath/osw/test.osw"
URL_TEST_PQP = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/openswath/lib/test.pqp"
URL_TEST_RAW_MZML = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/raw/test_raw_1.mzML"
URL_TEST_DREAMDIA_REPORT = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/dreamdia/test_dreamdia_report.tsv"
URL_TEST_DREAMDIA_REPORT = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/dreamdia/test_dreamdia_report.tsv"
URL_PRETRAINED_CONFORMER = "https://github.com/Roestlab/massdash/releases/download/v0.0.1-alpha/base_cape.onnx"
2 changes: 1 addition & 1 deletion massdash/loaders/SqMassLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def loadTransitionGroups(self, pep_id: str, charge: int) -> Dict[str, Transition
prec_chrom_ids = t.getPrecursorChromIDs(precursor_id)
precursor_chroms = t.getDataForChromatograms(prec_chrom_ids['chrom_ids'], prec_chrom_ids['native_ids'])

out[t] = TransitionGroup(precursor_chroms, transition_chroms)
out[t] = TransitionGroup(precursor_chroms, transition_chroms, pep_id, charge)
return out

def loadTransitionGroupFeaturesDf(self, pep_id: str, charge: int) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion massdash/loaders/access/MzMLDataAccess.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def msExperimentToFeatureMap(self, msExperiment: po.MSExperiment, feature: Trans
else:
LOGGER.warn(f"No spectra found for peptide: {feature.sequence}{feature.precursor_charge}. Try adjusting the extraction parameters")

return FeatureMap(results_df, config)
return FeatureMap(results_df, feature.sequence, feature.precursor_charge, config)

def _find_closest_reference_mz(self, given_mz: np.array, reference_mz_values: np.array, peptide_product_annotation_list: np.array) -> np.array:
"""
Expand Down
38 changes: 25 additions & 13 deletions massdash/peakPickers/ConformerPeakPicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# Structs
from ..structs.TransitionGroup import TransitionGroup
from ..structs.TransitionGroupFeature import TransitionGroupFeature
from ..loaders.SpectralLibraryLoader import SpectralLibraryLoader
# Utils
from ..util import check_package
from ..util import LOGGER

onnxruntime, ONNXRUNTIME_AVAILABLE = check_package("onnxruntime")

Expand All @@ -35,7 +37,7 @@ class ConformerPeakPicker:
_convertConformerFeatureToTransitionGroupFeatures: Convert conformer predicted feature to TransitionGroupFeatures.
"""

def __init__(self, transition_group: TransitionGroup, pretrained_model_file: str, window_size: int = 175, prediction_threshold: float = 0.5, prediction_type: str = "logits"):
def __init__(self, library_file: str, pretrained_model_file: str, prediction_threshold: float = 0.5, prediction_type: str = "logits"):
"""
Initialize the ConformerPeakPicker class.

Expand All @@ -46,14 +48,18 @@ def __init__(self, transition_group: TransitionGroup, pretrained_model_file: str
prediction_threshold (float, optional): The prediction threshold for peak picking. Defaults to 0.5.
prediction_type (str, optional): The prediction type for peak picking. Defaults to "logits".
"""
self.transition_group = transition_group
self.pretrained_model_file = pretrained_model_file
self.window_size = window_size
self.prediction_threshold = prediction_threshold
self.prediction_type = prediction_type
self.onnx_session = None
self.library = SpectralLibraryLoader(library_file)

self._validate_model()

## set in load_model
self.onnx_session = None
self.window_size = None

LOGGER.name = __class__.__name__

def _validate_model(self):
"""
Expand All @@ -73,8 +79,14 @@ def load_model(self):
raise ImportError("onnxruntime is required for loading the pretrained Conformer model, but not installed.")
# Load pretrained model
self.onnx_session = onnxruntime.InferenceSession(self.pretrained_model_file)
if len(self.onnx_session.get_inputs()) == 0:
raise ValueError("Pretrained model does not have any inputs.")
elif len(self.onnx_session.get_inputs()[0].shape) != 3:
raise ValueError("First input to model must be a 3D numpy array, current shape: {}".format(len(self.onnx_session.get_inputs()[0].shape)))
else:
self.window_size = self.onnx_session.get_inputs()[0].shape[2]

def pick(self, max_int_transition: int=1000) -> List[TransitionGroupFeature]:
def pick(self, transition_group, max_int_transition: int=1000) -> List[TransitionGroupFeature]:
"""
Perform peak picking.

Expand All @@ -85,19 +97,19 @@ def pick(self, max_int_transition: int=1000) -> List[TransitionGroupFeature]:
List[TransitionGroupFeature]: The list of transition group features.
"""
# Transform data into required input
print("Preprocessing data...")
conformer_preprocessor = ConformerPreprocessor(self.transition_group)
input_data = conformer_preprocessor.preprocess()
print("Loading model...")
LOGGER.info("Loading model...")
self.load_model()
print("Predicting...")
LOGGER.info("Preprocessing data...")
conformer_preprocessor = ConformerPreprocessor(transition_group, self.window_size)
input_data = conformer_preprocessor.preprocess(self.library)
LOGGER.info("Predicting...")
ort_input = {self.onnx_session.get_inputs()[0].name: input_data}
ort_output = self.onnx_session.run(None, ort_input)
print("Getting predicted boundaries...")
LOGGER.info("Getting predicted boundaries...")
peak_info = conformer_preprocessor.find_top_peaks(ort_output[0], ["precursor"], self.prediction_threshold, self.prediction_type)
# Get actual peak boundaries
peak_info = conformer_preprocessor.get_peak_boundaries(peak_info, self.transition_group, self.window_size)
print(f"Peak info: {peak_info}")
peak_info = conformer_preprocessor.get_peak_boundaries(peak_info)
LOGGER.info(f"Peak info: {peak_info}")
return self._convertConformerFeatureToTransitionGroupFeatures(peak_info, max_int_transition)

def _convertConformerFeatureToTransitionGroupFeatures(self, peak_info: dict, max_int_transition: int=1000) -> List[TransitionGroupFeature]:
Expand Down
64 changes: 25 additions & 39 deletions massdash/preprocess/ConformerPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .GenericPreprocessor import GenericPreprocessor
# Structs
from ..structs.TransitionGroup import TransitionGroup
from ..loaders.SpectralLibraryLoader import SpectralLibraryLoader
# Utils
from ..util import check_package

Expand All @@ -36,9 +37,13 @@ class ConformerPreprocessor(GenericPreprocessor):

"""

def __init__(self, transition_group: TransitionGroup):
def __init__(self, transition_group: TransitionGroup, window_size: int=175):
super().__init__(transition_group)

## pad the transition group to the window size
self.transition_group = self.transition_group.adjust_length(window_size)
self.window_size = window_size

@staticmethod
def min_max_scale(data, min: float=None, max: float=None) -> np.ndarray:
"""
Expand Down Expand Up @@ -101,14 +106,14 @@ def sigmoid(x: np.ndarray) -> np.ndarray:
"""
return 1 / (1 + np.exp(-x))

def preprocess(self, window_size: int=175) -> np.ndarray:
def preprocess(self, library: SpectralLibraryLoader) -> np.ndarray:
"""
Preprocesses the data by scaling and transforming it into a numpy array.

Code adapted from CAPE

Args:
window_size (int): The desired window size for trimming the data. Default is 175.
SpectralLibraryLoader (SpectralLibraryLoader): The spectral library loader.

Returns:
np.ndarray: The preprocessed data as a numpy array with shape (1, 21, len(data[0])).
Expand All @@ -122,17 +127,19 @@ def preprocess(self, window_size: int=175) -> np.ndarray:
# Row index 19: library retention time diff
# Row index 20: precursor charge

# initialize empty numpy array
data = np.empty((0, len(self.transition_group.transitionData[0].intensity)), float)
if len(self.transition_group.transitionData) != 6:
raise ValueError(f"Transition group must have 6 transitions, but has {len(self.transition_group.transitionData)}.")

lib_int_data = np.empty((0, len(self.transition_group.transitionData[0].intensity)), float)
# initialize empty numpy array
data = np.empty((0, self.window_size), float)
lib_int_data = np.empty((0, self.window_size), float)

for chrom in self.transition_group.transitionData:
# append ms2 intensity data to data
data = np.append(data, [chrom.intensity], axis=0)

lib_int = self.transition_group.targeted_transition_list[self.transition_group.targeted_transition_list.Annotation==chrom.label]['LibraryIntensity'].values
lib_int = np.repeat(lib_int, len(chrom.intensity))
lib_int = library.get_fragment_library_intensity(self.transition_group.sequence, self.transition_group.precursor_charge, chrom.label)
lib_int = np.repeat(lib_int, self.window_size)
lib_int_data = np.append(lib_int_data, [lib_int], axis=0)

# initialize empty numpy array to store scaled data
Expand All @@ -148,20 +155,7 @@ def preprocess(self, window_size: int=175) -> np.ndarray:
)

## MS1 trace data
# padd precursor intensity data with zeros to match ms2 intensity data
len_trans = len(self.transition_group.transitionData[0].intensity)
len_prec = len(self.transition_group.precursorData[0].intensity)
if len_prec!=len_trans:
if len_prec < len_trans:
prec_int = np.pad(self.transition_group.precursorData[0].intensity, (0, len_trans-len_prec), 'constant', constant_values=(0, 0))
if len_prec > len_trans:
prec_int = self.transition_group.precursorData[0].intensity
# compute number of points to trim from either side of the middle point
remove_n_points = len_prec - len_trans
# trim precursor intensity data
prec_int = prec_int[remove_n_points//2:-remove_n_points//2]
else:
prec_int = self.transition_group.precursorData[0].intensity
prec_int = self.transition_group.precursorData[0].intensity

# append ms1 intensity data to data
new_data[12] = self.min_max_scale(prec_int)
Expand Down Expand Up @@ -190,18 +184,11 @@ def preprocess(self, window_size: int=175) -> np.ndarray:
new_data[19] = tmp_arr

## Add charge state
new_data[20] = self.transition_group.targeted_transition_list.PrecursorCharge.values[0] * np.ones(len(data[0]))
new_data[20] = self.transition_group.precursor_charge * np.ones(len(data[0]))

## Convert to float32
new_data = new_data.astype(np.float32)

## trim data if does not match window size starting at the centre
if len(new_data[0]) > window_size:
middle_index = len(data[0]) // 2
trim_start = middle_index - (window_size // 2)
trim_end = middle_index + (window_size // 2) + 1
new_data = new_data[:, trim_start:trim_end]

# cnvert the shape to be (1, 21, len(data[0]))
new_data = np.expand_dims(new_data, axis=0)

Expand Down Expand Up @@ -297,43 +284,42 @@ def find_top_peaks(self, preds, seq_classes: List[str]='input_precursor', thresh

return peak_info

def get_peak_boundaries(self, peak_info: dict, tr_group: TransitionGroup, window_size: int=175):
def get_peak_boundaries(self, peak_info: dict):
"""
Adjusts the peak boundaries in the peak_info dictionary based on the window size and the dimensions of the input rt_array.
Calculates the actual RT values from the rt_array and appends them to the peak_info dictionary.

Args:
peak_info (dict): A dictionary containing information about the peaks.
tr_group (TransitionGroup): The transition group containing the data.
window_size (int, optional): The size of the window used for trimming the rt_array. Defaults to 175.

Returns:
dict: The updated peak_info dictionary with adjusted peak boundaries and RT values.
"""
rt_array = tr_group.transitionData[0].data
if rt_array.shape[0] != window_size:
print(f"input_data {rt_array.shape[0]} was trimmed to {window_size}, adjusting peak_info indexes to map to the original datas dimensions")
rt_array = self.transition_group.transitionData[0].data
if rt_array.shape[0] != self.window_size:
print(f"input_data {rt_array.shape[0]} was trimmed to {self.window_size}, adjusting peak_info indexes to map to the original datas dimensions")
for key in peak_info.keys():
for i in range(len(peak_info[key])):
peak_info[key][i]['max_idx_org'] = peak_info[key][i]['max_idx']
peak_info[key][i]['start_idx_org'] = peak_info[key][i]['start_idx']
peak_info[key][i]['end_idx_org'] = peak_info[key][i]['end_idx']
new_max_idx = peak_info[key][i]['max_idx'] + (window_size // 2) - (rt_array.shape[0] // 2)
new_max_idx = peak_info[key][i]['max_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)
if not new_max_idx < 0:
peak_info[key][i]['max_idx'] = new_max_idx

new_start_idx = peak_info[key][i]['start_idx'] + (window_size // 2) - (rt_array.shape[0] // 2)
new_start_idx = peak_info[key][i]['start_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)
if not new_start_idx < 0:
peak_info[key][i]['start_idx'] = new_start_idx

peak_info[key][i]['end_idx'] = peak_info[key][i]['end_idx'] + (window_size // 2) - (rt_array.shape[0] // 2)
peak_info[key][i]['end_idx'] = peak_info[key][i]['end_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)

# get actual RT value from RT array and append to peak_info
for key in peak_info.keys():
for i in range(len(peak_info[key])):
peak_info[key][i]['rt_apex'] = rt_array[peak_info[key][i]['max_idx']]
peak_info[key][i]['rt_start'] = rt_array[peak_info[key][i]['start_idx']]
peak_info[key][i]['rt_end'] = rt_array[peak_info[key][i]['end_idx']]
peak_info[key][i]['int_apex'] = np.max([tg.intensity[peak_info[key][i]['max_idx']] for tg in tr_group.transitionData])
peak_info[key][i]['int_apex'] = np.max([tg.intensity[peak_info[key][i]['max_idx']] for tg in self.transition_group.transitionData])

return peak_info
4 changes: 2 additions & 2 deletions massdash/server/ExtractedIonChromatogramAnalysisServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def main(self):
tr_group.targeted_transition_list = transition_list_ui.target_transition_list
print(f"Pretrained model file: {peak_picking_settings.peak_picker_algo_settings.pretrained_model_file}")

peak_picker = ConformerPeakPicker(tr_group, peak_picking_settings.peak_picker_algo_settings.pretrained_model_file, window_size=peak_picking_settings.peak_picker_algo_settings.conformer_window_size, prediction_threshold=peak_picking_settings.peak_picker_algo_settings.conformer_prediction_threshold, prediction_type=peak_picking_settings.peak_picker_algo_settings.conformer_prediction_type)
peak_picker = ConformerPeakPicker(self.massdash_gui.file_input_settings.osw_file_path, peak_picking_settings.peak_picker_algo_settings.pretrained_model_file, window_size=peak_picking_settings.peak_picker_algo_settings.conformer_window_size, prediction_threshold=peak_picking_settings.peak_picker_algo_settings.conformer_prediction_threshold, prediction_type=peak_picking_settings.peak_picker_algo_settings.conformer_prediction_type)
# get the trantition in tr_group with the max intensity
max_int_transition = np.max([transition.intensity for transition in tr_group.transitionData])
peak_features = peak_picker.pick(max_int_transition)
peak_features = peak_picker.pick(tr_group, max_int_transition)
tr_group_feature_data[file.filename] = peak_features
st.write(f"Performing Conformer Peak Picking... Elapsed time: {elapsed_time()}")
else:
Expand Down
5 changes: 3 additions & 2 deletions massdash/server/OneDimensionPlotterServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ class OneDimensionPlotterServer:
def __init__(self,
feature_map_dict: Dict[str, FeatureMap],
transition_list_ui: TransitionListUISettings, chrom_plot_settings: ChromatogramPlotUISettings,
peak_picking_settings: PeakPickingUISettings,
peak_picking_settings: PeakPickingUISettings, spectral_library_path: str=None,
verbose: bool=False):
self.feature_map_dict = feature_map_dict
self.transition_list_ui = transition_list_ui
self.chrom_plot_settings = chrom_plot_settings
self.peak_picking_settings = peak_picking_settings
self.spectral_library_path = spectral_library_path
self.plot_obj_dict = {}
self.verbose = verbose

Expand All @@ -74,7 +75,7 @@ def generate_chromatogram_plots(self):
tr_group = feature_map.to_chromatograms()
# Perform peak picking if enabled
peak_picker = PeakPickingServer(self.peak_picking_settings, self.chrom_plot_settings)
tr_group_feature_data = peak_picker.perform_peak_picking(tr_group_data={'tmp':tr_group}, transition_list_ui=self.transition_list_ui)
tr_group_feature_data = peak_picker.perform_peak_picking(tr_group_data={'tmp':tr_group}, transition_list_ui=self.transition_list_ui, spec_lib=self.spectral_library_path)
plot_settings_dict = self._get_plot_settings('Retention Time (s)', 'Intensity', file, 'chromatogram')
plot_obj = self._generate_plot(tr_group, plot_settings_dict, tr_group_feature_data['tmp'])
run_plots_list.append(plot_obj)
Expand Down
Loading
Loading