Skip to content

Commit

Permalink
Merge pull request #165 from matchms/pytorch_version
Browse files Browse the repository at this point in the history
First implementation of the switch to pytorch, further development and integration will be done in dev_pytorch
  • Loading branch information
niekdejonge committed Jan 16, 2024
2 parents 23efa4b + b49abbc commit 8f3a1fe
Show file tree
Hide file tree
Showing 13 changed files with 1,043 additions and 159 deletions.
37 changes: 1 addition & 36 deletions .github/workflows/CI_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
fail-fast: false
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.9', '3.10', '3.11']
exclude:
# already tested in first_check job
- python-version: 3.9
Expand All @@ -69,38 +69,3 @@ jobs:
- name: Run tests
run: |
pytest
tensorflow_check:
name: Tensorflow version check / python-3.8 / ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Python info
run: |
which python
python --version
- name: Install Tensorflow version 2.6
run: |
python -m pip install --upgrade pip
pip install "tensorflow>=2.6,<2.7"
- name: Install other dependencies
run: |
pip install -e .[dev,train]
- name: Show pip list
run: |
pip list
- name: Run test with tensorflow version 2.6
run: pytest
- name: Install Tensorflow version 2.8
run: |
pip install --upgrade "numpy<1.24.0"
pip install --upgrade "tensorflow>=2.8,<2.9"
- name: Show pip list
run: |
pip list
- name: Run test with tensorflow version 2.8
run: pytest
48 changes: 48 additions & 0 deletions ms2deepscore/MetadataFeatureGenerator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,55 @@
import json
from importlib import import_module
from typing import List, Union
import torch
from matchms import Metadata
from matchms.typing import SpectrumType
from tqdm import tqdm
from .typing import BinnedSpectrumType


class MetadataVectorizer:
"""Create a numerical vector of selected metadata field including transformations..
"""

def __init__(self,
additional_metadata = ()):
"""
Parameters
----------
additional_metadata:
List of all metadata used/wanted in a metadata vector. Default is ().
"""
self.additional_metadata = additional_metadata

def transform(self, spectra: List[SpectrumType],
progress_bar=False) -> List[BinnedSpectrumType]:
"""Transforms the input *spectrums* into metadata vectors as needed for
MS2DeepScore.
Parameters
----------
spectra
List of spectra.
progress_bar
Show progress bar if set to True. Default is False.
Returns:
List of metadata vectors.
"""
metadata_vectors = torch.zeros((len(spectra), self.size))
for i, spec in tqdm(enumerate(spectra),
desc="Create metadata vectors",
disable=(not progress_bar)):
metadata_vectors[i, :] = \
torch.tensor([feature_generator.generate_features(spec.metadata)
for feature_generator in self.additional_metadata])
return metadata_vectors

@property
def size(self):
return len(self.additional_metadata)


class MetadataFeatureGenerator:
Expand Down
235 changes: 235 additions & 0 deletions ms2deepscore/data_generators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
""" Data generators for training/inference with siamese Keras model.
"""
from typing import Iterator, List, NamedTuple, Optional
from matchms import Spectrum
import numba
import numpy as np
import pandas as pd
import torch
from tensorflow.keras.utils import Sequence # pylint: disable=import-error
from ms2deepscore.SpectrumBinner import SpectrumBinner
from ms2deepscore.train_new_model.spectrum_pair_selection import \
Expand All @@ -11,6 +14,238 @@
from .typing import BinnedSpectrumType


class DataGeneratorPytorch:
"""Generates data for training a siamese Keras model.
This class provides a data generator specifically
designed for training a siamese Keras model with a curated set of compound pairs.
It uses pre-selected compound pairs, allowing more control over the training process,
particularly in scenarios where certain compound pairs are of specific interest or
have higher significance in the training dataset.
"""
def __init__(self, spectrums: list[Spectrum],
selected_compound_pairs: SelectedCompoundPairs,
min_mz, max_mz, mz_bin_width, intensity_scaling,
metadata_vectorizer,
**settings):
"""Generates data for training a siamese Keras model.
Parameters
----------
spectrums
List of matchms Spectrum objects.
selected_compound_pairs
SelectedCompoundPairs object which contains selected compounds pairs and the
respective similarity scores.
min_mz
Lower bound for m/z values to consider.
max_mz
Upper bound for m/z values to consider.
mz_bin_width
Bin width for m/z sampling.
intensity_scaling
To put more attention on small and medium intensity peaks, peak intensities are
scaled by intensity to the power of intensity_scaling.
metadata_vectorizer
Add the specific MetadataVectorizer object for your data if the model should contain specific
metadata entries as input. Default is set to None which means this will be ignored.
settings
The available settings can be found in GeneratorSettings
"""
# pylint: disable=too-many-arguments
self.current_index = 0
self.spectrums = spectrums

