Skip to content

Commit

Permalink
Fix model selection: Behavior was not consistent with interface (#21)
Browse files Browse the repository at this point in the history
Signed-off-by: Christopher Schröder <chschroeder@users.noreply.github.com>
  • Loading branch information
chschroeder committed Oct 8, 2022
1 parent 5f507e7 commit 67ae83a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
3 changes: 0 additions & 3 deletions small_text/integrations/pytorch/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ def _save_model(self, optimizer, model_selection, model_id, train_acc, train_los

def _perform_model_selection(self, optimizer, model_selection):
model_selection_result = model_selection.select()
# TODO: can we test the load_state_dict() calls here?
if model_selection_result is not None:
# this currently does not check if this model is the last one (and thus does not need
# to be reloaded
self.model.load_state_dict(torch.load(model_selection_result.model_path))
optimizer_path = model_selection_result.model_path.with_suffix('.pt.optimizer')
optimizer.load_state_dict(torch.load(optimizer_path))
Expand Down
7 changes: 4 additions & 3 deletions small_text/training/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,11 @@ def select(self, select_by=None):
else:
select_by = self.default_select_by

# valid rows are rows where no early stopping has been triggered
valid_rows = np.not_equal(self.models[ModelSelection.FIELD_NAME_EARLY_STOPPING], True)
if not np.any(valid_rows): # checks if we have no valid rows
return None

rows = self.models[valid_rows]

metrics_dict = {metric.name: metric for metric in self.metrics}
Expand All @@ -228,9 +232,6 @@ def select(self, select_by=None):
)
indices = np.lexsort(tuples)

if indices.shape[0] == 0:
return ModelSelectionResult(0, self.last_model_id, None, {})

model_id = rows['model_id'][indices[0]]
model_path = rows['model_path'][indices[0]]

Expand Down
42 changes: 42 additions & 0 deletions tests/unit/small_text/training/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,33 @@ def test_add_model(self):
self.assertEqual((1,), model_selection.models.shape)
self.assertEqual(model_id, model_selection.last_model_id)

def test_select_with_single_model(self):
model_selection = ModelSelection()
measured_values_list = [
{'val_loss': 0.043, 'val_acc': 0.78, 'train_loss': 0.023, 'train_acc': 0.85}
]
for i, measured_values in enumerate(measured_values_list):
model_selection.add_model(str(i+1), f'/any/path/to/model_{i+1}.bin',
measured_values)

model_selection_result = model_selection.select(select_by='val_loss')

self.assertEqual('1', model_selection_result.model_id)
self.assertEqual('/any/path/to/model_1.bin', model_selection_result.model_path)

self.assertEqual(4, len(model_selection_result.measured_values))
for key, val in measured_values_list[0].items():
self.assertEqual(val, model_selection_result.measured_values[key])

self.assertEqual(1, len(model_selection_result.fields))
self.assertFalse(model_selection_result.fields[ModelSelection.FIELD_NAME_EARLY_STOPPING])

def test_select_without_model(self):
model_selection = ModelSelection()

model_selection_result = model_selection.select(select_by='val_loss')
self.assertIsNone(model_selection_result)

def test_add_model_missing_metrics(self):
model_selection = ModelSelection()
# val_loss is missing
Expand Down Expand Up @@ -240,6 +267,21 @@ def test_select_with_early_stopping(self):
self.assertEqual(1, len(model_selection_result.fields))
self.assertFalse(model_selection_result.fields[ModelSelection.FIELD_NAME_EARLY_STOPPING])

def test_select_with_early_stopping_and_single_model(self):
model_selection = ModelSelection()
measured_values_list = [
{'val_loss': 0.043, 'val_acc': 0.78, 'train_loss': 0.023, 'train_acc': 0.85}
]
fields = [
{ModelSelection.FIELD_NAME_EARLY_STOPPING: True}
]
for i, measured_values in enumerate(measured_values_list):
model_selection.add_model(str(i+1), f'/any/path/to/model_{i+1}.bin',
measured_values, fields=fields[i])

model_selection_result = model_selection.select(select_by='val_loss')
self.assertIsNone(model_selection_result)

def test_select_with_early_stopping_select_by_second_metric(self):
model_selection = ModelSelection()
measured_values_list = [
Expand Down

0 comments on commit 67ae83a

Please sign in to comment.