Skip to content

Commit

Permalink
Merge pull request #72 from Roestlab/patch/conformer
Browse files Browse the repository at this point in the history
refactor conformer model for raw extraction
  • Loading branch information
jcharkow authored Jan 24, 2024
2 parents cd9a05f + 03ccb2c commit 9320cc6
Show file tree
Hide file tree
Showing 16 changed files with 245 additions and 78 deletions.
3 changes: 2 additions & 1 deletion massdash/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
URL_TEST_OSW = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/openswath/osw/test.osw"
URL_TEST_PQP = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/openswath/lib/test.pqp"
URL_TEST_RAW_MZML = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/raw/test_raw_1.mzML"
URL_TEST_DREAMDIA_REPORT = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/dreamdia/test_dreamdia_report.tsv"
URL_TEST_DREAMDIA_REPORT = "https://github.com/Roestlab/massdash/raw/dev/test/test_data/example_dia/dreamdia/test_dreamdia_report.tsv"
URL_PRETRAINED_CONFORMER = "https://github.com/Roestlab/massdash/releases/download/v0.0.1-alpha/base_cape.onnx"
2 changes: 1 addition & 1 deletion massdash/loaders/SqMassLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def loadTransitionGroups(self, pep_id: str, charge: int) -> Dict[str, Transition
prec_chrom_ids = t.getPrecursorChromIDs(precursor_id)
precursor_chroms = t.getDataForChromatograms(prec_chrom_ids['chrom_ids'], prec_chrom_ids['native_ids'])

out[t] = TransitionGroup(precursor_chroms, transition_chroms)
out[t] = TransitionGroup(precursor_chroms, transition_chroms, pep_id, charge)
return out

def loadTransitionGroupFeaturesDf(self, pep_id: str, charge: int) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion massdash/loaders/access/MzMLDataAccess.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def msExperimentToFeatureMap(self, msExperiment: po.MSExperiment, feature: Trans
else:
LOGGER.warn(f"No spectra found for peptide: {feature.sequence}{feature.precursor_charge}. Try adjusting the extraction parameters")

return FeatureMap(results_df, config)
return FeatureMap(results_df, feature.sequence, feature.precursor_charge, config)

def _find_closest_reference_mz(self, given_mz: np.array, reference_mz_values: np.array, peptide_product_annotation_list: np.array) -> np.array:
"""
Expand Down
38 changes: 25 additions & 13 deletions massdash/peakPickers/ConformerPeakPicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# Structs
from ..structs.TransitionGroup import TransitionGroup
from ..structs.TransitionGroupFeature import TransitionGroupFeature
from ..loaders.SpectralLibraryLoader import SpectralLibraryLoader
# Utils
from ..util import check_package
from ..util import LOGGER

onnxruntime, ONNXRUNTIME_AVAILABLE = check_package("onnxruntime")

Expand All @@ -35,7 +37,7 @@ class ConformerPeakPicker:
_convertConformerFeatureToTransitionGroupFeatures: Convert conformer predicted feature to TransitionGroupFeatures.
"""

def __init__(self, transition_group: TransitionGroup, pretrained_model_file: str, window_size: int = 175, prediction_threshold: float = 0.5, prediction_type: str = "logits"):
def __init__(self, library_file: str, pretrained_model_file: str, prediction_threshold: float = 0.5, prediction_type: str = "logits"):
"""
Initialize the ConformerPeakPicker class.
Expand All @@ -46,14 +48,18 @@ def __init__(self, transition_group: TransitionGroup, pretrained_model_file: str
prediction_threshold (float, optional): The prediction threshold for peak picking. Defaults to 0.5.
prediction_type (str, optional): The prediction type for peak picking. Defaults to "logits".
"""
self.transition_group = transition_group
self.pretrained_model_file = pretrained_model_file
self.window_size = window_size
self.prediction_threshold = prediction_threshold
self.prediction_type = prediction_type
self.onnx_session = None
self.library = SpectralLibraryLoader(library_file)

self._validate_model()

## set in load_model
self.onnx_session = None
self.window_size = None

LOGGER.name = __class__.__name__

def _validate_model(self):
"""
Expand All @@ -73,8 +79,14 @@ def load_model(self):
raise ImportError("onnxruntime is required for loading the pretrained Conformer model, but not installed.")
# Load pretrained model
self.onnx_session = onnxruntime.InferenceSession(self.pretrained_model_file)
if len(self.onnx_session.get_inputs()) == 0:
raise ValueError("Pretrained model does not have any inputs.")
elif len(self.onnx_session.get_inputs()[0].shape) != 3:
raise ValueError("First input to model must be a 3D numpy array, current shape: {}".format(len(self.onnx_session.get_inputs()[0].shape)))
else:
self.window_size = self.onnx_session.get_inputs()[0].shape[2]

def pick(self, max_int_transition: int=1000) -> List[TransitionGroupFeature]:
def pick(self, transition_group, max_int_transition: int=1000) -> List[TransitionGroupFeature]:
"""
Perform peak picking.
Expand All @@ -85,19 +97,19 @@ def pick(self, max_int_transition: int=1000) -> List[TransitionGroupFeature]:
List[TransitionGroupFeature]: The list of transition group features.
"""
# Transform data into required input
print("Preprocessing data...")
conformer_preprocessor = ConformerPreprocessor(self.transition_group)
input_data = conformer_preprocessor.preprocess()
print("Loading model...")
LOGGER.info("Loading model...")
self.load_model()
print("Predicting...")
LOGGER.info("Preprocessing data...")
conformer_preprocessor = ConformerPreprocessor(transition_group, self.window_size)
input_data = conformer_preprocessor.preprocess(self.library)
LOGGER.info("Predicting...")
ort_input = {self.onnx_session.get_inputs()[0].name: input_data}
ort_output = self.onnx_session.run(None, ort_input)
print("Getting predicted boundaries...")
LOGGER.info("Getting predicted boundaries...")
peak_info = conformer_preprocessor.find_top_peaks(ort_output[0], ["precursor"], self.prediction_threshold, self.prediction_type)
# Get actual peak boundaries
peak_info = conformer_preprocessor.get_peak_boundaries(peak_info, self.transition_group, self.window_size)
print(f"Peak info: {peak_info}")
peak_info = conformer_preprocessor.get_peak_boundaries(peak_info)
LOGGER.info(f"Peak info: {peak_info}")
return self._convertConformerFeatureToTransitionGroupFeatures(peak_info, max_int_transition)

def _convertConformerFeatureToTransitionGroupFeatures(self, peak_info: dict, max_int_transition: int=1000) -> List[TransitionGroupFeature]:
Expand Down
64 changes: 25 additions & 39 deletions massdash/preprocess/ConformerPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .GenericPreprocessor import GenericPreprocessor
# Structs
from ..structs.TransitionGroup import TransitionGroup
from ..loaders.SpectralLibraryLoader import SpectralLibraryLoader
# Utils
from ..util import check_package

Expand All @@ -36,9 +37,13 @@ class ConformerPreprocessor(GenericPreprocessor):
"""

def __init__(self, transition_group: TransitionGroup):
def __init__(self, transition_group: TransitionGroup, window_size: int=175):
super().__init__(transition_group)

## pad the transition group to the window size
self.transition_group = self.transition_group.adjust_length(window_size)
self.window_size = window_size

@staticmethod
def min_max_scale(data, min: float=None, max: float=None) -> np.ndarray:
"""
Expand Down Expand Up @@ -101,14 +106,14 @@ def sigmoid(x: np.ndarray) -> np.ndarray:
"""
return 1 / (1 + np.exp(-x))

def preprocess(self, window_size: int=175) -> np.ndarray:
def preprocess(self, library: SpectralLibraryLoader) -> np.ndarray:
"""
Preprocesses the data by scaling and transforming it into a numpy array.
Code adapted from CAPE
Args:
window_size (int): The desired window size for trimming the data. Default is 175.
SpectralLibraryLoader (SpectralLibraryLoader): The spectral library loader.
Returns:
np.ndarray: The preprocessed data as a numpy array with shape (1, 21, len(data[0])).
Expand All @@ -122,17 +127,19 @@ def preprocess(self, window_size: int=175) -> np.ndarray:
# Row index 19: library retention time diff
# Row index 20: precursor charge

# initialize empty numpy array
data = np.empty((0, len(self.transition_group.transitionData[0].intensity)), float)
if len(self.transition_group.transitionData) != 6:
raise ValueError(f"Transition group must have 6 transitions, but has {len(self.transition_group.transitionData)}.")

lib_int_data = np.empty((0, len(self.transition_group.transitionData[0].intensity)), float)
# initialize empty numpy array
data = np.empty((0, self.window_size), float)
lib_int_data = np.empty((0, self.window_size), float)

for chrom in self.transition_group.transitionData:
# append ms2 intensity data to data
data = np.append(data, [chrom.intensity], axis=0)

lib_int = self.transition_group.targeted_transition_list[self.transition_group.targeted_transition_list.Annotation==chrom.label]['LibraryIntensity'].values
lib_int = np.repeat(lib_int, len(chrom.intensity))
lib_int = library.get_fragment_library_intensity(self.transition_group.sequence, self.transition_group.precursor_charge, chrom.label)
lib_int = np.repeat(lib_int, self.window_size)
lib_int_data = np.append(lib_int_data, [lib_int], axis=0)

# initialize empty numpy array to store scaled data
Expand All @@ -148,20 +155,7 @@ def preprocess(self, window_size: int=175) -> np.ndarray:
)

## MS1 trace data
# padd precursor intensity data with zeros to match ms2 intensity data
len_trans = len(self.transition_group.transitionData[0].intensity)
len_prec = len(self.transition_group.precursorData[0].intensity)
if len_prec!=len_trans:
if len_prec < len_trans:
prec_int = np.pad(self.transition_group.precursorData[0].intensity, (0, len_trans-len_prec), 'constant', constant_values=(0, 0))
if len_prec > len_trans:
prec_int = self.transition_group.precursorData[0].intensity
# compute number of points to trim from either side of the middle point
remove_n_points = len_prec - len_trans
# trim precursor intensity data
prec_int = prec_int[remove_n_points//2:-remove_n_points//2]
else:
prec_int = self.transition_group.precursorData[0].intensity
prec_int = self.transition_group.precursorData[0].intensity

# append ms1 intensity data to data
new_data[12] = self.min_max_scale(prec_int)
Expand Down Expand Up @@ -190,18 +184,11 @@ def preprocess(self, window_size: int=175) -> np.ndarray:
new_data[19] = tmp_arr

## Add charge state
new_data[20] = self.transition_group.targeted_transition_list.PrecursorCharge.values[0] * np.ones(len(data[0]))
new_data[20] = self.transition_group.precursor_charge * np.ones(len(data[0]))

## Convert to float32
new_data = new_data.astype(np.float32)

## trim data if does not match window size starting at the centre
if len(new_data[0]) > window_size:
middle_index = len(data[0]) // 2
trim_start = middle_index - (window_size // 2)
trim_end = middle_index + (window_size // 2) + 1
new_data = new_data[:, trim_start:trim_end]

# cnvert the shape to be (1, 21, len(data[0]))
new_data = np.expand_dims(new_data, axis=0)

Expand Down Expand Up @@ -297,43 +284,42 @@ def find_top_peaks(self, preds, seq_classes: List[str]='input_precursor', thresh

return peak_info

def get_peak_boundaries(self, peak_info: dict, tr_group: TransitionGroup, window_size: int=175):
def get_peak_boundaries(self, peak_info: dict):
"""
Adjusts the peak boundaries in the peak_info dictionary based on the window size and the dimensions of the input rt_array.
Calculates the actual RT values from the rt_array and appends them to the peak_info dictionary.
Args:
peak_info (dict): A dictionary containing information about the peaks.
tr_group (TransitionGroup): The transition group containing the data.
window_size (int, optional): The size of the window used for trimming the rt_array. Defaults to 175.
Returns:
dict: The updated peak_info dictionary with adjusted peak boundaries and RT values.
"""
rt_array = tr_group.transitionData[0].data
if rt_array.shape[0] != window_size:
print(f"input_data {rt_array.shape[0]} was trimmed to {window_size}, adjusting peak_info indexes to map to the original datas dimensions")
rt_array = self.transition_group.transitionData[0].data
if rt_array.shape[0] != self.window_size:
print(f"input_data {rt_array.shape[0]} was trimmed to {self.window_size}, adjusting peak_info indexes to map to the original datas dimensions")
for key in peak_info.keys():
for i in range(len(peak_info[key])):
peak_info[key][i]['max_idx_org'] = peak_info[key][i]['max_idx']
peak_info[key][i]['start_idx_org'] = peak_info[key][i]['start_idx']
peak_info[key][i]['end_idx_org'] = peak_info[key][i]['end_idx']
new_max_idx = peak_info[key][i]['max_idx'] + (window_size // 2) - (rt_array.shape[0] // 2)
new_max_idx = peak_info[key][i]['max_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)
if not new_max_idx < 0:
peak_info[key][i]['max_idx'] = new_max_idx

new_start_idx = peak_info[key][i]['start_idx'] + (window_size // 2) - (rt_array.shape[0] // 2)
new_start_idx = peak_info[key][i]['start_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)
if not new_start_idx < 0:
peak_info[key][i]['start_idx'] = new_start_idx

peak_info[key][i]['end_idx'] = peak_info[key][i]['end_idx'] + (window_size // 2) - (rt_array.shape[0] // 2)
peak_info[key][i]['end_idx'] = peak_info[key][i]['end_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)

# get actual RT value from RT array and append to peak_info
for key in peak_info.keys():
for i in range(len(peak_info[key])):
peak_info[key][i]['rt_apex'] = rt_array[peak_info[key][i]['max_idx']]
peak_info[key][i]['rt_start'] = rt_array[peak_info[key][i]['start_idx']]
peak_info[key][i]['rt_end'] = rt_array[peak_info[key][i]['end_idx']]
peak_info[key][i]['int_apex'] = np.max([tg.intensity[peak_info[key][i]['max_idx']] for tg in tr_group.transitionData])
peak_info[key][i]['int_apex'] = np.max([tg.intensity[peak_info[key][i]['max_idx']] for tg in self.transition_group.transitionData])

return peak_info
4 changes: 2 additions & 2 deletions massdash/server/ExtractedIonChromatogramAnalysisServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def main(self):
tr_group.targeted_transition_list = transition_list_ui.target_transition_list
print(f"Pretrained model file: {peak_picking_settings.peak_picker_algo_settings.pretrained_model_file}")

peak_picker = ConformerPeakPicker(tr_group, peak_picking_settings.peak_picker_algo_settings.pretrained_model_file, window_size=peak_picking_settings.peak_picker_algo_settings.conformer_window_size, prediction_threshold=peak_picking_settings.peak_picker_algo_settings.conformer_prediction_threshold, prediction_type=peak_picking_settings.peak_picker_algo_settings.conformer_prediction_type)
peak_picker = ConformerPeakPicker(self.massdash_gui.file_input_settings.osw_file_path, peak_picking_settings.peak_picker_algo_settings.pretrained_model_file, window_size=peak_picking_settings.peak_picker_algo_settings.conformer_window_size, prediction_threshold=peak_picking_settings.peak_picker_algo_settings.conformer_prediction_threshold, prediction_type=peak_picking_settings.peak_picker_algo_settings.conformer_prediction_type)
# get the trantition in tr_group with the max intensity
max_int_transition = np.max([transition.intensity for transition in tr_group.transitionData])
peak_features = peak_picker.pick(max_int_transition)
peak_features = peak_picker.pick(tr_group, max_int_transition)
tr_group_feature_data[file.filename] = peak_features
st.write(f"Performing Conformer Peak Picking... Elapsed time: {elapsed_time()}")
else:
Expand Down
5 changes: 3 additions & 2 deletions massdash/server/OneDimensionPlotterServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ class OneDimensionPlotterServer:
def __init__(self,
feature_map_dict: Dict[str, FeatureMap],
transition_list_ui: TransitionListUISettings, chrom_plot_settings: ChromatogramPlotUISettings,
peak_picking_settings: PeakPickingUISettings,
peak_picking_settings: PeakPickingUISettings, spectral_library_path: str=None,
verbose: bool=False):
self.feature_map_dict = feature_map_dict
self.transition_list_ui = transition_list_ui
self.chrom_plot_settings = chrom_plot_settings
self.peak_picking_settings = peak_picking_settings
self.spectral_library_path = spectral_library_path
self.plot_obj_dict = {}
self.verbose = verbose

Expand All @@ -74,7 +75,7 @@ def generate_chromatogram_plots(self):
tr_group = feature_map.to_chromatograms()
# Perform peak picking if enabled
peak_picker = PeakPickingServer(self.peak_picking_settings, self.chrom_plot_settings)
tr_group_feature_data = peak_picker.perform_peak_picking(tr_group_data={'tmp':tr_group}, transition_list_ui=self.transition_list_ui)
tr_group_feature_data = peak_picker.perform_peak_picking(tr_group_data={'tmp':tr_group}, transition_list_ui=self.transition_list_ui, spec_lib=self.spectral_library_path)
plot_settings_dict = self._get_plot_settings('Retention Time (s)', 'Intensity', file, 'chromatogram')
plot_obj = self._generate_plot(tr_group, plot_settings_dict, tr_group_feature_data['tmp'])
run_plots_list.append(plot_obj)
Expand Down
Loading

0 comments on commit 9320cc6

Please sign in to comment.