Skip to content

Commit

Permalink
Unify initialize() methods of pytorch classifiers (#57)
Browse files Browse the repository at this point in the history
Signed-off-by: Christopher Schröder <chschroeder@users.noreply.github.com>
  • Loading branch information
chschroeder committed May 12, 2024
1 parent 733f455 commit 80bc1e2
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 23 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ On the other hand, this also allowed us to deal with further issues that contain

- General
- Moved `split_data()` method from `small_text.data.datasets` to `small_text.data.splits`.

- Classification:
- The `initialize()` methods of all PyTorch-classifiers (KimCNN, TransformerBasedClassification, SetFitClassification) are now more unified. ([#57](https://github.com/webis-de/small-text/issues/57))
- Utils
- `init_kmeans_plusplus_safe()` now supports weighted kmeans++ initialization for `scikit-learn>=1.3.0`.

Expand Down
5 changes: 3 additions & 2 deletions small_text/integrations/pytorch/classifiers/kimcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def fit(self, train_set, validation_set=None, weights=None, early_stopping=None,
def _fit_main(self, sub_train, sub_valid, weights, early_stopping, model_selection,
optimizer, scheduler):
if self.model is None:
self.initialize_kimcnn_model()
self.initialize()

_check_optimizer_and_scheduler_config(optimizer, scheduler)
scheduler = scheduler if scheduler is not None else None
Expand All @@ -370,7 +370,7 @@ def _fit_main(self, sub_train, sub_valid, weights, early_stopping, model_selecti

return self

def initialize_kimcnn_model(self):
def initialize(self):
vocab_size = self.embedding_matrix.shape[0]
embed_dim = self.embedding_matrix.shape[1]
self.model = KimCNN(vocab_size, self.max_seq_len, num_classes=self.num_classes,
Expand All @@ -381,6 +381,7 @@ def initialize_kimcnn_model(self):
kernel_heights=self.kernel_heights)

self.model = _compile_if_possible(self.model, compile_model=self.compile_model)
return self.model

def _default_optimizer(self, base_lr):
params = [param for param in self.model.parameters() if param.requires_grad]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def fit(self, train_set, validation_set=None, weights=None, early_stopping=None,
def _fit_main(self, sub_train, sub_valid, weights, early_stopping, model_selection,
optimizer, scheduler):
if self.model is None:
self.initialize_transformer(self.cache_dir)
self.initialize()

_check_optimizer_and_scheduler_config(optimizer, scheduler)
scheduler = scheduler if scheduler is not None else 'linear'
Expand All @@ -405,14 +405,14 @@ def _fit_main(self, sub_train, sub_valid, weights, early_stopping, model_selecti

return self

def initialize_transformer(self, cache_dir):

def initialize(self):
self.config, self.tokenizer, self.model = _initialize_transformer_components(
self.transformer_model,
self.num_classes,
cache_dir,
self.cache_dir,
)
self.model = _compile_if_possible(self.model, self.transformer_model.compile_model)
return self.model

def _default_optimizer(self, base_lr):

Expand Down
7 changes: 3 additions & 4 deletions small_text/integrations/transformers/classifiers/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,23 +295,22 @@ def _fit(self, sub_train, sub_valid, setfit_train_kwargs):
return self

def initialize(self):
# TODO: make sure the initialize() methods of all classifiers are similar
from_pretrained_options = _get_arguments_for_from_pretrained_model(
self.setfit_model_args.model_loading_strategy
)
model_kwargs = self.model_kwargs.copy()
if self.multi_label and 'multi_target_strategy' not in model_kwargs:
model_kwargs['multi_target_strategy'] = 'one-vs-rest'

model = SetFitModel.from_pretrained(
self.model = SetFitModel.from_pretrained(
self.setfit_model_args.sentence_transformer_model,
use_differentiable_head=self.use_differentiable_head,
force_download=from_pretrained_options.force_download,
local_files_only=from_pretrained_options.local_files_only,
**model_kwargs
)
model.model_body = _compile_if_possible(model.model_body, compile_model=self.setfit_model_args.compile_model)
return model
self.model.model_body = _compile_if_possible(self.model.model_body, compile_model=self.setfit_model_args.compile_model)
return self.model

def validate(self, _validation_set):
if self.use_differentiable_head:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _test_with_no_amp_args_configured(self, clf):
self.assertEqual('cpu', clf.amp_args.device_type)
self.assertEqual(torch.bfloat16, clf.amp_args.dtype)

clf.initialize_kimcnn_model()
clf.initialize()
amp_args = clf.amp_args
self.assertIsNotNone(amp_args)
self.assertFalse(amp_args.use_amp)
Expand All @@ -78,7 +78,7 @@ def _test_with_amp_args_configured(self, clf):
self.assertEqual('cuda', clf.amp_args.device_type)
self.assertEqual(torch.float16, clf.amp_args.dtype)

clf.initialize_kimcnn_model()
clf.initialize()
amp_args = clf.amp_args
self.assertIsNotNone(amp_args)
self.assertFalse(amp_args.use_amp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_initialize_with_pytorch_geq_v2_and_compile_enabled(self):

with patch('torch.__version__', new='2.0.0'), \
patch('torch.compile', wraps=torch.compile) as compile_spy:
classifier.initialize_kimcnn_model()
classifier.initialize()
compile_spy.assert_called()

def test_initialize_with_pytorch_geq_v2_and_compile_disabled(self):
Expand All @@ -342,7 +342,7 @@ def test_initialize_with_pytorch_geq_v2_and_compile_disabled(self):

with patch('torch.__version__', new='2.0.0'), \
patch('torch.compile', wraps=torch.compile) as compile_spy:
classifier.initialize_kimcnn_model()
classifier.initialize()
compile_spy.assert_not_called()

def test_initialize_with_pytorch_lesser_v2_and_compile_enabled(self):
Expand All @@ -356,7 +356,7 @@ def test_initialize_with_pytorch_lesser_v2_and_compile_enabled(self):
with patch.object(torch, 'compile', return_value=None):
with patch('torch.__version__', new='1.9.0'), \
patch('torch.compile', wraps=torch.compile) as compile_spy:
classifier.initialize_kimcnn_model()
classifier.initialize()
compile_spy.assert_not_called()

def test_initialize_with_pytorch_lesser_v2_and_compile_disabled(self):
Expand All @@ -370,7 +370,7 @@ def test_initialize_with_pytorch_lesser_v2_and_compile_disabled(self):
with patch.object(torch, 'compile', return_value=None):
with patch('torch.__version__', new='2.0.0'), \
patch('torch.compile', wraps=torch.compile) as compile_spy:
classifier.initialize_kimcnn_model()
classifier.initialize()
compile_spy.assert_not_called()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def test_initialize_with_pytorch_geq_v2_and_compile_enabled(self):

with patch('torch.__version__', new='2.0.0'), \
patch('torch.compile', wraps=torch.compile) as compile_spy:
classifier.initialize_transformer(classifier.cache_dir)
classifier.initialize()
compile_spy.assert_called()

def test_initialize_with_pytorch_geq_v2_and_compile_disabled(self):
Expand All @@ -429,7 +429,7 @@ def test_initialize_with_pytorch_geq_v2_and_compile_disabled(self):

with patch('torch.__version__', new='2.0.0'), \
patch('torch.compile', wraps=torch.compile) as compile_spy:
classifier.initialize_transformer(classifier.cache_dir)
classifier.initialize()
compile_spy.assert_not_called()

def test_initialize_with_pytorch_lesser_v2_and_compile_enabled(self):
Expand All @@ -441,7 +441,7 @@ def test_initialize_with_pytorch_lesser_v2_and_compile_enabled(self):
with patch.object(torch, 'compile', return_value=None):
with patch('torch.__version__', new='1.9.0'), \
patch('torch.compile', wraps=torch.compile) as compile_spy:
classifier.initialize_transformer(classifier.cache_dir)
classifier.initialize()
compile_spy.assert_not_called()

def test_initialize_with_pytorch_lesser_v2_and_compile_disabled(self):
Expand All @@ -453,7 +453,7 @@ def test_initialize_with_pytorch_lesser_v2_and_compile_disabled(self):
with patch.object(torch, 'compile', return_value=None):
with patch('torch.__version__', new='1.9.0'), \
patch('torch.compile', wraps=torch.compile) as compile_spy:
classifier.initialize_transformer(classifier.cache_dir)
classifier.initialize()
compile_spy.assert_not_called()


Expand All @@ -470,7 +470,7 @@ def test_with_no_amp_args_configured(self):
self.assertEqual('cpu', clf.amp_args.device_type)
self.assertEqual(torch.bfloat16, clf.amp_args.dtype)

clf.initialize_transformer(clf.cache_dir)
clf.initialize()
amp_args = clf.amp_args
self.assertIsNotNone(amp_args)
self.assertFalse(amp_args.use_amp)
Expand All @@ -495,7 +495,7 @@ def test_with_amp_args_configured(self):
self.assertEqual('cuda', clf.amp_args.device_type)
self.assertEqual(torch.float16, clf.amp_args.dtype)

clf.initialize_transformer(clf.cache_dir)
clf.initialize()
amp_args = clf.amp_args
self.assertIsNotNone(amp_args)
self.assertFalse(amp_args.use_amp)
Expand Down

0 comments on commit 80bc1e2

Please sign in to comment.