A deep learning model to predict the speedup obtained from applying a sequence of transformations on an input program.
Install the environment using the environment.yml
file as follows:
conda env create -f environment.yml
This should create an environment called cost_model_env
.
Whenever you want to use the model, you need to activate the environment as follows:
conda activate cost_model_env
All of the main scripts use Hydra for configuration management. To configure the repository, fill the configuration template conf/config.yaml
with the paths and parameters required.
While using one of the following script files, you can override any configuration in the conf file. For example, to modify the batch size to 512 for training, use the following command. The parameter should be included with its section name.
python generate_dataset.py data_generation.batch_size=512
Currently, we have separated the data loading and training from each other. This is because the data loading is very time-consuming, and we don't want to redo it for every training. To solve this, we run a script to load the raw data (JSON), extract the representation for each datapoint, and then save the batched data in a .pt
file that can be loaded directly into memory for training. We call this process data generation.
To generate the dataset, run the python script generate_dataset.py
(after configuring the repository):
python generate_dataset.py
A sample from the dataset is provided in the dataset_samples
folder as a pickle file. This sample contains approximately 80,000 data points, divided into a training set and a validation set. The training set includes 600 synthetic Tiramisu programs (~60,000 schedules), while the validation set consists of 125 synthetic programs (~20,000 schedules).
To run the training, run the python script train_model.py
(after configuring the repository and generating the dataset):
python train_model.py
The repository allows the use Weights and Biases for visualization. To enable it, set the use_wandb
parameter to True
, after logging into wandb from the command line. The project name should be specified. This name does not have to already exist in wandb. During training, the progress can be found on the wandb platform.
To evaluate the trained model, run the python script evaluate_model.py
(after configuring the repository and generating the dataset):
python evaluate_model.py