Skip to content

Commit

Permalink
- minor fix in naming class variables of callback "NES.utils.BestWeig…
Browse files Browse the repository at this point in the history
…hts"
  • Loading branch information
sgrubas committed Dec 30, 2023
1 parent 3332746 commit 88073af
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions NES/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,34 +290,35 @@ def get_monitor_value(self, logs):
return monitor_value


class NES_BestWeights(tf.keras.callbacks.Callback):
class BestWeights(tf.keras.callbacks.Callback):
"""Keeps only the best weights of the model.
Args:
NES: NES instance
nn_model: neural network model which weights will be watched
monitor: Quantity to be monitored ('loss' or 'val_loss'). By default, 'loss'
freq: how often save weights
verbose: verbosity mode (0 or 1).
conversion: conversion function from "loss" to "RMAE"
"""

def __init__(self,
NES,
nn_model,
monitor='loss',
freq=50,
verbose=1,
conversion=lambda x: x * 10**(-0.16),

):
super(NES_BestWeights, self).__init__()
super(BestWeights, self).__init__()
assert monitor == 'loss' or monitor == 'val_loss', \
"Only 'loss' and 'val_loss' are supported for monitor metric"

self.model = NES
self.model = nn_model
self.monitor = monitor
self.verbose = verbose
self.freq = freq
self.best_monitor = np.inf
self.best_rmae = np.inf
self.best_epoch = None
self.best_weights = None
self.conversion = conversion
Expand Down

0 comments on commit 88073af

Please sign in to comment.