Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve inchikey pair selection and data generators #232

Open
wants to merge 76 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
0819cf7
Add test for issue
niekdejonge Aug 29, 2024
11325ad
Skipp cases where no pair is available
niekdejonge Aug 29, 2024
9a7b518
Rename to current_batch_index
niekdejonge Aug 29, 2024
6191e15
Add shuffling option to SelectedCompoundNames.generator
niekdejonge Aug 29, 2024
c23290c
Update tests to test shuffling option of SelectedCompoundNames.generator
niekdejonge Aug 29, 2024
560fb6a
Use CompoundPairSelector.generator() in DataGeneratorPytorch, results…
niekdejonge Aug 29, 2024
227fd01
Update test_DataGeneratorPytorch, since the length of an epoch is now…
niekdejonge Aug 29, 2024
363e6da
linting
niekdejonge Aug 29, 2024
5726542
Undo removal of StopIteration and fix actual linting issue by adding …
niekdejonge Aug 29, 2024
40cc133
Separate nr_of_batches from length
niekdejonge Aug 29, 2024
4c93afb
Added test for equal inchikey distribution (which currently fails)
niekdejonge Sep 2, 2024
ba71d7f
Improve documentation and code readability of spectrum_pair_selection.py
niekdejonge Sep 2, 2024
d7b6eaf
Rename spectrum_pair_selection to inchikey_pair_selection.py
niekdejonge Sep 2, 2024
d525805
Add new methods for balanced spectrum pair selection
niekdejonge Sep 2, 2024
7e263c1
Add SelectedInchikeyPairs
niekdejonge Sep 2, 2024
a48155c
Fix bug selecting only one spectrum per inchikey (test still needs to…
niekdejonge Sep 2, 2024
b761419
Use inchikeys instead of indexes in available_inchikey_counts
niekdejonge Sep 2, 2024
48b21de
Remove spectrums as output from select_compound_pairs_wrapper in test
niekdejonge Sep 2, 2024
64b0477
Make test for inchikey balance less strict
niekdejonge Sep 2, 2024
88c574d
Switch order of generator output
niekdejonge Sep 3, 2024
b065144
Remove SelectedCompoundPairs and replace tests with tests for Selecte…
niekdejonge Sep 3, 2024
0a93443
Fix test model training (still expected spectra as output from select…
niekdejonge Sep 3, 2024
d92b466
Remove outdated test and add test for balanced inchikeys for selectin…
niekdejonge Sep 3, 2024
9c22063
Add methods for calculating inchikey counts to SelectedInchikeyPairs …
niekdejonge Sep 3, 2024
1836d8c
Only select each pair once in convert_selected_pairs_matrix (remove r…
niekdejonge Sep 3, 2024
38c6be2
Change inchikey pair selection algorithm to also select least frequen…
niekdejonge Sep 3, 2024
3fe3e2b
Move create test spectra
niekdejonge Sep 3, 2024
ce2273f
Add test for balanced score when selecting inchikey pairs
niekdejonge Sep 3, 2024
a6abf48
Add test to check that there are no repeating pairs in when selecting…
niekdejonge Sep 3, 2024
4b9caab
Add docstring to test
niekdejonge Sep 3, 2024
65f148d
Add new tests for data generators
niekdejonge Sep 3, 2024
38b4def
Remove prepare_folders_and_generators, to make more modular.
niekdejonge Sep 3, 2024
ca5abd4
Add test for create_data_generator
niekdejonge Sep 3, 2024
44c2f7c
Fix test_model_training
niekdejonge Sep 3, 2024
272c1c6
Remove spectra selected as output from compute_fingerprints_for_training
niekdejonge Sep 3, 2024
3fa6fb2
linting
niekdejonge Sep 3, 2024
b1cf326
linting
niekdejonge Sep 3, 2024
5ddccfa
Remove unused import
niekdejonge Sep 4, 2024
a7a79ed
Update CHANGELOG.md
niekdejonge Sep 4, 2024
2bbcf70
Add general description of pairs sampling in DataGeneratorPytorch doc…
niekdejonge Sep 4, 2024
16ced2a
Return inchikey_counts from select_pairs
niekdejonge Sep 5, 2024
db8789d
Change convert_selected_pairs_matrix for speed optimization (was extr…
niekdejonge Sep 12, 2024
a6bde44
Fix typo
niekdejonge Sep 16, 2024
2f37e1c
Add docstring
niekdejonge Sep 16, 2024
d2d05b0
Add resampling as option
niekdejonge Sep 16, 2024
ef26b8c
Add max_pair_resampling to SettingsMS2Deepscore.py
niekdejonge Sep 16, 2024
7556b89
Add validation_function that checks that same_prob_bins are submitted…
niekdejonge Sep 17, 2024
ec7acb9
Change the inchikey_pair_selection to always select bins by > bin[0] …
niekdejonge Sep 17, 2024
4338a58
Don't validate settings when loading the model (to reduce unnecessary…
niekdejonge Sep 17, 2024
454b204
Update bins in tests to start at -0.01
niekdejonge Sep 17, 2024
9770d01
Remove the sorting bases on lowest number of pairs and instead use th…
niekdejonge Sep 17, 2024
d70fe99
Restructure tests in test_inchikey_pair_selection.py and add tests fo…
niekdejonge Sep 17, 2024
510eedc
Add save as json
niekdejonge Sep 23, 2024
ed4d7c3
Add description to tqdm
niekdejonge Sep 23, 2024
0e9cde0
Add saving train_generator pairs as json in train_ms2ds_model
niekdejonge Sep 23, 2024
74ac106
Remove select one spectrum per inchikey function, since redundant
niekdejonge Sep 23, 2024
9d00498
Fix create_test_spectra.py
niekdejonge Sep 23, 2024
f103696
Fix ValidationLossCalculator.py for equal multiple spectra per inchik…
niekdejonge Sep 23, 2024
0599a7f
Fix tests that expected 2 spectra per inchikey
niekdejonge Sep 23, 2024
04faf50
Optimized speed by working with numpy matrixes
niekdejonge Sep 24, 2024
828fae8
Remove the need for convert_selected_pairs_matrix in tests
niekdejonge Sep 24, 2024
1648799
Remove unused import
niekdejonge Sep 24, 2024
c222380
Remove unused functions
niekdejonge Sep 24, 2024
77e7526
Improve progress bars.
niekdejonge Sep 24, 2024
bfe25e2
Add progress bar when loading in spectra
niekdejonge Sep 24, 2024
a67e5b0
Remove unused conversion to coo arrays and adjust tests that still re…
niekdejonge Sep 25, 2024
66aefb8
Rename SelectedInchikeyPairs to InchikeyPairGenerator
niekdejonge Sep 25, 2024
c404243
move inchikeyPairGenerator to data_generators.py
niekdejonge Sep 25, 2024
1aad8bc
Rename DataGeneratorPytorch to SpectrumPairGenerator
niekdejonge Sep 25, 2024
3b95b32
reformatting file
niekdejonge Sep 25, 2024
d20e23a
Reordering function order and adding docstrings
niekdejonge Sep 25, 2024
fbd8bcc
Add docstring to select_balanced_pairs
niekdejonge Sep 25, 2024
57252fb
Change output of select compound pairs wrapper to list of pairs inste…
niekdejonge Sep 25, 2024
7022b47
Fixing prospector warnings
niekdejonge Sep 25, 2024
3d3fe5b
Remove unreliable test.
niekdejonge Sep 25, 2024
4970a41
Fix sonarcloud issue
niekdejonge Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### fixed
- A bug of spectrum pair sampling during training was fixed. Due to this bug for each spectrum only one unique spectrum was sampled, even if multiple spectra were available. The bug was introduced with MS2Deepscore 2.0

### Changed
- The inchikey pair selection and data generator has been refactored. The new data generator results in a more balanced inchikey distribution. For details see [#232](https://github.com/matchms/ms2deepscore/pull/232)

### Changed
- dense layers are not build with leaky ReLU instead of ReLU [#222](https://github.com/matchms/ms2deepscore/pull/222).

Expand Down
53 changes: 49 additions & 4 deletions ms2deepscore/SettingsMS2Deepscore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import warnings
from collections import Counter
from datetime import datetime
from json import JSONEncoder
from typing import Optional
Expand Down Expand Up @@ -90,8 +91,11 @@ class SettingsMS2Deepscore:
Array of metadata entries (and their transformation) to be used in training.
See `MetadatFeatureGenerator` for more information.
Default is set to empty list.
max_pair_resampling
The maximum number a inchikey pair can be resampled. Resampling is done to balance inchikey pairs over
the tanimoto scores. The minimum is 1, meaning that no resampling is performed.
"""
def __init__(self, **settings):
def __init__(self, validate_settings=True, **settings):
# model structure
self.base_dims = (2000, 2000, 2000)
self.embedding_dim = 400
Expand Down Expand Up @@ -133,10 +137,12 @@ def __init__(self, **settings):
# Compound pairs selection settings
self.average_pairs_per_bin = 20
self.max_pairs_per_bin = 100
self.same_prob_bins = np.array([(x / 10, x / 10 + 0.1) for x in range(0, 10)])
self.same_prob_bins = np.array([(0.8, 0.9), (0.7, 0.8), (0.9, 1.0), (0.6, 0.7), (0.5, 0.6),
(0.4, 0.5), (0.3, 0.4), (0.2, 0.3), (0.1, 0.2), (-0.01, 0.1)])
self.include_diagonal = True
self.val_spectra_per_inchikey = 1
self.random_seed: Optional[int] = None
self.max_pair_resampling = 1

# Tanimioto score setings
self.fingerprint_type: str = "daylight"
Expand All @@ -158,9 +164,15 @@ def __init__(self, **settings):
f"the type given is {type(value)}, the value given is {value}")
setattr(self, key, value)
else:
raise ValueError(f"Unknown setting: {key}")
if validate_settings:
raise ValueError(f"Unknown setting: {key}")
# When loading an older model, there can be incompatibilities between training settings.
# If these settings were just used during training it should not break the loading of a model,
# since it does not affect how the model runs.
setattr(self, key, value)

self.validate_settings()
if validate_settings:
self.validate_settings()
if self.random_seed is not None:
np.random.seed(self.random_seed)

Expand All @@ -174,6 +186,7 @@ def validate_settings(self):
assert isinstance(self.random_seed, int), "Random seed must be integer number."
if self.loss_function.lower() not in LOSS_FUNCTIONS:
raise ValueError(f"Unknown loss function. Must be one of: {LOSS_FUNCTIONS.keys()}")
validate_bin_order(self.same_prob_bins)

def number_of_bins(self):
return int((self.max_mz - self.min_mz) / self.mz_bin_width)
Expand All @@ -192,6 +205,38 @@ def default(self, o):
json.dump(self.__dict__, file, indent=4, cls=NumpyArrayEncoder)


def validate_bin_order(score_bins):
"""Checks that the given bins are of the correct format

The bins should cover everything between 0 and 1.0 and the lowest bin should be below 0
(since pairs > are selected and we want to include zero)"""
# check that the correct same_prob_bins are selected
bin_borders_below_zero = 0
bin_borders_1 = 0
not_starting_or_ending_borders = []
for score_bin in score_bins:
if score_bin[0] > score_bin[1]:
raise ValueError("The first number in the bin should be smaller than the second")
for bin_border in score_bin:
if bin_border < 0:
bin_borders_below_zero += 1
elif bin_border == 1:
bin_borders_1 += 1
else:
not_starting_or_ending_borders.append(bin_border)
border_counts = Counter(not_starting_or_ending_borders)
if bin_borders_below_zero != 1:
raise ValueError(f"There should be one bin border with a value below 0. "
f"But {bin_borders_below_zero} bin borders with value below 0 are found")
if bin_borders_1 != 1:
raise ValueError(
f"There should be one bin border with value 1. "
f"But {bin_borders_below_zero} bin borders with value 1 are found")
for count in border_counts.values():
if count != 2:
raise ValueError("There is a gap in the bins, the bins should cover everything between 0 and 1.")


class SettingsEmbeddingEvaluator:
"""Contains all the settings used for training a EmbeddingEvaluator model.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm
from ms2deepscore import MS2DeepScore
from ms2deepscore.models import load_model
from ms2deepscore.train_new_model.spectrum_pair_selection import \
from ms2deepscore.train_new_model.inchikey_pair_selection import \
select_inchi_for_unique_inchikeys
from ms2deepscore.utils import save_pickled_file

Expand Down
2 changes: 1 addition & 1 deletion ms2deepscore/models/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def load_model(filename: Union[str, Path]) -> SiameseSpectralModel:
model_params = model_settings["model_params"]

# Instantiate the SiameseSpectralModel with the loaded parameters
model = SiameseSpectralModel(settings=SettingsMS2Deepscore(**model_params))
model = SiameseSpectralModel(settings=SettingsMS2Deepscore(**model_params, validate_settings=False))
model.load_state_dict(model_settings["model_state_dict"])
model.eval()
return model
Expand Down
25 changes: 11 additions & 14 deletions ms2deepscore/train_new_model/ValidationLossCalculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,21 @@ def select_spectra_per_inchikey(spectra,
spectra_per_inchikey: int = 1):
"""Pick spectra_per_inchikey spectra for every unique inchikey14 (when possible).
"""
if spectra_per_inchikey < 1:
raise ValueError
inchikeys14_array = np.array([s.get("inchikey")[:14] for s in spectra])
unique_inchikeys = np.unique(inchikeys14_array)
rng = np.random.default_rng(seed=random_seed)
selected_spectra = []
for inchikey in unique_inchikeys:
matching_spectra_idx = np.where(inchikeys14_array == inchikey)[0]
if (spectra_per_inchikey > 1) & (spectra_per_inchikey <= len(matching_spectra_idx)):
spectrum_id = rng.choice(matching_spectra_idx, spectra_per_inchikey, replace=False)
selected_spectra.extend([spectra[i] for i in spectrum_id])
else:
spectrum_id = rng.choice(matching_spectra_idx)
selected_spectra.append(spectra[spectrum_id])
if len(matching_spectra_idx) == 0:
raise ValueError("Expected at least one spectrum per inchikey")
selected_spectrum_ids = []
for i in range(int(spectra_per_inchikey//len(matching_spectra_idx))):
selected_spectrum_ids.extend(list(matching_spectra_idx))
additional_spectrum_ids = rng.choice(matching_spectra_idx, spectra_per_inchikey%len(matching_spectra_idx),
replace=False)
selected_spectrum_ids.extend(additional_spectrum_ids)
selected_spectra.extend([spectra[i] for i in selected_spectrum_ids])
return selected_spectra


def select_one_spectrum_per_inchikey(spectra,
random_seed):
return select_spectra_per_inchikey(
spectra,
random_seed,
spectra_per_inchikey = 1)
9 changes: 4 additions & 5 deletions ms2deepscore/train_new_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from .data_generators import DataGeneratorPytorch
from .spectrum_pair_selection import (SelectedCompoundPairs,
select_compound_pairs_wrapper)
from .data_generators import SpectrumPairGenerator, InchikeyPairGenerator
from .inchikey_pair_selection import (select_compound_pairs_wrapper)


__all__ = [
"DataGeneratorPytorch",
"SpectrumPairGenerator",
"select_compound_pairs_wrapper",
"SelectedCompoundPairs",
"InchikeyPairGenerator"
]
Loading
Loading