diff --git a/baal/experiments/transformers.py b/baal/experiments/transformers.py index bd3edb5..41973ea 100644 --- a/baal/experiments/transformers.py +++ b/baal/experiments/transformers.py @@ -12,15 +12,14 @@ class TransformersAdapter(FrameworkAdapter): def __init__(self, wrapper: BaalTransformersTrainer): self.wrapper = wrapper self._init_weight = deepcopy(self.wrapper.model.state_dict()) - self._init_scheduler = deepcopy(self.wrapper.lr_scheduler.state_dict()) - self._init_optimizer = deepcopy(self.wrapper.optimizer.state_dict()) def reset_weights(self): self.wrapper.model.load_state_dict(self._init_weight) - self.wrapper.lr_scheduler.load_state_dict(self._init_scheduler) - self.wrapper.optimizer.load_state_dict(self._init_optimizer) + self.wrapper.lr_scheduler = None + self.wrapper.optimizer = None def train(self, al_dataset: Dataset) -> Dict[str, float]: + self.wrapper.train_dataset = al_dataset return self.wrapper.train().metrics def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]: diff --git a/tests/experiments/test_adapter.py b/tests/experiments/test_adapter.py index f71a663..fd1a133 100644 --- a/tests/experiments/test_adapter.py +++ b/tests/experiments/test_adapter.py @@ -5,13 +5,29 @@ from torch.nn import MSELoss from torch.optim import SGD from torch.utils.data import Dataset +from transformers import pipeline, TrainingArguments from baal import ModelWrapper, ActiveLearningDataset from baal.active.heuristics import BALD from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion from baal.experiments.base import ActiveLearningExperiment from baal.experiments.modelwrapper import ModelWrapperAdapter +from baal.experiments.transformers import TransformersAdapter from baal.modelwrapper import TrainingArgs +from baal.transformers_trainer_wrapper import BaalTransformersTrainer + + +class MockHFDataset: + def __init__(self): + self.x = torch.randint(0, 5, (10,)) + self.y = torch.randn((2,)) + self.length = 5 + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {"input_ids": self.x, "labels": self.y} class MockDataset(Dataset): @@ -63,6 +79,17 @@ def model_wrapper(): ) +@pytest.fixture +def hf_trainer(): + text_classifier = pipeline( + task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt" + ) + return TransformersAdapter(BaalTransformersTrainer( + model=text_classifier.model, + args=TrainingArguments('/tmp', no_cuda=True) + )) + + def test_adapter_reset(model_wrapper): dataset = MockDataset() before_params = list( @@ -77,6 +104,7 @@ def test_adapter_reset(model_wrapper): map(lambda x: x.clone(), model_wrapper.wrapper.model.parameters())) assert all([np.allclose(i.detach(), j.detach()) for i, j in zip(before_params, reset_params)]) + model_wrapper.train(dataset) result = model_wrapper.evaluate(dataset, average_predictions=1) assert 'test_loss' in result @@ -85,6 +113,29 @@ def test_adapter_reset(model_wrapper): assert predictions.shape == (len(dataset), 1, 10) +def test_hf_wrapper(hf_trainer): + dataset = MockHFDataset() + before_params = list( + map(lambda x: x.clone(), hf_trainer.wrapper.model.parameters())) + hf_trainer.train(dataset) + after_params = list( + map(lambda x: x.clone(), hf_trainer.wrapper.model.parameters())) + assert not all([np.array_equal(i.detach(), j.detach()) + for i, j in zip(before_params, after_params)]) + hf_trainer.reset_weights() + reset_params = list( + map(lambda x: x.clone(), hf_trainer.wrapper.model.parameters())) + assert all([np.allclose(i.detach(), j.detach()) + for i, j in zip(before_params, reset_params)]) + hf_trainer.train(dataset) + + result = hf_trainer.evaluate(dataset, average_predictions=1) + assert 'eval_loss' in result + + predictions = hf_trainer.predict(dataset, iterations=10) + assert predictions.shape == (len(dataset), 2, 10) + + def test_experiment(model_wrapper): al_dataset = ActiveLearningDataset(MockDataset()) al_dataset.label_randomly(10)