-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
46 changed files
with
5,242 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
log/* | ||
*.tar.gz | ||
*.csv | ||
*.pyc | ||
*.html | ||
*.jpg | ||
table*.txt | ||
*.pickle | ||
figures/* | ||
recording/* | ||
*.png | ||
*.pickle | ||
configs | ||
data | ||
Replica-Dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,50 @@ | ||
Code will be released soon! | ||
<a rel="license" href="http://creativecommons.org/licenses/by-nc/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc/4.0/">Creative Commons Attribution-NonCommercial 4.0 International License</a>. | ||
|
||
--- | ||
|
||
Implementation of **PVR for Control**, as presented | ||
in [The (Un)Surprising Effectiveness of Pre-Trained Vision Models for Control](https://arxiv.org/abs/2203.03580). | ||
|
||
Part of the code was built on the [RIDE repository](https://github.com/facebookresearch/impact-driven-exploration). | ||
|
||
## Codebase Installation | ||
``` | ||
conda create -n pvr python=3.8 | ||
conda activate pvr | ||
git clone git@github.com:sparisi/habitat_pvr.git | ||
cd pvr | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Habitat Installation | ||
* Clone `https://github.com/sparisi/habitat-lab` and do a **full install** with `habitat_baselines`. | ||
> The main differences between this and the original Habitat repository are: | ||
> 1) `STOP` action removed, | ||
> 2) Bugfix where the agent is placed slightly above the ground, and therefore the | ||
terminal goal condition is never triggered. | ||
|
||
* Download and extract Replica scenes in the root folder of `pvr`. | ||
> WARNING! The dataset is very large! | ||
``` | ||
sudo apt-get install pigz | ||
git clone https://github.com/facebookresearch/Replica-Dataset.git | ||
cd Replica-Dataset | ||
./download.sh replica-path | ||
``` | ||
|
||
If you have already downloaded it somewhere else, just make a symbolic link | ||
``` | ||
ln -s path/to/Replica-Dataset Replica-Dataset | ||
``` | ||
|
||
## How to Run Experiments | ||
There are three main scripts to run behavioral cloning: | ||
* `main_bc_1.py` loads raw trajectories saved as pickles, passes observations (images) | ||
through the embedding, and then learns on the embedded observations. | ||
* `main_bc_2.py` directly loads embedded observations that have already been passed | ||
through the embedding, in order to save time. | ||
* `main_bc_finetune.py` is used to finetune the random PVR. | ||
|
||
For more details on how to generate trajectories and pickles, see the README in | ||
the `behavioral_cloning` folder. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
Place / save here all data used for behavioral cloning (BC), i.e., all | ||
optimal trajectories and the corresponding embedded observations (PVRs). | ||
|
||
## Scripts Description | ||
|
||
* `save_opt_trajectories.py` generates optimal trajectories using Habitat's native solver. | ||
Data is saved as `.pickle` and takes a lot of space. | ||
* `save_opt_trajectories_png.py` same, but image observations are saved as `.png` to save space. | ||
* `save_opt_trajectories_jpeg.py` same, but image observations are saved as `.jpeg`. | ||
This is a lossy format compared to `.png`, but it is also the same used by ImageNet. | ||
Used to collect data for pre-training vision models. | ||
* `save_embedded_obs.py` passes all images of from the aforementioned trajectories | ||
through embedding models. This way, when we run BC we just load these files | ||
rather than raw trajectories (unless embeddings are trained / fine-tuned). | ||
|
||
To recap, the steps to run BC are: | ||
1. For every scene you want to do BC on, generate optimal trajectories using | ||
`save_opt_trajectories.py` or `save_opt_trajectories_png.py`. | ||
2. Pass the images through the desired embedding using `save_embedded_obs.py`. | ||
3. Run one of the `main_bc.py` scripts (in root folder), passing the right scenes and embedding. | ||
|
||
## Example | ||
|
||
From root folder run | ||
``` | ||
python behavioral_cloning/save_opt_trajectories.py --env=HabitatImageNav-apartment_0 | ||
python behavioral_cloning/save_opt_trajectories.py --env=HabitatImageNav-frl_apartment_0 | ||
python behavioral_cloning/save_embedded_obs.py --env=HabitatImageNav-apartment_0 --embedding_name=resnet50 --source=pickle | ||
python behavioral_cloning/save_embedded_obs.py --env=HabitatImageNav-frl_apartment_0 --embedding_name=resnet50 --source=pickle | ||
python main_bc_2.py --env=HabitatImageNav-apartment_0,HabitatImageNav-frl_apartment_0 --embedding_name=resnet50 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
python behavioral_cloning/save_opt_trajectories.py --n_trajectories=10000 --env=HabitatImageNav-apartment_0 | ||
python behavioral_cloning/save_opt_trajectories.py --n_trajectories=10000 --env=HabitatImageNav-frl_apartment_0 | ||
python behavioral_cloning/save_opt_trajectories.py --n_trajectories=10000 --env=HabitatImageNav-room_0 | ||
python behavioral_cloning/save_opt_trajectories.py --n_trajectories=10000 --env=HabitatImageNav-hotel_0 | ||
python behavioral_cloning/save_opt_trajectories.py --n_trajectories=10000 --env=HabitatImageNav-office_0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/bin/bash | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-apartment_0 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-apartment_1 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-apartment_2 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-frl_apartment_0 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-frl_apartment_1 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-frl_apartment_2 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-frl_apartment_3 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-frl_apartment_4 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-frl_apartment_5 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-room_0 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-room_1 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-room_2 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-hotel_0 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-office_0 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-office_1 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-office_2 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-office_3 | ||
python behavioral_cloning/save_opt_trajectories_jpeg.py --n_trajectories=20000 --env=HabitatImageNav-office_4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
python behavioral_cloning/save_opt_trajectories_png.py --n_trajectories=10000 --env=HabitatImageNav-apartment_0 | ||
python behavioral_cloning/save_opt_trajectories_png.py --n_trajectories=10000 --env=HabitatImageNav-frl_apartment_0 | ||
python behavioral_cloning/save_opt_trajectories_png.py --n_trajectories=10000 --env=HabitatImageNav-room_0 | ||
python behavioral_cloning/save_opt_trajectories_png.py --n_trajectories=10000 --env=HabitatImageNav-hotel_0 | ||
python behavioral_cloning/save_opt_trajectories_png.py --n_trajectories=10000 --env=HabitatImageNav-office_0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
import re | ||
import numpy as np | ||
import torch | ||
import itertools | ||
import pickle | ||
from tqdm import tqdm | ||
from torch.nn import functional as F | ||
from torch import nn | ||
import random | ||
import cv2 | ||
|
||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
from src.embeddings import EmbeddingNet | ||
from src.arguments import parser | ||
|
||
parser.add_argument('--n_trajectories', type=int, default=-1) | ||
parser.add_argument('--source', type=str, default='png', choices=['png', 'pickle']) | ||
|
||
|
||
def read_habitat_data_from_pickle(data_path, n_trajectories=-1): | ||
print('loading %s ...' % data_path) | ||
|
||
data = pickle.load(open(data_path + '.pickle', 'rb')) | ||
if n_trajectories == -1: | ||
n_trajectories = len(data['reward']) | ||
|
||
# Merge trajectories | ||
data['obs'] = np.concatenate(data['obs'][:n_trajectories]) | ||
data['action'] = np.concatenate(data['action'][:n_trajectories]) | ||
data['reward'] = np.concatenate(data['reward'][:n_trajectories]) | ||
data['done'] = np.concatenate(data['done'][:n_trajectories]) | ||
data['true_state'] = np.concatenate(data['true_state'][:n_trajectories]) | ||
|
||
n_samples = len(data['reward']) | ||
print(' ', '%d trajectories for a total of %d samples' % (n_trajectories, n_samples)) | ||
print(' ', 'avg. return is', data['reward'].sum() / n_trajectories) | ||
|
||
return data | ||
|
||
|
||
def read_habitat_data_from_png(data_path, model=None, n_trajectories=-1): | ||
print('loading %s ...' % data_path) | ||
data = dict(obs=[], action=[], reward=[], done=[], true_state=[], png=[]) | ||
|
||
if n_trajectories == -1: | ||
# n_trajectories = len([f for f in os.listdir(data_path) if f.endswith('_0.png')]) # too slow | ||
n_trajectories = 100000 | ||
|
||
# Merge trajectories | ||
for t in tqdm(range(n_trajectories)): | ||
try: | ||
tmp = pickle.load(open(os.path.join(data_path, str(t) + '.pickle'), 'rb')) | ||
for k in data.keys(): | ||
try: | ||
data[k].append(tmp[k]) | ||
except: | ||
pass | ||
goal = cv2.imread(os.path.join(data_path, str(t) + '_goal' + '.png')) | ||
if model is not None: | ||
goal = model(torch.from_numpy(goal[None,:])).reshape(-1,) | ||
except: | ||
break | ||
for s in range(500): # 500 is the max step per trajectory according to Habitat's YAML config | ||
try: | ||
obs = cv2.imread(os.path.join(data_path, str(t) + '_' + str(s) + '.png')) | ||
if model is not None: | ||
obs = model(torch.from_numpy(obs[None,:])).reshape(-1,) | ||
data['obs'].append(np.concatenate((obs, goal), -1)) | ||
data['png'] += [os.path.join(data_path, str(t) + '_' + str(s)) + '.png'] | ||
except: | ||
break | ||
|
||
n_trajectories = t | ||
data['obs'] = np.stack(data['obs']) | ||
data['action'] = np.concatenate(data['action']) | ||
data['reward'] = np.concatenate(data['reward']) | ||
data['done'] = np.concatenate(data['done']) | ||
data['true_state'] = np.concatenate(data['true_state']) | ||
|
||
n_samples = len(data['reward']) | ||
print(' ', '%d trajectories for a total of %d samples' % (n_trajectories, n_samples)) | ||
print(' ', 'avg. return is', data['reward'].sum() / n_trajectories) | ||
|
||
return data | ||
|
||
|
||
def run(flags): | ||
save_name = os.path.join(flags.data_path, | ||
flags.env + '_' + | ||
flags.embedding_name + '.pickle') | ||
if os.path.isfile(save_name): | ||
return | ||
|
||
# Fix seeds | ||
torch.manual_seed(flags.run_id) | ||
torch.cuda.manual_seed(flags.run_id) | ||
np.random.seed(flags.run_id) | ||
random.seed(flags.run_id) | ||
|
||
# Device setup | ||
flags.device = None | ||
if torch.cuda.is_available() and not flags.disable_cuda: | ||
print('Using CUDA.') | ||
flags.device = torch.device('cuda') | ||
else: | ||
print('Not using CUDA.') | ||
flags.device = torch.device('cpu') | ||
|
||
# Init models, env, optimizer, ... | ||
embedding_model = EmbeddingNet(flags.embedding_name, | ||
in_channels=3, | ||
pretrained=flags.pretrained_embedding, | ||
train=flags.train_embedding, | ||
disable_cuda=flags.disable_cuda) # Always on GPU, unless CUDA is disabled | ||
|
||
# Save model that will be used in main_bc | ||
emb_path = os.path.join(flags.data_path, flags.embedding_name) | ||
if flags.embedding_name == 'random': | ||
emb_path += '_' + str(flags.run_id) | ||
torch.save({ | ||
'embedding_model_state_dict': embedding_model.state_dict(), | ||
}, emb_path + '.tar') | ||
|
||
print('=== Loading trajectories ===') | ||
|
||
if flags.source == 'png': | ||
data = read_habitat_data_from_png( | ||
os.path.join(flags.data_path, flags.env), | ||
embedding_model, | ||
flags.n_trajectories | ||
) | ||
|
||
if flags.source == 'pickle': | ||
data = read_habitat_data_from_pickle( | ||
os.path.join(flags.data_path, flags.env) | ||
) | ||
|
||
print(' ', 'passing observations through embedding model') | ||
n_samples = data['obs'].shape[0] | ||
n_frames = max(data['obs'].shape[3] // 3, 1) | ||
obs_scene = [] | ||
for i in tqdm(range(0, n_samples, flags.batch_size)): # To avoid OutOfMemory we loop through mini-batches | ||
o = data['obs'][i:i+flags.batch_size] | ||
o = np.concatenate(np.split(o, n_frames, axis=3), axis=0) # (N, H, W, n_frames * 3) -> (N * n_frames, H, W, 3) | ||
o = embedding_model(torch.from_numpy(o)) # (N * n_frames, O) | ||
o = np.concatenate(np.split(o, n_frames, axis=0), axis=-1) # (N, O * n_frames) | ||
obs_scene.append(o) | ||
obs_scene = np.concatenate(obs_scene)[:n_samples] | ||
|
||
obs = np.array(obs_scene) | ||
true_state = data['true_state'][:n_samples] | ||
action = data['action'][:n_samples] | ||
reward = data['reward'][:n_samples] | ||
done = data['done'][:n_samples] | ||
|
||
data = dict(obs=obs, action=action, reward=reward, done=done, true_state=true_state) | ||
|
||
n_samples = len(data['reward']) | ||
assert n_samples > 0, 'no data found' | ||
print(' ', 'total number of samples', n_samples) | ||
|
||
with open(save_name, 'wb') as handle: | ||
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
if __name__ == '__main__': | ||
flags = parser.parse_args() | ||
run(flags) |
Oops, something went wrong.