Interpretable and Personalized Apprenticeship Scheduling: Learning Interpretable Scheduling Policies from Heterogeneous User Demonstrations
This is the codebase for creating Personalized Neural Trees (PNTs) and generating performance results in each domain including the low-dim, synthetic scheduling, and Taxi domain.
Requirements are included in the requirements.txt
file, and this repo itself is a requirement. Install by running the following in the main directory:
$ pip install -r requirements.txt
$ pip install -e .
The table below relates author names within the results above to filenames.
Method Name | FileNAme |
---|---|
Our Method | PNT |
Sammut et. al. | NN |
Nikolaidis et. al. | k-means_NN |
Tamar et. al. | SNN |
Hsiao et. al. | cVAE |
Gombolay et. al. | DT_pairwise |
InfoGAIL | InfoGAIL |
Our Method (interpretable) | Crisp |
Our Method (DT+\omega) | DT_our_embeddings |
For the below domains, results can be generated by running the command
python file_name.py
For this environment, the generate_environment.py
file is in charge of creating a dataset. The datasets folder has several datasets we created and train/test upon.
Each of the methods are found within this subfolder, corresponding to algorithm names within their respective papers.
For this environment, generating successful schedule trajectories takes some time. The folder create_scheduling_data
has the files necessary to generate large datasets. Running create_data.py can create any number of schedules, averaging about 1000 a day.
As these datasets are large, we have added them to a box repo https://gatech.box.com/s/ogm20qcry0h12gf1r8b5r2vuud0vm6ye. However, as we use seeds, you are able to generate your own data. Using 250 schedules will produce the same dataset as ours. Once you have created a set of demonstrated schedules, please update the path to locate these.
Each of the methods are found within this subfolder, corresponding to algorithm names within their respective papers. Note that each algorithm will run several times for k-fold cross` validation, sampling a different subset of 150 schedules to train upon.
For this environment, data generation is done through taxi_sim.py
. This file is the images in the folder taxi_tree_images
represented as if statements and generates trajectories by sampling from users.
We have several datasets creating by randomly sampling demonstrators and initial states. Permuting them for training and testing gives several accuracy measures we use to compute the k-fold cross validation accuracy reported within the paper.
Each of the methods are found within this subfolder, corresponding to algorithm names within their respective papers.
- Pretraining the policy before starting variational inference typically leads to better performance and stability during training.
- Discretization is very sensitive and does not have a simple relationship with policy performance.