Skip to content

Commit

Permalink
Finish Adapter work, start test and editing experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed May 26, 2024
1 parent d99019f commit 9f791f6
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 245 deletions.
4 changes: 3 additions & 1 deletion baal/active/heuristics/heuristics_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def predict_on_dataset(
half=False,
verbose=True,
):
return super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1])
return (
super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1])
) # type: ignore

def predict_on_batch(self, data, iterations=1):
"""Rank the predictions according to their uncertainties."""
Expand Down
9 changes: 4 additions & 5 deletions baal/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import abc
from typing import Dict
from typing import Dict, List, Union

from numpy._typing import NDArray

from baal.active.dataset.base import Dataset


class FrameworkAdapter(abc.ABC):

def reset_weights(self):
raise NotImplementedError

def train(self, al_dataset:Dataset) -> Dict[str, float]:
def train(self, al_dataset: Dataset) -> Dict[str, float]:
raise NotImplementedError

def predict(self, dataset: Dataset, iterations: int) -> NDArray:
def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]:
raise NotImplementedError

def evaluate(self, dataset: Dataset) -> Dict[str, float]:
def evaluate(self, dataset: Dataset, average_predictions: int) -> Dict[str, float]:
raise NotImplementedError
62 changes: 46 additions & 16 deletions baal/experiments/base.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,77 @@
import itertools
from typing import Union, Optional
from typing import Union, Optional, TYPE_CHECKING

import pandas as pd
import structlog
from tqdm import tqdm
from transformers import Trainer

from baal import ModelWrapper, ActiveLearningDataset
from baal.active.dataset.base import Dataset
from baal.active.heuristics import AbstractHeuristic
from baal.active.stopping_criteria import StoppingCriterion, LabellingBudgetStoppingCriterion
from baal.experiments import FrameworkAdapter
from baal.experiments.modelwrapper import ModelWrapperAdapter

try:
import transformers

TRANSFORMERS_AVAILABLE = True
except ImportError:
from baal.transformers_trainer_wrapper import BaalTransformersTrainer
from baal.experiments.transformers import TransformersAdapter

TRANSFORMERS_AVAILABLE = False

log = structlog.get_logger(__name__)


class ActiveLearningExperiment:
def __init__(self, trainer: Union[ModelWrapper, Trainer],
al_dataset: ActiveLearningDataset, eval_dataset: Dataset, heuristic: AbstractHeuristic,
query_size: int = 100, iterations: int = 20, criterion: Optional[StoppingCriterion] = None):
def __init__(
self,
trainer: Union[ModelWrapper, "BaalTransformersTrainer"],
al_dataset: ActiveLearningDataset,
eval_dataset: Dataset,
heuristic: AbstractHeuristic,
query_size: int = 100,
iterations: int = 20,
criterion: Optional[StoppingCriterion] = None,
):
self.al_dataset = al_dataset
self.eval_dataset = eval_dataset
self.heuristic = heuristic
self.query_size = query_size
self.iterations = iterations
self.criterion = criterion or LabellingBudgetStoppingCriterion(al_dataset,
labelling_budget=al_dataset.n_unlabelled)
self.criterion = criterion or LabellingBudgetStoppingCriterion(
al_dataset, labelling_budget=al_dataset.n_unlabelled
)
self.adapter = self._get_adapter(trainer)

def start(self):
records = []
_start = len(self.al_dataset)
for step in tqdm(itertools.count(start=0)):
for _ in tqdm(itertools.count(start=0)):
self.adapter.reset_weights()
train_metrics = self.adapter.train()
eval_metrics = self.adapter.evaluate(self.eval_dataset)
train_metrics = self.adapter.train(self.al_dataset)
eval_metrics = self.adapter.evaluate(
self.eval_dataset, average_predictions=self.iterations
)
ranks, uncertainty = self.heuristic.get_ranks(
self.adapter.predict(self.al_dataset.pool, iterations=self.iterations))
self.al_dataset.label(ranks[:self.query_size])

self.adapter.predict(self.al_dataset.pool, iterations=self.iterations)
)
self.al_dataset.label(ranks[: self.query_size])
records.append({**train_metrics, **eval_metrics})
if self.criterion.should_stop(eval_metrics, uncertainty):
log.info("Experiment complete", num_labelled=len(self.al_dataset) - _start)
break
print(pd.DataFrame.from_dict({**train_metrics, **eval_metrics}))
return records

def _get_adapter(self, trainer: Union[ModelWrapper, Trainer]) -> FrameworkAdapter:
def _get_adapter(
self, trainer: Union[ModelWrapper, "BaalTransformersTrainer"]
) -> FrameworkAdapter:
if isinstance(trainer, ModelWrapper):
return ModelWrapperAdapter(trainer)
elif TRANSFORMERS_AVAILABLE and isinstance(trainer, BaalTransformersTrainer):
return TransformersAdapter(trainer)
raise ValueError(
f"{type(trainer)} is not a supported trainer."
" Baal supports ModelWrapper and BaalTransformersTrainer"
)
38 changes: 12 additions & 26 deletions baal/experiments/modelwrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from copy import deepcopy
from typing import Dict, Union, List

from numpy._typing import NDArray

Expand All @@ -10,33 +11,18 @@
class ModelWrapperAdapter(FrameworkAdapter):
def __init__(self, wrapper: ModelWrapper):
self.wrapper = wrapper
self._init_weight = self.wrapper.state_dict()
self._init_weight = deepcopy(self.wrapper.state_dict())

def reset_weights(self):
self.wrapper.load_state_dict(self._init_weight)

def train(self, al_dataset: Dataset) -> Dict[str, float]:
# TODO figure out args. Probably add a new "TrainingArgs" similar to other framworks.
return self.wrapper.train_on_dataset(al_dataset,
optimizer=...,
batch_size=...,
epoch=...,
use_cuda=...,
workers=...,
collate_fn=...,
regularizer=...)

def predict(self, dataset: Dataset, iterations: int) -> NDArray:
return self.wrapper.predict_on_dataset(dataset,
iterations=iterations,
batch_size=...,
use_cuda=...,
workers=...,
collate_fn=..., )

def evaluate(self, dataset: Dataset) -> Dict[str, float]:
return self.wrapper.test_on_dataset(dataset,
batch_size=...,
use_cuda=...,
workers=...,
collate_fn=...)
self.wrapper.train_on_dataset(al_dataset)
return self.wrapper.get_metrics("train")

def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]:
return self.wrapper.predict_on_dataset(dataset, iterations=iterations)

def evaluate(self, dataset: Dataset, average_predictions: int) -> Dict[str, float]:
self.wrapper.test_on_dataset(dataset, average_predictions=average_predictions)
return self.wrapper.get_metrics("test")
30 changes: 30 additions & 0 deletions baal/experiments/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from copy import deepcopy
from typing import Dict, cast, List, Union

from numpy._typing import NDArray

from baal.active.dataset.base import Dataset
from baal.experiments import FrameworkAdapter
from baal.transformers_trainer_wrapper import BaalTransformersTrainer


class TransformersAdapter(FrameworkAdapter):
def __init__(self, wrapper: BaalTransformersTrainer):
self.wrapper = wrapper
self._init_weight = deepcopy(self.wrapper.model.state_dict())
self._init_scheduler = deepcopy(self.wrapper.lr_scheduler.state_dict())
self._init_optimizer = deepcopy(self.wrapper.optimizer.state_dict())

def reset_weights(self):
self.wrapper.model.load_state_dict(self._init_weight)
self.wrapper.lr_scheduler.load_state_dict(self._init_scheduler)
self.wrapper.optimizer.load_state_dict(self._init_optimizer)

def train(self, al_dataset: Dataset) -> Dict[str, float]:
return self.wrapper.train().metrics

def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]:
return self.wrapper.predict_on_dataset(dataset, iterations=iterations)

def evaluate(self, dataset: Dataset, average_predictions: int) -> Dict[str, float]:
return cast(Dict[str, float], self.wrapper.evaluate(dataset))
5 changes: 3 additions & 2 deletions baal/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable, Optional
from typing import Callable, Optional, Union, List

import numpy as np
import structlog
import torch
from numpy._typing import NDArray
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
Expand Down Expand Up @@ -236,7 +237,7 @@ def predict_on_dataset(
iterations: int,
half=False,
verbose=True,
):
) -> Union[NDArray, List[NDArray]]:
"""
Use the model to predict on a dataset `iterations` time.
Expand Down
3 changes: 2 additions & 1 deletion baal/transformers_trainer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import structlog
import torch
from numpy._typing import NDArray
from torch import nn
from tqdm import tqdm
from transformers import PreTrainedModel, TrainingArguments
Expand Down Expand Up @@ -140,7 +141,7 @@ def predict_on_dataset(
iterations: int = 1,
half: bool = False,
ignore_keys: Optional[List[str]] = None,
):
) -> Union[NDArray, List[NDArray]]:

"""
Use the model to predict on a dataset `iterations` time.
Expand Down
56 changes: 12 additions & 44 deletions experiments/nlp_bert_mcdropout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import random
from copy import deepcopy
from pprint import pprint

import torch
import torch.backends
Expand All @@ -16,7 +17,9 @@
from baal.active import get_heuristic
from baal.active.active_loop import ActiveLearningLoop
from baal.active.dataset.nlp_datasets import active_huggingface_dataset, HuggingFaceDatasets
from baal.active.heuristics import BALD
from baal.bayesian.dropout import patch_module
from baal.experiments.base import ActiveLearningExperiment
from baal.transformers_trainer_wrapper import BaalTransformersTrainer

"""
Expand All @@ -26,14 +29,11 @@

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", default=100, type=int)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--initial_pool", default=1000, type=int)
parser.add_argument("--model", default="bert-base-uncased", type=str)
parser.add_argument("--query_size", default=100, type=int)
parser.add_argument("--heuristic", default="bald", type=str)
parser.add_argument("--iterations", default=20, type=int)
parser.add_argument("--shuffle_prop", default=0.05, type=float)
parser.add_argument("--learning_epoch", default=20, type=int)
return parser.parse_args()

Expand Down Expand Up @@ -67,8 +67,6 @@ def main():

hyperparams = vars(args)

heuristic = get_heuristic(hyperparams["heuristic"], hyperparams["shuffle_prop"])

model = BertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=hyperparams["model"]
)
Expand All @@ -85,8 +83,6 @@ def main():
if use_cuda:
model.cuda()

init_weights = deepcopy(model.state_dict())

training_args = TrainingArguments(
output_dir="/app/baal/results", # output directory
num_train_epochs=hyperparams["learning_epoch"], # total # of training epochs
Expand All @@ -104,44 +100,16 @@ def main():
eval_dataset=test_set,
tokenizer=None,
)

logs = {}
logs["epoch"] = 0

# In this case, nlp data is fast to process and we do NoT need to use a smaller batch_size
active_loop = ActiveLearningLoop(
active_set,
model.predict_on_dataset,
heuristic,
hyperparams.get("query_size", 1),
iterations=hyperparams["iterations"],
experiment = ActiveLearningExperiment(
trainer=model,
al_dataset=active_set,
eval_dataset=test_set,
heuristic=BALD(),
query_size=hyperparams["query_size"],
iterations=20,
criterion=None,
)

for epoch in tqdm(range(args.epoch)):
# we use the default setup of HuggingFace for training (ex: epoch=1).
# The setup is adjustable when BaalHuggingFaceTrainer is defined.
model.train()

# Validation!
eval_metrics = model.evaluate()

# We reorder the unlabelled pool at the frequency of learning_epoch
# This helps with speed while not changing the quality of uncertainty estimation.
should_continue = active_loop.step()

# We reset the model weights to relearn from the new trainset.
model.load_state_dict(init_weights)
model.lr_scheduler = None
if not should_continue:
break
active_logs = {
"epoch": epoch,
"labeled_data": active_set.labelled_map,
"Next Training set size": len(active_set),
}

logs = {**eval_metrics, **active_logs}
print(logs)
pprint(experiment.start())


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 9f791f6

Please sign in to comment.