Skip to content

Commit

Permalink
Add tests to HF
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed May 27, 2024
1 parent 9f791f6 commit d5b44b0
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
7 changes: 3 additions & 4 deletions baal/experiments/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ class TransformersAdapter(FrameworkAdapter):
def __init__(self, wrapper: BaalTransformersTrainer):
self.wrapper = wrapper
self._init_weight = deepcopy(self.wrapper.model.state_dict())
self._init_scheduler = deepcopy(self.wrapper.lr_scheduler.state_dict())
self._init_optimizer = deepcopy(self.wrapper.optimizer.state_dict())

def reset_weights(self):
self.wrapper.model.load_state_dict(self._init_weight)
self.wrapper.lr_scheduler.load_state_dict(self._init_scheduler)
self.wrapper.optimizer.load_state_dict(self._init_optimizer)
self.wrapper.lr_scheduler = None
self.wrapper.optimizer = None

def train(self, al_dataset: Dataset) -> Dict[str, float]:
self.wrapper.train_dataset = al_dataset
return self.wrapper.train().metrics

def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]:
Expand Down
51 changes: 51 additions & 0 deletions tests/experiments/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,29 @@
from torch.nn import MSELoss
from torch.optim import SGD
from torch.utils.data import Dataset
from transformers import pipeline, TrainingArguments

from baal import ModelWrapper, ActiveLearningDataset
from baal.active.heuristics import BALD
from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion
from baal.experiments.base import ActiveLearningExperiment
from baal.experiments.modelwrapper import ModelWrapperAdapter
from baal.experiments.transformers import TransformersAdapter
from baal.modelwrapper import TrainingArgs
from baal.transformers_trainer_wrapper import BaalTransformersTrainer


class MockHFDataset:
def __init__(self):
self.x = torch.randint(0, 5, (10,))
self.y = torch.randn((2,))
self.length = 5

def __len__(self):
return self.length

def __getitem__(self, i):
return {"input_ids": self.x, "labels": self.y}


class MockDataset(Dataset):
Expand Down Expand Up @@ -63,6 +79,17 @@ def model_wrapper():
)


@pytest.fixture
def hf_trainer():
text_classifier = pipeline(
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt"
)
return TransformersAdapter(BaalTransformersTrainer(
model=text_classifier.model,
args=TrainingArguments('/tmp', no_cuda=True)
))


def test_adapter_reset(model_wrapper):
dataset = MockDataset()
before_params = list(
Expand All @@ -77,6 +104,7 @@ def test_adapter_reset(model_wrapper):
map(lambda x: x.clone(), model_wrapper.wrapper.model.parameters()))
assert all([np.allclose(i.detach(), j.detach())
for i, j in zip(before_params, reset_params)])
model_wrapper.train(dataset)

result = model_wrapper.evaluate(dataset, average_predictions=1)
assert 'test_loss' in result
Expand All @@ -85,6 +113,29 @@ def test_adapter_reset(model_wrapper):
assert predictions.shape == (len(dataset), 1, 10)


def test_hf_wrapper(hf_trainer):
dataset = MockHFDataset()
before_params = list(
map(lambda x: x.clone(), hf_trainer.wrapper.model.parameters()))
hf_trainer.train(dataset)
after_params = list(
map(lambda x: x.clone(), hf_trainer.wrapper.model.parameters()))
assert not all([np.array_equal(i.detach(), j.detach())
for i, j in zip(before_params, after_params)])
hf_trainer.reset_weights()
reset_params = list(
map(lambda x: x.clone(), hf_trainer.wrapper.model.parameters()))
assert all([np.allclose(i.detach(), j.detach())
for i, j in zip(before_params, reset_params)])
hf_trainer.train(dataset)

result = hf_trainer.evaluate(dataset, average_predictions=1)
assert 'eval_loss' in result

predictions = hf_trainer.predict(dataset, iterations=10)
assert predictions.shape == (len(dataset), 2, 10)


def test_experiment(model_wrapper):
al_dataset = ActiveLearningDataset(MockDataset())
al_dataset.label_randomly(10)
Expand Down

0 comments on commit d5b44b0

Please sign in to comment.