Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensorboard #209

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ build
dist
.cache
__pycache__
runs

htmlcov
.coverage
Expand Down
2 changes: 2 additions & 0 deletions ms2deepscore/SettingsMS2Deepscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, **settings):
self.patience = 30
self.loss_function = "mse"
self.weighting_factor = 0
self.use_tensorboard = True
self.log_dir = "runs"

# Folder names for storing
self.model_file_name = "ms2deepscore_model.pt"
Expand Down
55 changes: 44 additions & 11 deletions ms2deepscore/models/SiameseSpectralModel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from ms2deepscore.__version__ import __version__
from ms2deepscore.models.helper_functions import (initialize_device,
Expand Down Expand Up @@ -147,6 +149,24 @@ def forward(self, spectra_tensors, metadata_tensors):
return x


def initialize_training(model, learning_rate, use_tensorboard, log_dir="runs"):
"""Initializes device (cpu or gpu) as well as the optimizer and Tensorboard writer.
"""
device = initialize_device()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

if use_tensorboard:
# TensorBoard writer
if not os.path.exists(log_dir):
os.makedirs(log_dir)

writer = SummaryWriter(log_dir)
else:
writer = None
return device, optimizer, writer


def train(model: SiameseSpectralModel,
data_generator,
num_epochs: int,
Expand All @@ -157,11 +177,13 @@ def train(model: SiameseSpectralModel,
checkpoint_filename: str = None,
loss_function="MSE",
weighting_factor=0,
monitor_rmse: bool = True,
collect_all_targets: bool = False,
lambda_l1: float = 0,
lambda_l2: float = 0,
progress_bar: bool = True):
progress_bar: bool = True,
use_tensorboard: bool = True,
log_dir: str = "runs",
):
"""Train a model with given parameters.

Parameters
Expand All @@ -186,8 +208,6 @@ def train(model: SiameseSpectralModel,
Pass a loss function (e.g. a pytorch default or a custom function).
weighting_factor
Default is set to 0, set to value between 0 and 1 to shift attention to higher target scores.
monitor_rmse
If True rmse will be monitored turing training.
collect_all_targets
If True, all training targets will be collected (e.g. for later statistics).
lambda_l1
Expand All @@ -196,15 +216,12 @@ def train(model: SiameseSpectralModel,
L2 regularization strength.
"""
# pylint: disable=too-many-arguments, too-many-locals
device = initialize_device()
model.to(device)
device, optimizer, writer = initialize_training(model, learning_rate, use_tensorboard, log_dir=log_dir)

if loss_function.lower() not in LOSS_FUNCTIONS:
raise ValueError(f"Unknown loss function. Must be one of: {LOSS_FUNCTIONS.keys()}")
criterion = LOSS_FUNCTIONS[loss_function.lower()]

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

history = {
"losses": [],
"val_losses": [],
Expand Down Expand Up @@ -237,8 +254,7 @@ def train(model: SiameseSpectralModel,
loss += l1_regularization(model, lambda_l1) + l2_regularization(model, lambda_l2)
batch_losses.append(float(loss))

if monitor_rmse:
batch_rmse.append(rmse_loss(outputs, targets.to(device)).cpu().detach().numpy())
batch_rmse.append(rmse_loss(outputs, targets).cpu().detach().numpy())

# Backward pass and optimize
loss.backward()
Expand All @@ -249,15 +265,29 @@ def train(model: SiameseSpectralModel,
loss=float(loss),
rmse=np.mean(batch_rmse),
)
# Monitor
avg_loss = np.mean(batch_losses)
avg_rmse = np.mean(batch_rmse)
if use_tensorboard:
writer.add_scalar('LOSS/train', avg_loss, epoch)
writer.add_scalar('RMSE/train', avg_rmse, epoch)
writer.flush()

history["losses"].append(np.mean(batch_losses))
history["rmse"].append(np.mean(batch_rmse))

if validation_loss_calculator is not None:
val_losses = validation_loss_calculator.compute_binned_validation_loss(model,
loss_types=(loss_function, "rmse"))
val_loss = val_losses[loss_function]

# Monitor
history["val_losses"].append(val_loss)
history["val_rmse"].append(val_losses["rmse"])
if use_tensorboard:
writer.add_scalar('LOSS/val', avg_loss, epoch)
writer.add_scalar('RMSE/val', avg_rmse, epoch)
writer.flush()
if val_loss < min_val_loss:
if checkpoint_filename:
print("Saving checkpoint model.")
Expand All @@ -271,9 +301,12 @@ def train(model: SiameseSpectralModel,
break

# Print statistics
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {np.mean(batch_losses):.4f}")
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
if validation_loss_calculator is not None:
print(f"Validation Loss: {val_loss:.4f} (RMSE: {val_losses['rmse']:.4f}).")

if use_tensorboard:
writer.close()
return history


Expand Down
5 changes: 4 additions & 1 deletion ms2deepscore/train_new_model/train_ms2deepscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def train_ms2ds_model(
validation_loss_calculator=validation_loss_calculator,
patience=settings.patience,
loss_function=settings.loss_function,
checkpoint_filename=output_model_file_name, lambda_l1=0, lambda_l2=0)
checkpoint_filename=output_model_file_name, lambda_l1=0, lambda_l2=0,
use_tensorboard=settings.use_tensorboard,
log_dir=settings.log_dir,
)
# Save plot of history
plot_history(history["losses"], history["val_losses"], ms2ds_history_plot_file_name)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"numpy>=1.20.3",
"pandas",
"scikit-learn",
"tensorboard",
"torch",
"tqdm",
"matplotlib==3.7.2"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_training_wrapper_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def test_train_wrapper_ms2ds_model(tmp_path):
"same_prob_bins": np.array([(0, 0.2), (0.2, 1.0)]),
"average_pairs_per_bin": 2,
"batch_size": 2, # to speed up tests --> usually larger
"random_seed": 42
"random_seed": 42,
"use_tensorboard": False
})

model_directory_name = train_ms2deepscore_wrapper(spectra_file_name, settings, validation_split_fraction=5)
Expand Down
Loading