# Collect all inchikeys
self.spectrum_inchikeys = np.array([s.get("inchikey")[:14] for s in self.spectrums])

# Set all other settings to input (or otherwise to defaults):
self.settings = GeneratorSettings(settings)
self.min_mz = min_mz
self.max_mz = max_mz
self.mz_bin_width = mz_bin_width
self.intensity_scaling = intensity_scaling
self.num_bins = int((max_mz - min_mz) / mz_bin_width)
self.metadata_vectorizer = metadata_vectorizer

unique_inchikeys = np.unique(self.spectrum_inchikeys)
if len(unique_inchikeys) < self.settings.batch_size:
raise ValueError("The number of unique inchikeys must be larger than the batch size.")
self.fixed_set = {}
self.selected_compound_pairs = selected_compound_pairs
self.on_epoch_end()

def __len__(self):
return int(self.settings.num_turns)\
* int(np.ceil(len(self.selected_compound_pairs.scores) / self.settings.batch_size))

def __iter__(self):
return self

def __next__(self):
if self.current_index < self.__len__():
batch = self.__getitem__(self.current_index)
self.current_index += 1
return batch
self.current_index = 0 # make generator executable again
self.on_epoch_end()
raise StopIteration

def _spectrum_pair_generator(self, batch_index: int):
"""Use the provided SelectedCompoundPairs object to pick pairs."""
batch_size = self.settings.batch_size
indexes = self.indexes[batch_index * batch_size:(batch_index + 1) * batch_size]
for index in indexes:
inchikey1 = self.selected_compound_pairs.idx_to_inchikey[index]
score, inchikey2 = self.selected_compound_pairs.next_pair_for_inchikey(inchikey1)
spectrum1 = self._get_spectrum_with_inchikey(inchikey1)
spectrum2 = self._get_spectrum_with_inchikey(inchikey2)
yield (spectrum1, spectrum2, score)

def on_epoch_end(self):
"""Updates indexes after each epoch."""
self.indexes = np.tile(np.arange(len(self.selected_compound_pairs.scores)), int(self.settings.num_turns))
if self.settings.shuffle:
np.random.shuffle(self.indexes)

def __getitem__(self, batch_index: int):
"""Generate one batch of data.
If use_fixed_set=True we try retrieving the batch from self.fixed_set (or store it if
this is the first epoch). This ensures a fixed set of data is generated each epoch.
"""
if self.settings.use_fixed_set and batch_index in self.fixed_set:
return self.fixed_set[batch_index]
if self.settings.random_seed is not None and batch_index == 0:
np.random.seed(self.settings.random_seed)
spectrum_pairs = self._spectrum_pair_generator(batch_index)
spectra_1, spectra_2, meta_1, meta_2, targets = self._tensorize_all(spectrum_pairs)

if self.settings.use_fixed_set:
# Store batches for later epochs
self.fixed_set[batch_index] = (spectra_1, spectra_2, meta_1, meta_2, targets)
else:
spectra_1 = self._data_augmentation(spectra_1)
spectra_2 = self._data_augmentation(spectra_2)
return spectra_1, spectra_2, meta_1, meta_2, targets

def _tensorize_all(self, spectrum_pairs):
spectra_1 = []
spectra_2 = []
targets = []
for pair in spectrum_pairs:
spectra_1.append(pair[0])
spectra_2.append(pair[1])
targets.append(pair[2])

binned_spectra_1, metadata_1 = tensorize_spectra(
spectra_1,
self.metadata_vectorizer,
self.min_mz, self.max_mz,
self.mz_bin_width, self.intensity_scaling
)
binned_spectra_2, metadata_2 = tensorize_spectra(
spectra_2,
self.metadata_vectorizer,
self.min_mz, self.max_mz,
self.mz_bin_width, self.intensity_scaling
)
return binned_spectra_1, binned_spectra_2, metadata_1, metadata_2, torch.tensor(targets, dtype=torch.float32)

