This is the implementation of Moore Machine Network (MMN) introduced in the work "Learning Finite State Representations of Recurrent Policy Networks".
If you find it useful in your research, please cite it using :
@inproceedings{
koul2018learning,
title={Learning Finite State Representations of Recurrent Policy Networks},
author={Anurag Koul and Alan Fern and Sam Greydanus},
booktitle={International Conference on Learning Representations},
year={2019},
url={https://openreview.net/forum?id=S1gOpsCctm},
}
Also, here is the link to the poster presented in the ICLR2019.
- Python 3.5+
- Pytorch
- gym_x
- To install dependencies:
pip install -r requirements.txt
We use main_mce.py , main_tomita.py
and main_atari.py
for experimenting with Mode Counter Environment(a.k.a Gold Rush) , Tomita Grammar and Atari, respectively.
In the following, we describe usage w.r.t main_atari.py
. However, the same would apply for other cases.
usage: main_atari.py [-h] [--generate_train_data] [--generate_bn_data]
[--generate_max_steps GENERATE_MAX_STEPS] [--gru_train]
[--gru_test] [--gru_size GRU_SIZE] [--gru_lr GRU_LR]
[--bhx_train] [--ox_train] [--bhx_test] [--ox_test]
[--bgru_train] [--bgru_test] [--bhx_size BHX_SIZE]
[--bhx_suffix BHX_SUFFIX] [--ox_size OX_SIZE]
[--train_epochs TRAIN_EPOCHS] [--batch_size BATCH_SIZE]
[--bgru_lr BGRU_LR] [--gru_scratch] [--bx_scratch]
[--generate_fsm] [--evaluate_fsm]
[--bn_episodes BN_EPISODES] [--bn_epochs BN_EPOCHS]
[--no_cuda] [--env ENV] [--env_seed ENV_SEED]
[--result_dir RESULT_DIR]
GRU to FSM
optional arguments:
-h, --help show this help message and exit
--generate_train_data
Generate Train Data
--generate_bn_data Generate Bottle-Neck Data
--generate_max_steps GENERATE_MAX_STEPS
Maximum number of steps to be used for data generation
--gru_train Train GRU Network
--gru_test Test GRU Network
--gru_size GRU_SIZE No. of GRU Cells
--gru_lr GRU_LR No. of GRU Cells
--bhx_train Train bx network
--ox_train Train ox network
--bhx_test Test bx network
--ox_test Test ox network
--bgru_train Train binary gru network
--bgru_test Test binary gru network
--bhx_size BHX_SIZE binary encoding size
--bhx_suffix BHX_SUFFIX
suffix fo bhx folder
--ox_size OX_SIZE binary encoding size
--train_epochs TRAIN_EPOCHS
No. of training episodes
--batch_size BATCH_SIZE
batch size used for training
--bgru_lr BGRU_LR Learning rate for binary GRU
--gru_scratch use scratch gru for BGRU
--bx_scratch use scratch bx network for BGRU
--generate_fsm extract fsm from fmm net
--evaluate_fsm evaluate fsm
--bn_episodes BN_EPISODES
No. of episodes for generating data for Bottleneck
Network
--bn_epochs BN_EPOCHS
No. of Training epochs
--no_cuda no cuda usage
--env ENV Name of the environment
--env_seed ENV_SEED Seed for the environment
--result_dir RESULT_DIR
Directory Path to store results
For most of the experiments we've done, we've set generate_max_steps = 100
. Based on the environment you're using, you can change it accordingly. Other parameters' values were set to the default ones, except for ox_size
, hx_size
, and gru_size
which were set based on the experiment we ran.
Formation of MMN requires following multiple steps which could be found here. These steps could also be sequentially executed for bhx_size=64,ox_size=100
using the following script. This script could be easily customized for other environments.
./run_atari.sh PongDeterministic-v4
-
Test RNN: We assume existence of pre-trained RNN model. The following step is optional and evaluates the performance of this model:
python main_atari.py --env PongDeterministic-v4 --gru_test --gru_size 32
-
Generate Bottleneck Data: It involves generating and storing data for training quantized bottleneck data (QBN).
python main_atari.py --env PongDeterministic-v4 --generate_bn_data --gru_size 32 --generate_max_steps 100
-
Train BHX : It involves training QBN for Hidden State (hx). After each epoch, the QBN is inserted into orginal rnn model and the overall model is evaluated with environment. The Best performing QBN is saved.
python main_atari.py --env PongDeterministic-v4 --bhx_train --bhx_size 64 --gru_size 32 --generate_max_steps 100
After it's done, the model and plots will be saved here:
results/Atari/PongDeterministic-v4/gru_32_bhx_64/
-
Test BHX (optional): Inserts the saved BHX model into original rnn model and evaluates the model with environment.
python main_atari.py --env PongDeterministic-v4 --bhx_test --bhx_size 64 --gru_size 32 --generate_max_steps 100
-
Train OX : It involves training QBN for learned observation features(X) given as input to RNN.
python main_atari.py --env PongDeterministic-v4 --ox_train --ox_size 100 --bhx_size 64 --gru_size 32 --generate_max_steps 100
After it's done, the model and plots will be saved here:
results/Atari/PongDeterministic-v4/gru_32_ox_100/
-
Test BHX (optional): Inserts the saved OX model into original rnn model and evaluates the model with environment.
python main_atari.py --env PongDeterministic-v4 --ox_test --ox_size 100 --bhx_size 64 --gru_size 32 --generate_max_steps 100
-
MMN: We form the Moore Machine Network by inserting both the BHX and OX qbn's into the original rnn model. Thereafter the performance of the mmn is evaluated on the environment. Fine-Tuning of MMN is performed if there is a fall in performance which could be caused by accumulated error by both the qbn's.
python main_atari.py --env PongDeterministic-v4 --bgru_train --ox_size 100 --bhx_size 64 --gru_size 32 --generate_max_steps 100
When the fine-tuning is done, model and plots will be saved here:
results/Atari/PongDeterministic-v4/gru_32_hx_(64,100)_bgru
-
Test MMN (optional): Loads and tests the saved MMN model.
python main_atari.py --env PongDeterministic-v4 --bgru_test --bhx_size 64 --ox_size 100 --gru_size 32 --generate_max_steps 100
-
Extract Moore Machine: In this final step, quantized observation and hidden state space are enumarated to form a moore machine. Thereafter minimization is performed on top of it.
python main_atari.py --env PongDeterministic-v4 --generate_fsm --bhx_size 64 --ox_size 100 --gru_size 32 --generate_max_steps 100
Final Results before and after minimization are stored in text files (fsm.txt and minimized_moore_machine.txt ) here:
results/Atari/PongDeterministic-v4/gru_32_hx_(64,100)_bgru/
For results to be easily reproducible, previously trained GRU models on different environments have been provided. You can simply use them to train new QBNs and reproduce the results presented in the paper. Models are accessible through this directory: results/Atari/
. The GRU cell size can be determined from the models' path, i.e. if a model is saved in a folder named as gru_32
, then the GRU cell size is 32.
Having the pretrained GRU model, you can go to how to run the code step by step to start training the QBNs.
Presenting the Mode Counter Environments(MCE) results, number of states and observations of the MMs extracted from the MMNs both before and after minimization. Moore Machine extraction for MCE(table 1 in paper):
Game | Bh | Bf | Fine-Tuning Score | Before Minimization | After Minimization | |||||
---|---|---|---|---|---|---|---|---|---|---|
Before(%) | After(%) | |H| | |O| | Acc(%) | |H| | |O| | Acc(%) | |||
Amnesia (gold rush read) |
4 | 4 | 98 | 100 | 7 | 5 | 100 | 4 | 4 | 100 |
4 | 8 | 99 | 100 | 7 | 7 | 100 | 4 | 4 | 100 | |
8 | 4 | 100 | - | 6 | 5 | 100 | 4 | 4 | 100 | |
8 | 8 | 99 | 100 | 7 | 7 | 100 | 4 | 4 | 100 | |
Blind (gold rush blind) |
4 | 4 | 100 | - | 12 | 6 | 100 | 10 | 1 | 100 |
4 | 8 | 100 | - | 12 | 8 | 100 | 10 | 1 | 100 | |
8 | 4 | 100 | - | 15 | 6 | 100 | 10 | 1 | 100 | |
8 | 8 | 78 | 100 | 13 | 8 | 100 | 10 | 1 | 100 | |
Tracker (gold rush sneak) |
4 | 4 | 98 | 98 | 58 | 5 | 98 | 50 | 4 | 98 |
4 | 8 | 99 | 100 | 23 | 5 | 100 | 10 | 4 | 100 | |
8 | 4 | 98 | 100 | 91 | 5 | 100 | 10 | 4 | 100 | |
8 | 8 | 99 | 100 | 85 | 5 | 100 | 10 | 4 | 100 |
The below table presents the test results for the trained RNNs giving the accuracy over a test set of 100 strings drawn from the same distribution as used for training. Moore Machine extraction for Tomita grammar(table 2 in paper):
Grammar | RNN Acc(%) | Bh | Fine-Tuning Score | Before Minimization | After Minimization | |||
---|---|---|---|---|---|---|---|---|
Before(%) | After(%) | |H| | Acc(%) | |H| | Acc(%) | |||
1 | 100 | 8 | 100 | - | 13 | 100 | 2 | 100 |
100 | 16 | 100 | - | 28 | 100 | 2 | 100 | |
2 | 100 | 8 | 100 | - | 13 | 100 | 3 | 100 |
100 | 16 | 100 | - | 14 | 100 | 3 | 100 | |
3 | 100 | 8 | 100 | - | 34 | 100 | 5 | 100 |
100 | 16 | 100 | - | 39 | 100 | 5 | 100 | |
4 | 100 | 8 | 100 | - | 17 | 100 | 4 | 100 |
100 | 16 | 100 | - | 18 | 100 | 4 | 100 | |
5 | 100 | 8 | 95 | 96 | 192 | 96 | 115 | 96 |
100 | 16 | 100 | - | 316 | 100 | 4 | 100 | |
6 | 99 | 8 | 98 | 98 | 100 | 98 | 12 | 98 |
99 | 16 | 99 | 99 | 518 | 99 | 11 | 99 | |
7 | 100 | 8 | 100 | - | 25 | 100 | 5 | 100 |
100 | 16 | 100 | - | 107 | 100 | 5 | 100 |
To run the whole thing over control tasks, you only need to run the run_control.sh
file. Below, is an example of how to do it:
sh run_control.sh Acrobot-v1 32 64 64
More experiments on control tasks have been done. Results are presented in the following table:
Game(# of actions) | Bh | Bf | Before Minimization | After Minimization | ||||
---|---|---|---|---|---|---|---|---|
|H| | |O| | Score | |H| | |O| | Score | |||
Cart Pole(2) | 64 | 64 | 27 | 859 | 500 | 4 | 32 | 500 |
Lunar Lander(4) | 128 | 64 | 1502 | 1165 | 198 | 52 | 89 | 115 |
Acrobot(3) | 64 | 64 | 769 | 649 | -73.95 | 11 | 23 | -89.4 |
This table shows the performance of the trained MMNs before and after finetuning for different combinations of Bh and Bf. A few more games investigated and the results are added to the table 3 of the paper: Results may slightly vary.
Game(# of actions) | RNN(score) | Bh | Bf | Fine-Tuning Score | Before Minimization | After Minimization | |||||
---|---|---|---|---|---|---|---|---|---|---|---|
Before | After | |H| | |O| | Score | |H| | |O| | Score | ||||
Pong(3) | 21 | 64 | 100 | 20 | 21 | 380 | 374 | 21 | 4 | 12 | 21 |
64 | 400 | 20 | 21 | 373 | 372 | 21 | 3 | 10 | 21 | ||
128 | 100 | 20 | 21 | 383 | 373 | 21 | 3 | 12 | 21 | ||
128 | 400 | 20 | 21 | 379 | 371 | 21 | 3 | 11 | 21 | ||
Freeway(3) | 21 | 64 | 100 | 21 | - | 1 | 1 | 21 | 1 | 1 | 21 |
64 | 400 | 21 | - | 1 | 1 | 21 | 1 | 1 | 21 | ||
128 | 100 | 21 | - | 1 | 1 | 21 | 1 | 1 | 21 | ||
128 | 400 | 21 | - | 1 | 1 | 21 | 1 | 1 | 21 | ||
Breakout(4) | 773 | 64 | 100 | 32 | 423 | 1898 | 1874 | 423 | 8 | 30 | 423 |
64 | 400 | 25 | 415 | 1888 | 1871 | 415 | 8 | 30 | 415 | ||
128 | 100 | 41 | 377 | 1583 | 1514 | 377 | 11 | 27 | 377 | ||
128 | 400 | 85 | 379 | 1729 | 1769 | 379 | 8 | 30 | 379 | ||
Space Invaders(4) | 1820 | 64 | 100 | 520 | 1335 | 1495 | 1502 | 1335 | 8 | 29 | 1335 |
64 | 400 | 365 | 1235 | 1625 | 1620 | 1235 | 12 | 29 | 1235 | ||
128 | 100 | 390 | 1040 | 1563 | 1457 | 1040 | 12 | 35 | 1040 | ||
128 | 400 | 520 | 1430 | 1931 | 1921 | 1430 | 6 | 27 | 1430 | ||
Bowling(6) | 60 | 64 | 100 | 60 | - | 49 | 1 | 60 | 33 | 1 | 60 |
64 | 400 | 60 | - | 49 | 1 | 60 | 33 | 1 | 60 | ||
128 | 100 | 60 | - | 26 | 1 | 60 | 24 | 1 | 60 | ||
128 | 400 | 60 | - | 26 | 1 | 60 | 24 | 1 | 60 | ||
Boxing(18) | 100 | 64 | 100 | 94 | 100 | 1173 | 1167 | 100 | 13 | 79 | 100 |
64 | 400 | 98 | 100 | 2621 | 2605 | 100 | 14 | 119 | 100 | ||
128 | 100 | 94 | 97 | 2499 | 2482 | 97 | 14 | 106 | 97 | ||
128 | 400 | 97 | 100 | 1173 | 1169 | 100 | 14 | 88 | 100 | ||
Chopper Command(18) | 5300 | 64 | 100 | 4000 | 3710 | 3731 | 4000 | 38 | 182 | 1890 |