diff --git a/README.md b/README.md index 7995e1e..2e6fd5e 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,9 @@ [![Documentation Status](https://readthedocs.org/projects/daart/badge/?version=latest)](https://daart.readthedocs.io/en/latest/?badge=latest) [![DOI](https://zenodo.org/badge/334987729.svg)](https://zenodo.org/badge/latestdoi/334987729) -A collection of tools for the discrete classification of animal behaviors using low-dimensional representations of videos (such as skeletons provided by tracking algorithms). Our approach combines strong supervision, weak supervision, and self-supervision to improve model performance. See the preprint [here](https://www.biorxiv.org/content/10.1101/2021.06.16.448685v1) for more details. This repo currently supports fitting the following types of base models on behavioral time series data: +A collection of tools for the discrete classification of animal behaviors using low-dimensional representations of videos (such as skeletons provided by tracking algorithms). Our approach combines strong supervision, weak supervision, and self-supervision to improve model performance. See the preprint [here](https://www.biorxiv.org/content/10.1101/2021.06.16.448685v1) for more details. + +This repo currently supports fitting the following types of base models on behavioral time series data: * Dense MLP network with initial 1D convolutional layer * RNNs - both LSTMs and GRUs * Temporal Convolutional Networks (TCNs) @@ -19,46 +21,3 @@ If you use daart in your analysis of behavioral data, please cite our preprint! year={2021}, publisher={Cold Spring Harbor Laboratory} } - -## Installation - - - -## Getting started - -To fit models from the command line using [test-tube](https://williamfalcon.github.io/test-tube/) -for hyperparameter searching and model fitting, see `fit_models.py` in the `examples` directory. -This script fits one or more models based on three yaml configuration files: one describing the -data, one describing the model, and one describing the training procedure. Example configuration -files can be found in the `configs` directory. - -**_Note:_** Test-tube will automatically perform a hyperparameter search over any field that is -provided as a list; for example, in the `model.yaml` file, change `n_hid_layers: 1` to -`n_hid_layers: [1, 2, 3]` to search over the number of hidden layers in the model. - -Once you have set the desired parameters in these files (see comment on data paths below), you can -then fit models like so: - -``` -(daart) $: python fit_models.py --data_config /path/to/data.yaml - --model_config /path/to/model.yaml --train_config /path/to/train.yaml -``` - -#### Data paths - -The `data.yaml` file has a field for listing experiment/session/video ids (`expt_ids`), as well as -a `data_dir` field. The `fit_models.py` script assumes data is stored in the following way, though -this can easily be adapted by changing the appropriate lines in a copy of the `fit_example.py` -script: - -* markers: `data_dir/markers/[expt_id]_labeled.csv` or `data_dir/markers/[expt_id]_labeled.h5`; -the standard file formats used by DLC/DGP are currently supported. - -* hand labels: `data_dir/labels-hand/[expt_id]_labels.csv`; a binary matrix of shape -`(T, n_classes + 1)`, where the first column represents the `background` class; the gradients -contributed by these time points are zeroed out during training. - -* heuristic labels: `data_dir/labels-heuristic/[expt_id]_labels.csv`; same format as the hand -labels - -See the directory `daart/data` for example fly data used in the preprint. diff --git a/docs/index.rst b/docs/index.rst index a8a42a1..1cdb4e5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,6 +24,7 @@ data: :caption: Contents: source/installation + source/user_guide source/api Indices and tables diff --git a/docs/source/user_guide.rst b/docs/source/user_guide.rst new file mode 100644 index 0000000..b617e06 --- /dev/null +++ b/docs/source/user_guide.rst @@ -0,0 +1,17 @@ +.. _user_guide: + +########## +User guide +########## + +This guide walks you through the steps required for using the daart package for semi-supervised +discrete behavior classification. + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + user_guide/organizing_your_data + user_guide/config_files + user_guide/training + user_guide/inference diff --git a/docs/source/user_guide/config_files.rst b/docs/source/user_guide/config_files.rst new file mode 100644 index 0000000..66899cc --- /dev/null +++ b/docs/source/user_guide/config_files.rst @@ -0,0 +1,60 @@ +.. _user_guide_configs: + +####################### +The configuration files +####################### + +Users interact with daart through a set of configuration (yaml) files. +These files point to the data directories, define the type of model to fit, and specify a wide +range of hyperparameters. + +An example set of configuration files can be found +`here `_. +When training a model on a new dataset, you must copy/paste these templates onto your local +machine and update the arguments to match your data. + +There are three configuration files: + +* :ref:`data `: where data is stored and model input type +* :ref:`model `: model class and various network hyperparameters +* :ref:`train `: training epochs, batch size, etc. + +The sections below describe the most important parameters in each file; +see the example configs for all possible options. + +.. _config_data: + +Data +==== + +* **input_type**: name of directory containing input data: 'markers' | 'features' | ... +* **output_size**: number of classes (including background) +* **expt_ids**: list of experiment ids used for training the model +* **data_dir**: absolute path to directory that contains the data +* **results_dir**: absolute path to directory that stores model fitting results + +.. _config_model: + +Model +===== + +* **labmda_weak**: weight on heuristic/pseudo label classification loss +* **lambda_strong**: weight on hand label classification loss (can always leave this as 1) +* **lambda_recon**: weight on input reconstruction loss +* **lambda_pred**: weight on next-step-ahead prediction loss + +So, for example, to fit a fully supervised classification model, set ``lambda_strong: 1`` and +all other "lambda" options to 0. + +To fit a model that uses heuristic labels, set ``lambda_strong: 1``, ``lambda_weak: 1``, and +all other "lambda" options to 0. You can try several values of ``lambda_weak`` to see what works +best for your data. + +.. _config_train: + +Train +===== + +* **min/max_epochs**: control length of training +* **enable_early_stop**: exit training early if validation loss begins to increase +* **trial_splits**: fraction of data to use for train;val;test;gap; you can always set "gap" to 0 as long as you validate your model on completely held-out videos diff --git a/docs/source/user_guide/inference.rst b/docs/source/user_guide/inference.rst new file mode 100644 index 0000000..731a6e1 --- /dev/null +++ b/docs/source/user_guide/inference.rst @@ -0,0 +1,82 @@ +.. _user_guide_inference: + +######### +Inference +######### + +Once you have trained a model you'll likely want to run inference on new videos. + +Similar to training, there are a set of high-level functions used to perform inference and evaluate +performance; this page details some of the main steps. + + +Load model +========== + +Using a provided model directory, construct a model and load the weights. + +.. code-block:: python + + import os + import torch + import yaml + + from daart.models import Segmenter + + model_dir = /path/to/model_dir + model_file = os.path.join(model_dir, 'best_val_model.pt') + + hparams_file = os.path.join(model_dir, 'hparams.yaml') + hparams = yaml.safe_load(open(hparams_file, 'rb')) + + model = Segmenter(hparams) + model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage)) + model.to(hparams['device']) + model.eval() + +Build data generator +==================== + +To run inference on a new session, you must provide a csv file that contains markers or features +from a new session (you must use the same type of inputs the model was trained on). + +.. code-block:: python + + from daart.data import DataGenerator + from daart.transforms import ZScore + + sess_id = + input_file = /path/to/markers_or_features_csv + + # define data generator signals + signals = ['markers'] # same for markers or features + transforms = [ZScore()] + paths = [input_file] + + # build data generator + data_gen_test = DataGenerator( + [sess_id], [signals], [transforms], [paths], device=hparams['device'], + sequence_length=hparams['sequence_length'], batch_size=hparams['batch_size'], + trial_splits=hparams['trial_splits'], + sequence_pad=hparams['sequence_pad'], input_type=hparams['input_type'], + ) + +Run inference +============= + +Inference can be performed by passing the newly constructed data generator to the model's +``predict_labels`` method: + +.. code-block:: python + + import numpy as np + + # predict probabilities from model + print('computing states for %s...' % sess_id, end='') + tmp = model.predict_labels(data_gen_test, return_scores=True) + probs = np.vstack(tmp['labels'][0]) + print('done') + + # get discrete state by taking argmax over probabilities at each time point + states = np.argmax(probs, axis=1) + diff --git a/docs/source/user_guide/organizing_your_data.rst b/docs/source/user_guide/organizing_your_data.rst new file mode 100644 index 0000000..8a3d549 --- /dev/null +++ b/docs/source/user_guide/organizing_your_data.rst @@ -0,0 +1,214 @@ +.. _user_guide_data: + +#################### +Organizing your data +#################### + +The daart models use a low-dimensional representation of behavioral videos (such as pose estimates) +to predict a set of discrete behavioral classes. +Heuristic labels may supplement hand labels and lead to better classification performance +(see the preprint referenced on the home page). +Each daart project will have a data directory that contains markers (or other features), +hand labels, and, if desired, heuristic labels. +See the +`example data `_ +directory which contains two sessions of the head-fixed fly experiment analyzed in the paper. + +.. note:: + + The base daart package does **not** contain tools for labeling data; + we recommend the `DeepEthogram labeling GUI `_. + +.. note:: + + Currently the models perform multinomial classification, so that only a single behavior + is predicted at each time step. + +Data directory structure +------------------------ + +The data directory structure contains subdirectories for each data type. +At minimum, a supervised model requires model inputs (either ``markers`` or ``features``) +and model outputs (``labels-hand``). +Semi-supervised models may also require heuristic labels (``labels-heuristic``). + +The example directory structure below shows the naming convention for the different data types. +For each data type the data must be separated by experimental session, and each session must have a +unique ID. + +Multiple types of features can be stored, and the set of features desired for a particular model +can be specified in the :ref:`configuration files `. +In the example below, there are two sets of features: ``features-base`` and ``features-simba``. +The naming convention is the same for both. + +Videos are not required for fitting the daart models, but may be useful for downstream analysis. + +.. code-block:: + + data_directory + ├── features-base + │ ├── _labeled.csv + │ └── _labeled.csv + ├── features-simba + │ ├── _labeled.csv + │ └── _labeled.csv + ├── labels-hand + │ ├── _labels.csv + │ └── _labels.csv + ├── labels-heuristic + │ ├── _labels.csv + │ └── _labels.csv + ├── markers + │ ├── _labeled.csv + │ └── _labeled.csv + └── videos + ├── .mp4 + └── .mp4 + + +Data formats +------------ + +Each data type requires its own (quite general) format for use with the daart code. + +Markers format +************** + +The current code accepts either csv or h5 files that are output by DLC or Lightning Pose. +The csv files must look like the following: + +.. list-table:: markers/_labeled.csv + :widths: 25 25 25 25 25 25 25 + :header-rows: 3 + + * - scorer + - scorer_name + - scorer_name + - scorer_name + - scorer_name + - scorer_name + - scorer_name + * - bodyparts + - bodypart 1 + - bodypart 1 + - bodypart 1 + - bodypart 2 + - bodypart 2 + - bodypart 2 + * - coords + - x + - y + - likelihood + - x + - y + - likelihood + * - 0 + - 274.3 + - 184.5 + - 0.87 + - 23.4 + - 13.0 + - 0.99 + * - 1 + - 275.6 + - 183.0 + - 0.88 + - 23.3 + - 13.0 + - 0.99 + * - 2 + - 276.9 + - 182.5 + - 0.87 + - 23.3 + - 12.9 + - 0.99 + * - 3 + - 278.4 + - 181.0 + - 0.87 + - 23.4 + - 13.1 + - 0.99 + +Features format +*************** + +Features should also be stored in csv files, with a single header row giving the feature name for +each column. The first column denotes the frame number. + +.. list-table:: features/_labeled.csv + :widths: 10 25 25 25 + :header-rows: 1 + + * - + - feature0_name + - feature1_name + - feature2_name + * - 0 + - 458.3 + - 0.12 + - 13.8 + * - 1 + - 500.2 + - 0.06 + - 14.7 + * - 2 + - 523.8 + - -0.06 + - 15.6 + * - 3 + - 567.4 + - -0.08 + - 16.5 + +Hand labels format +****************** + +The hand labels are stored in a csv file; the first (header) row denotes the behavior class names +(with the first column containing an empty cell). +The remaining rows contain the hand labels for each time point. +The first column denotes the frame number. +The second column denotes the "background" class, and for each row the entry should be 1 if +no other behavior is labeled at that time point, or 0 if at least one other behavior is labeled at +that time point. +The remaining columns correspond to the dataset-specific behavioral classes, and are binary as well +(0s and 1s). +There should only be a single "1" per row. + +.. list-table:: labels-hand/_labels.csv + :widths: 10 25 25 25 25 + :header-rows: 1 + + * - + - background + - behavior0_name + - behavior1_name + - behavior2_name + * - 0 + - 1 + - 0 + - 0 + - 0 + * - 1 + - 1 + - 0 + - 0 + - 0 + * - 2 + - 0 + - 0 + - 0 + - 1 + * - 3 + - 0 + - 0 + - 0 + - 1 + +For a complete example see the csv files in the `example data `_. + +Heuristic labels format +*********************** + +Same format as the hand labels. diff --git a/docs/source/user_guide/training.rst b/docs/source/user_guide/training.rst new file mode 100644 index 0000000..9da2965 --- /dev/null +++ b/docs/source/user_guide/training.rst @@ -0,0 +1,31 @@ +.. _user_guide_training: + +######## +Training +######## + +daart provides several tools for training models: + +1. A set of high-level functions used for creating data loaders, models, trainers, etc. + You can combine these to create your own custom training script. +2. An `example training script `_ + that demonstrates how to combine the high-level functions for model training and evaluation. + This is a complete training script and you may simply use it as-is. + +Additionally, daart uses the `test-tube `_ package +for hyperparameter searching and model fitting. +``test-tube`` will automatically perform a hyperparameter search over any field that is provided as +a list; +for example, in the ``model.yaml`` file, change ``n_hid_layers: 1`` to ``n_hid_layers: [1, 2, 3]`` +to search over the number of hidden layers in the model. + +Once you have set the desired parameters in the :ref:`configuration files ` +(make sure to update the data paths!), move to the directory where your copy of ``fit_models.py`` +is stored and run the following from the terminal: + +.. code-block:: console + + python fit_models.py --data_config /path/to/data.yaml --model_config /path/to/model.yaml --train_config /path/to/train.yaml + +You will see configuration details printed in the terminal, followed by a training progress bar. +Upon training completion the model will be saved in the location specified in the data config.