Skip to content

Commit

Permalink
Merge branch 'v1.3.x' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
chschroeder committed May 4, 2024
2 parents 893dae6 + 668c326 commit 733f455
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 0 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ On the other hand, this also allowed us to deal with further issues that contain

---

## Version 1.4.0 - unreleased

### Fixed

- Changed the way how the seed is controlled in `SetFitClassification` since the seed was fixed unless explicitly set via the respective trainer keyword argument.

### Changed

- Documentation: Updated showcase section.

---

## Version 1.3.3 - 2023-12-29

### Changed
Expand Down
2 changes: 2 additions & 0 deletions small_text/integrations/transformers/classifiers/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,14 @@ def _get_train_and_valid_sets(self, x_train, y_train, x_valid, y_valid):
return sub_train, sub_valid

def _fit(self, sub_train, sub_valid, setfit_train_kwargs):
seed = np.random.randint(2**32-1)
trainer = SetFitTrainer(
self.model,
sub_train,
eval_dataset=sub_valid,
batch_size=self.mini_batch_size,
use_amp=self.amp_args.use_amp, # TODO: device_type and dtype are not used
seed=seed,
**self.trainer_kwargs
)
if not 'show_progress_bar' in setfit_train_kwargs:
Expand Down
6 changes: 6 additions & 0 deletions small_text/integrations/transformers/utils/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ def _check_trainer_kwargs(trainer_kwargs):
raise ValueError('Invalid keyword argument in trainer_kwargs: '
'Argument "batch_size" can be set via "mini_batch_size" in '
'SetFitClassification.')

if 'show_progress_bar' in trainer_kwargs:
raise ValueError('Invalid keyword argument in trainer_kwargs: '
'Argument "show_progress_bar" can be controlled via "verbosity" in '
'SetFitClassification.')

if 'seed' in trainer_kwargs:
raise ValueError('Invalid keyword argument in trainer_kwargs: '
'Argument "seed" cannot be set via train_kwargs.')

return trainer_kwargs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,26 @@ def test_fit_with_non_default_settings(self):
self.assertEqual(1, train_mock.call_count)
self.assertEqual(max_seq_len, train_mock.call_args_list[0].kwargs['max_length'])

def test_fit_prevent_fixed_seed(self):
ds = twenty_news_text(10, num_classes=self.num_classes, multi_label=self.multi_label)
num_classes = 5

setfit_model_args = SetFitModelArguments('sentence-transformers/all-MiniLM-L6-v2')
setfit_train_kwargs = {'show_progress_bar': False}

with patch('setfit.trainer.set_seed') as set_seed_mock:
clf = SetFitClassification(setfit_model_args, num_classes, multi_label=self.multi_label)

clf.fit(ds, setfit_train_kwargs=setfit_train_kwargs)
self.assertEqual(1, set_seed_mock.call_count)
first_seed = set_seed_mock.call_args_list[0][0]

clf.fit(ds, setfit_train_kwargs=setfit_train_kwargs)
self.assertEqual(2, set_seed_mock.call_count)
second_seed = set_seed_mock.call_args_list[1][0]

self.assertNotEqual(first_seed, second_seed)


@pytest.mark.pytorch
@pytest.mark.optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ def test_init_with_misplaced_show_progress_bar_kwargs(self):
with self.assertRaisesRegex(ValueError, 'Invalid keyword argument in trainer_kwargs'):
SetFitClassification(setfit_model_args, num_classes, trainer_kwargs=trainer_kwargs)

def test_init_with_misplaced_seed_kwargs(self):
setfit_model_args = SetFitModelArguments('sentence-transformers/all-MiniLM-L6-v2')
num_classes = 5

trainer_kwargs = {'seed': 4242}

with self.assertRaisesRegex(ValueError, 'Invalid keyword argument in trainer_kwargs'):
SetFitClassification(setfit_model_args, num_classes, trainer_kwargs=trainer_kwargs)


class _SetFitClassification(object):

Expand Down

0 comments on commit 733f455

Please sign in to comment.