This is an implementation of Adaptive Computation Time (Graves, 2016) in PyTorch.
Adaptive Computation Time is a drop-in replacement for RNNs structures that allows the model to process multiple time steps on a single input token. More information can be found in the paper, or in this blog post.
- Python 3.6
- PyTorch 0.3.0
matplotlib
,argparse
I am still in the process of replicating the experiments described in the paper.
- Bit Parity
- Logical Gates
- Addition
- Sorting
- Word Prediction
-
Git clone this repository
-
Train/Evaluate the model on a given task/parameter setting:
- E.g.
python run_train.py \ --task=parity \ --use_act=False \ --model_save_path="outputs/models/parity/rnn"
python run_train.py \ --task=parity \ --use_act=True \ --act_ponder_penalty=0.001 \ --model_save_path="outputs/models/parity/act_0.001"