Skip to content
/ SR-DICE Public

Author's PyTorch implementation of SR-DICE for marginalized importance sampling

License

Notifications You must be signed in to change notification settings

sfujim/SR-DICE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

A Deep Reinforcement Learning Approach to Marginalized Importance Sampling with the Successor Representation

Code for Successor Representation DIstribution Correction Estimation (SR-DICE) a marginalized importance sampling method which builds off of deep successor representation. The paper will be presented at ICML 2021.

Code is provided for both continuous and discrete domains. Results were collected with MuJoCo 1.50 on OpenAI gym 0.17.2. Networks are trained using PyTorch 1.4.0 and Python 3.7.

Usage

Continuous

Train expert:

python train_expert.py

Collect data & train SR-DICE:

python main.py

Discrete

Train expert:

python main.py --train_behavioral

Collect data:

python main.py --generate_buffer

Train SR-DICE:

python main.py

Bibtex

@InProceedings{fujimoto2021srdice,
  title = 	 {A Deep Reinforcement Learning Approach to Marginalized Importance Sampling with the Successor Representation},
  author =       {Fujimoto, Scott and Meger, David and Precup, Doina},
  booktitle = 	 {Proceedings of the 38th International Conference on Machine Learning},
  pages = 	 {3518--3529},
  year = 	 {2021},
  volume = 	 {139},
  publisher =    {PMLR},
}

About

Author's PyTorch implementation of SR-DICE for marginalized importance sampling

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages