From a4ceccd276e05f18301c6f0ca669f14a8ce47959 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Branchaud-Charron?= Date: Mon, 27 May 2024 19:51:42 -0400 Subject: [PATCH] Add documentation for criterion (#297) --- docs/api/index.md | 3 +- docs/api/stopping_criteria.md | 35 ++++++++ mkdocs.yml | 1 + notebooks/production/baal_prod_cls.ipynb | 104 +++++------------------ 4 files changed, 60 insertions(+), 83 deletions(-) create mode 100644 docs/api/stopping_criteria.md diff --git a/docs/api/index.md b/docs/api/index.md index 1c5464d5..b15afe8b 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -6,12 +6,13 @@ * [baal.bayesian](./bayesian.md) * [baal.active](./dataset_management.md) * [baal.active.heuristics](./heuristics.md) +* [baal.active.stopping_criteria](./stopping_criteria.md) * [baal.calibration](./calibration.md) * [baal.utils](./utils.md) ### :material-file-tree: Compatibility -* [baal.utils.pytorch_lightning] (./compatibility/pytorch-lightning) +* [baal.utils.pytorch_lightning](./compatibility/pytorch-lightning) * [baal.transformers_trainer_wrapper](./compatibility/huggingface) diff --git a/docs/api/stopping_criteria.md b/docs/api/stopping_criteria.md new file mode 100644 index 00000000..d77116e7 --- /dev/null +++ b/docs/api/stopping_criteria.md @@ -0,0 +1,35 @@ +# Stopping Criteria + +Stopping criterion are used to determine when to stop your active learning experiment. + +Their usage are simple, but best put in practice with `ActiveExperiment`. + +**Example** +```python +from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion +from baal.active.dataset import ActiveLearningDataset + +al_dataset: ActiveLearningDataset = ... # len(al_dataset) == 10 +criterion = LabellingBudgetStoppingCriterion(al_dataset, labelling_budget=100) + +assert not criterion.should_stop({}, []) + +# len(al_dataset) == 60 +al_dataset.label_randomly(50) +assert not criterion.should_stop({}, []) + +# len(al_dataset) == 110, budget exhausted! We've labelled 100 items. +al_dataset.label_randomly(50) +assert criterion.should_stop({}, []) +``` + + +### API + +### baal.active.stopping_criteria + +::: baal.active.stopping_criteria.LabellingBudgetStoppingCriterion + +::: baal.active.stopping_criteria.LowAverageUncertaintyStoppingCriterion + +::: baal.active.stopping_criteria.EarlyStoppingCriterion \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 60774b2c..4af60038 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -78,6 +78,7 @@ nav: - api/calibration.md - api/dataset_management.md - api/heuristics.md + - api/stopping_criteria.md - api/modelwrapper.md - api/utils.md - Compatibility: diff --git a/notebooks/production/baal_prod_cls.ipynb b/notebooks/production/baal_prod_cls.ipynb index 0542ecf9..21ae544a 100644 --- a/notebooks/production/baal_prod_cls.ipynb +++ b/notebooks/production/baal_prod_cls.ipynb @@ -35,15 +35,6 @@ "is_executing": false } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: 5174, Valid: 1725, Num. classes : 8\n" - ] - } - ], "source": [ "from glob import glob\n", "import os\n", @@ -52,7 +43,8 @@ "classes = os.listdir('/tmp/natural_images')\n", "train, test = train_test_split(files, random_state=1337) # Split 75% train, 25% validation\n", "print(f\"Train: {len(train)}, Valid: {len(test)}, Num. classes : {len(classes)}\")\n" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -79,7 +71,6 @@ "is_executing": false } }, - "outputs": [], "source": [ "from baal.active import FileDataset, ActiveLearningDataset\n", "from torchvision import transforms\n", @@ -101,7 +92,8 @@ "# We use -1 to specify that the data is unlabeled.\n", "test_dataset = FileDataset(test, [-1] * len(test), test_transform)\n", "active_learning_ds = ActiveLearningDataset(train_dataset, pool_specifics={'transform': test_transform})\n" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -129,7 +121,6 @@ "is_executing": false } }, - "outputs": [], "source": [ "import torch\n", "from torch import nn, optim\n", @@ -149,7 +140,8 @@ "# ModelWrapper is an object similar to keras.Model.\n", "baal_model = ModelWrapper(model, criterion)\n", "\n" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -170,11 +162,11 @@ "is_executing": false } }, - "outputs": [], "source": [ "from baal.active.heuristics import BALD\n", "heuristic = BALD(shuffle_prop=0.1)\n" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -193,13 +185,13 @@ "is_executing": false } }, - "outputs": [], "source": [ "# This function would do the work that a human would do.\n", "def get_label(img_path):\n", " return classes.index(img_path.split('/')[-2])\n", "\n" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -223,15 +215,6 @@ "is_executing": false } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Num. labeled: 100/5174\n" - ] - } - ], "source": [ "import numpy as np\n", "# 1. Label all the test set and some samples from the training set.\n", @@ -246,7 +229,8 @@ "active_learning_ds.label(train_idxs, labels)\n", "\n", "print(f\"Num. labeled: {len(active_learning_ds)}/{len(train_dataset)}\")\n" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -256,56 +240,19 @@ "is_executing": false } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:47:48.133213Z [\u001B[32minfo ] Starting training dataset=100 epoch=5\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/lib/python3.9/site-packages/torch/utils/data/dataloader.py:478: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 1, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n", - "/opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n", - " return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:48:07.477011Z [\u001B[32minfo ] Training complete train_loss=2.058176279067993\n", - "[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:48:07.479793Z [\u001B[32minfo ] Starting evaluating dataset=1725\n", - "[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:48:21.277716Z [\u001B[32minfo ] Evaluation complete test_loss=2.0671451091766357\n", - "Metrics: {'test_loss': 2.0671451091766357, 'train_loss': 2.058176279067993}\n" - ] - } - ], "source": [ "# 2. Train the model for a few epoch on the training set.\n", "baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)\n", "baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)\n", "\n", "print(\"Metrics:\", {k:v.avg for k,v in baal_model.metrics.items()})\n" - ] + ], + "outputs": [] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:48:21.291851Z [\u001B[32minfo ] Start Predict dataset=5074\n" - ] - } - ], "source": [ "# 3. Select the K-top uncertain samples according to the heuristic.\n", "pool = active_learning_ds.pool\n", @@ -316,21 +263,13 @@ "predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)\n", "# We will label the 10 most uncertain samples.\n", "top_uncertainty = heuristic(predictions)[:10]\n" - ] + ], + "outputs": [] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[(3, 1429), (4, 2971), (2, 1309), (4, 5), (3, 3761), (4, 2708), (6, 4679), (7, 160), (7, 1638), (6, 73)]\n" - ] - } - ], "source": [ "# 4. Label those samples.\n", "oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)\n", @@ -338,7 +277,8 @@ "print(list(zip(labels, oracle_indices)))\n", "active_learning_ds.label(top_uncertainty, labels)\n", "\n" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -348,7 +288,6 @@ "is_executing": true } }, - "outputs": [], "source": [ "# 5. If not done, go back to 2.\n", "for step in range(5): # 5 Active Learning step!\n", @@ -372,7 +311,8 @@ " active_learning_ds.label(top_uncertainty, labels)\n", " \n", " " - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -386,14 +326,14 @@ "cell_type": "code", "execution_count": 11, "metadata": {}, - "outputs": [], "source": [ "torch.save({\n", " 'active_dataset': active_learning_ds.state_dict(),\n", " 'model': baal_model.state_dict(),\n", " 'metrics': {k:v.avg for k,v in baal_model.metrics.items()}\n", "}, '/tmp/baal_output.pth')\n" - ] + ], + "outputs": [] }, { "cell_type": "markdown",