def _get_spectrum_with_inchikey(self, inchikey: str) -> Spectrum:
"""
Get a random spectrum matching the `inchikey` argument.
NB: A compound (identified by an
inchikey) can have multiple measured spectrums in a binned spectrum dataset.
"""
matching_spectrum_id = np.where(self.spectrum_inchikeys == inchikey)[0]
if len(matching_spectrum_id) <= 0:
raise ValueError("No matching inchikey found (note: expected first 14 characters)")
return self.spectrums[np.random.choice(matching_spectrum_id)]

def _data_augmentation(self, spectra_tensors):
for i in range(spectra_tensors.shape[0]):
spectra_tensors[i, :] = self._data_augmentation_spectrum(spectra_tensors[i, :])
return spectra_tensors

def _data_augmentation_spectrum(self, spectrum_tensor):
"""Data augmentation.
Parameters
----------
spectrum_tensor
Spectrum in Pytorch tensor form.
"""
# Augmentation 1: peak removal (peaks < augment_removal_max)
if self.settings.augment_removal_max or self.settings.augment_removal_intensity:
# TODO: Factor out function with documentation + example?

indices_select = torch.where((spectrum_tensor > 0)
& (spectrum_tensor < self.settings.augment_removal_max))[0]
removal_part = np.random.random(1) * self.settings.augment_removal_max
indices = np.random.choice(indices_select, int(np.ceil((1 - removal_part)*len(indices_select))))
if len(indices) > 0:
spectrum_tensor[indices] = 0

# Augmentation 2: Change peak intensities
if self.settings.augment_intensity:
# TODO: Factor out function with documentation + example?
spectrum_tensor = spectrum_tensor * (1 - self.settings.augment_intensity * 2 * (torch.rand(spectrum_tensor.shape) - 0.5))

# Augmentation 3: Peak addition
if self.settings.augment_noise_max and self.settings.augment_noise_max > 0:
indices_select = torch.where(spectrum_tensor == 0)[0]
if len(indices_select) > self.settings.augment_noise_max:
indices_noise = np.random.choice(indices_select,
np.random.randint(0, self.settings.augment_noise_max),
replace=False,
)
spectrum_tensor[indices_noise] = self.settings.augment_noise_intensity * torch.rand(len(indices_noise))
return spectrum_tensor


def tensorize_spectra(
spectra,
metadata_vectorizer,
min_mz,
max_mz,
mz_bin_width,
intensity_scaling
):
"""Convert list of matchms Spectrum objects to pytorch peak and metadata tensors.
"""
# pylint: disable=too-many-arguments
num_bins = int((max_mz - min_mz) / mz_bin_width)
if metadata_vectorizer is None:
metadata_tensors = torch.zeros((len(spectra), 0))
else:
metadata_tensors = metadata_vectorizer.transform(spectra)

binned_spectra = torch.zeros((len(spectra), num_bins))
for i, spectrum in enumerate(spectra):
binned_spectra[i, :] = torch.tensor(vectorize_spectrum(spectrum.peaks.mz, spectrum.peaks.intensities,
min_mz, max_mz, mz_bin_width, intensity_scaling
))
return binned_spectra, metadata_tensors


@numba.jit(nopython=True)
def vectorize_spectrum(mz_array, intensities_array, min_mz, max_mz, mz_bin_width, intensity_scaling):
"""Fast function to convert mz and intensity arrays into dense spectrum vector."""
# pylint: disable=too-many-arguments
num_bins = int((max_mz - min_mz) / mz_bin_width)
vector = np.zeros((num_bins))
for mz, intensity in zip(mz_array, intensities_array):
if min_mz <= mz < max_mz:
bin_index = int((mz - min_mz) / mz_bin_width)
# Take max intensity peak per bin
vector[bin_index] = max(vector[bin_index], intensity ** intensity_scaling)
# Alternative: Sum all intensties for all peaks in each bin
# vector[bin_index] += intensity ** intensity_scaling
return vector


class SpectrumPair(NamedTuple):
"""
Represents a pair of binned spectrums
Expand Down
Loading

0 comments on commit 8f3a1fe

Please sign in to comment.