From 1ce05a11e27eee679cddbec423a0d4b811cfcf39 Mon Sep 17 00:00:00 2001 From: sparisi Date: Sun, 24 Apr 2022 14:42:24 -0400 Subject: [PATCH] cleanup --- README.md | 4 +- src/algos/__init__.py | 5 - src/algos/torchbeast.py | 233 ---------------------------------- src/arguments.py | 34 ----- src/core/__init__.py | 5 - src/core/file_writer.py | 180 -------------------------- src/core/prof.py | 69 ---------- src/core/vtrace.py | 125 ------------------ src/init_models_and_states.py | 101 --------------- src/losses.py | 46 ------- src/utils.py | 153 ---------------------- 11 files changed, 2 insertions(+), 953 deletions(-) delete mode 100644 src/algos/__init__.py delete mode 100644 src/algos/torchbeast.py delete mode 100644 src/core/__init__.py delete mode 100644 src/core/file_writer.py delete mode 100644 src/core/prof.py delete mode 100644 src/core/vtrace.py delete mode 100644 src/init_models_and_states.py delete mode 100644 src/losses.py delete mode 100644 src/utils.py diff --git a/README.md b/README.md index 8fdf86e..12e09da 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,6 @@ Implementation of **PVR for Control**, as presented in [The (Un)Surprising Effectiveness of Pre-Trained Vision Models for Control](https://arxiv.org/abs/2203.03580). -Part of the code was built on the [RIDE repository](https://github.com/facebookresearch/impact-driven-exploration). - ## Codebase Installation ``` conda create -n pvr python=3.8 @@ -48,3 +46,5 @@ through the embedding, in order to save time. For more details on how to generate trajectories and pickles, see the README in the `behavioral_cloning` folder. + +Pre-trained models can be downloaded [here](xxx). diff --git a/src/algos/__init__.py b/src/algos/__init__.py deleted file mode 100644 index 1f7739a..0000000 --- a/src/algos/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. diff --git a/src/algos/torchbeast.py b/src/algos/torchbeast.py deleted file mode 100644 index 76b1ec2..0000000 --- a/src/algos/torchbeast.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import os -import threading -import time -import timeit -import pprint -import itertools -import numpy as np - -import torch -from torch import multiprocessing as mp -from torch import nn -from torch.nn import functional as F - -from src.core import file_writer -from src.core import prof -from src.core import vtrace - -import src.models as models -import src.losses as losses - -from src.utils import get_batch, log, create_buffers, act -from src.init_models_and_states import init_models_and_states - - -def learn(actor_model, - learner_model, - batch, - initial_agent_state, - optimizer, - scheduler, - flags, - lock=threading.Lock()): - """Performs a learning (optimization) step.""" - with lock: - learner_outputs, unused_state = learner_model(batch, initial_agent_state) - - bootstrap_value = learner_outputs['baseline'][-1] - - batch = {key: tensor[1:] for key, tensor in batch.items()} - learner_outputs = { - key: tensor[:-1] - for key, tensor in learner_outputs.items() - } - - rewards = batch['reward'] - if flags.clip_reward: - clipped_rewards = torch.clamp(rewards, -1, 1) - else: - clipped_rewards = rewards / flags.max_reward - - discounts = (1 - batch['done'].float()).abs() * flags.discounting - - vtrace_returns = vtrace.from_logits( - behavior_policy_logits=batch['policy_logits'], - target_policy_logits=learner_outputs['policy_logits'], - actions=batch['action'], - discounts=discounts, - rewards=clipped_rewards, - values=learner_outputs['baseline'], - bootstrap_value=bootstrap_value - ) - - pg_loss = losses.compute_policy_gradient_loss( - learner_outputs['policy_logits'], - batch['action'], - vtrace_returns.pg_advantages - ) - baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( - vtrace_returns.vs - learner_outputs['baseline'] - ) - entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( - learner_outputs['policy_logits'] - ) - - total_loss = pg_loss + baseline_loss + entropy_loss - - episode_returns = batch['episode_return'][batch['done']] - episode_successes = batch['episode_success'][batch['done']] - episode_lengths = batch['episode_step'][batch['done']] - stats = { - 'total_episodes': torch.sum(batch['done'].float()).item(), - 'mean_episode_length': torch.mean(episode_lengths.float()).item(), - 'mean_episode_return': torch.mean(episode_returns).item(), - 'mean_rewards': torch.mean(batch['reward']).item(), - 'max_rewards': torch.max(batch['reward']).item(), - 'min_rewards': torch.min(batch['reward']).item(), - 'mean_episode_success': torch.mean(episode_successes).item(), - 'total_loss': total_loss.item(), - 'pg_loss': pg_loss.item(), - 'baseline_loss': baseline_loss.item(), - 'entropy_loss': entropy_loss.item(), - } - - scheduler.step() - optimizer.zero_grad() - total_loss.backward() - nn.utils.clip_grad_norm_(learner_model.parameters(), flags.max_grad_norm) - optimizer.step() - - actor_model.load_state_dict(learner_model.state_dict()) - return stats - - -def train(flags): - if flags.xpid is None: - flags.xpid = 'vanilla-%s' % time.strftime('%Y%m%d-%H%M%S') - plogger = file_writer.FileWriter( - xpid=flags.xpid, - xp_args=flags.__dict__, - rootdir=flags.savedir, - ) - - env_iterator = itertools.cycle(flags.env.split(',')) - - flags.env = next(env_iterator) - models_and_states = init_models_and_states(flags) - - actor_model = models_and_states['actor_model'] - learner_model = models_and_states['learner_model'] - embedding_model = models_and_states['embedding_model'] - initial_agent_state_buffers = models_and_states['initial_agent_state_buffers'] - learner_model_optimizer = models_and_states['learner_model_optimizer'] - scheduler = models_and_states['scheduler'] - buffers = models_and_states['buffers'] - - actor_processes = [] - ctx = mp.get_context(flags.mp_start) - free_queue = ctx.SimpleQueue() - full_queue = ctx.SimpleQueue() - - for i in range(flags.num_actors): - flags.env = next(env_iterator) - actor = ctx.Process( - target=act, - args=(i, free_queue, full_queue, actor_model, embedding_model, - buffers, initial_agent_state_buffers, flags)) - actor.start() - actor_processes.append(actor) - - frames, stats = 0, {} - - def batch_and_learn(i, lock=threading.Lock()): - """Thread target for the learning process.""" - nonlocal frames, stats - timings = prof.Timings() - while frames < flags.total_frames: - timings.reset() - batch, agent_state = get_batch(free_queue, full_queue, buffers, - initial_agent_state_buffers, flags, timings) - stats = learn(actor_model, learner_model, - batch, agent_state, - learner_model_optimizer, - scheduler, flags) - timings.time('learn') - with lock: - to_log = dict(frames=frames) - to_log.update(stats) - plogger.log(to_log) - frames += flags.unroll_length * flags.batch_size - - if i == 0: - log.info('Batch and learn: %s', timings.summary()) - - for m in range(flags.num_buffers): - free_queue.put(m) - - threads = [] - for i in range(flags.num_threads): - thread = threading.Thread( - target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,)) - thread.start() - threads.append(thread) - - def checkpoint(frames): - if flags.disable_checkpoint: - return - checkpointpath = os.path.expandvars(os.path.expanduser( - '%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar'))) - log.info('Saving checkpoint to %s', checkpointpath) - torch.save({ - 'actor_model_state_dict': actor_model.state_dict(), - 'embedding_model_state_dict': embedding_model.state_dict(), - 'learner_model_optimizer_state_dict': learner_model_optimizer.state_dict(), - 'scheduler_state_dict': scheduler.state_dict(), - 'flags': vars(flags), - }, checkpointpath) - - timer = timeit.default_timer - - try: - last_checkpoint_time = timer() - while frames < flags.total_frames: - start_frames = frames - start_time = timer() - time.sleep(5) - - if timer() - last_checkpoint_time > flags.save_interval * 60: - checkpoint(frames) - last_checkpoint_time = timer() - - fps = (frames - start_frames) / (timer() - start_time) - if stats.get('episode_returns', None): - mean_return = 'Return per episode: %.1f. ' % stats[ - 'mean_episode_return'] - else: - mean_return = '' - total_loss = stats.get('total_loss', float('inf')) - log.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s', - frames, total_loss, fps, mean_return, - pprint.pformat(stats)) - - except KeyboardInterrupt: - return - - else: - for thread in threads: - thread.join() - log.info('Learning finished after %d frames.', frames) - - finally: - for _ in range(flags.num_actors): - free_queue.put(None) - for actor in actor_processes: - actor.join(timeout=1) - - checkpoint(frames) - plogger.close() diff --git a/src/arguments.py b/src/arguments.py index 004f120..448d80e 100644 --- a/src/arguments.py +++ b/src/arguments.py @@ -40,50 +40,20 @@ (instead of a different random seed since torchbeast does not accept this).') parser.add_argument('--seed', default=1, type=int, help='Random seed.') -parser.add_argument('--save_interval', default=10, type=int, - help='Time interval (in minutes) at which to save the model.') -parser.add_argument('--checkpoint_num_frames', default=10000000, type=int, - help='Number of frames for checkpoint to load.') -parser.add_argument('--checkpoint', default=None, - help='Path to model.tar for loading checkpoint from past run.') # Training settings. -parser.add_argument('--disable_checkpoint', action='store_true', - help='Disable saving checkpoint.') -parser.add_argument('--savedir', default='logs', - help='Root dir where experiment data will be saved.') -parser.add_argument('--num_actors', default=40, type=int, - help='Number of actors.') parser.add_argument('--total_frames', default=50000000, type=int, help='Total environment frames to train for.') parser.add_argument('--batch_size', default=32, type=int, help='Learner batch size.') parser.add_argument('--unroll_length', default=100, type=int, help='The unroll length (time dimension).') -parser.add_argument('--queue_timeout', default=1, type=int, - help='Error timeout for queue.') -parser.add_argument('--num_buffers', default=40, type=int, - help='Number of shared-memory buffers.') -parser.add_argument('--num_threads', default=4, type=int, - help='Number learner threads.') parser.add_argument('--mp_start', default='spawn', type=str, help='Start method of multiprocesses. \ Depending on your machine, there can be problems between CUDA \ with some environments. To avoid them, use `spawn`.') parser.add_argument('--disable_cuda', action='store_true', help='Disable CUDA.') -parser.add_argument('--clip_reward', action='store_true', - help='If True, rewards are clipped in [-1,1].') -parser.add_argument('--max_reward', default=1.0, type=float, - help='To normalize rewards (use 1 to keep default rewards).') - -# Loss settings. -parser.add_argument('--entropy_cost', default=0.0005, type=float, - help='Entropy cost/multiplier.') -parser.add_argument('--baseline_cost', default=0.05, type=float, - help='Baseline cost/multiplier.') -parser.add_argument('--discounting', default=0.99, type=float, - help='Discounting factor.') # Optimizer settings. parser.add_argument('--learning_rate', default=0.0001, type=float, @@ -96,7 +66,3 @@ help='RMSProp epsilon.') parser.add_argument('--max_grad_norm', default=40., type=float, help='Max norm of gradients.') - -# Training Models. -parser.add_argument('--algorithm_name', default='vanilla', - help='Algorithm used for training the agent.') diff --git a/src/core/__init__.py b/src/core/__init__.py deleted file mode 100644 index 1f7739a..0000000 --- a/src/core/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. diff --git a/src/core/file_writer.py b/src/core/file_writer.py deleted file mode 100644 index f9b29ae..0000000 --- a/src/core/file_writer.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import datetime -import csv -import json -import logging -import os -import time -from typing import Dict - -import git - - -def gather_metadata() -> Dict: - date_start = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') - # gathering git metadata - try: - repo = git.Repo(search_parent_directories=True) - git_sha = repo.commit().hexsha - git_data = dict( - commit=git_sha, - branch=repo.active_branch.name, - is_dirty=repo.is_dirty(), - path=repo.git_dir, - ) - except:# git.InvalidGitRepositoryError: - git_data = None - # gathering slurm metadata - if 'SLURM_JOB_ID' in os.environ: - slurm_env_keys = [k for k in os.environ if k.startswith('SLURM')] - slurm_data = {} - for k in slurm_env_keys: - d_key = k.replace('SLURM_', '').replace('SLURMD_', '').lower() - slurm_data[d_key] = os.environ[k] - else: - slurm_data = None - return dict( - date_start=date_start, - date_end=None, - successful=False, - git=git_data, - slurm=slurm_data, - env=os.environ.copy(), - ) - - -class FileWriter: - def __init__(self, - xpid: str = None, - xp_args: dict = None, - rootdir: str = '~/palaas'): - if not xpid: - # make unique id - xpid = '{proc}_{unixtime}'.format( - proc=os.getpid(), unixtime=int(time.time())) - self.xpid = xpid - self._tick = 0 - - # metadata gathering - if xp_args is None: - xp_args = {} - self.metadata = gather_metadata() - # we need to copy the args, otherwise when we close the file writer - # (and rewrite the args) we might have non-serializable objects (or - # other nasty stuff). - self.metadata['args'] = copy.deepcopy(xp_args) - self.metadata['xpid'] = self.xpid - - formatter = logging.Formatter('%(message)s') - self._logger = logging.getLogger('palaas/out') - - # to stdout handler - shandle = logging.StreamHandler() - shandle.setFormatter(formatter) - self._logger.addHandler(shandle) - self._logger.setLevel(logging.INFO) - - rootdir = os.path.expandvars(os.path.expanduser(rootdir)) - # to file handler - self.basepath = os.path.join(rootdir, self.xpid) - - if not os.path.exists(self.basepath): - self._logger.info('Creating log directory: %s', self.basepath) - os.makedirs(self.basepath, exist_ok=True) - else: - self._logger.info('Found log directory: %s', self.basepath) - - # NOTE: remove latest because it creates errors when running on slurm - # multiple jobs trying to write to latest but cannot find it - # Add 'latest' as symlink unless it exists and is no symlink. - # symlink = os.path.join(rootdir, 'latest') - # if os.path.islink(symlink): - # os.remove(symlink) - # if not os.path.exists(symlink): - # os.symlink(self.basepath, symlink) - # self._logger.info('Symlinked log directory: %s', symlink) - - self.paths = dict( - msg='{base}/out.log'.format(base=self.basepath), - logs='{base}/logs.csv'.format(base=self.basepath), - fields='{base}/fields.csv'.format(base=self.basepath), - meta='{base}/meta.json'.format(base=self.basepath), - ) - - self._logger.info('Saving arguments to %s', self.paths['meta']) - if os.path.exists(self.paths['meta']): - self._logger.warning('Path to meta file already exists. ' - 'Not overriding meta.') - else: - self._save_metadata() - - self._logger.info('Saving messages to %s', self.paths['msg']) - if os.path.exists(self.paths['msg']): - self._logger.warning('Path to message file already exists. ' - 'New data will be appended.') - - fhandle = logging.FileHandler(self.paths['msg']) - fhandle.setFormatter(formatter) - self._logger.addHandler(fhandle) - - self._logger.info('Saving logs data to %s', self.paths['logs']) - self._logger.info('Saving logs\' fields to %s', self.paths['fields']) - if os.path.exists(self.paths['logs']): - self._logger.warning('Path to log file already exists. ' - # 'New data will be appended.') - 'Old data will be deleted.') - os.remove(self.paths['logs']) - with open(self.paths['fields'], 'r') as csvfile: - reader = csv.reader(csvfile) - self.fieldnames = list(reader)[0] - else: - self.fieldnames = ['_tick', '_time'] - - def log(self, to_log: Dict, tick: int = None, - verbose: bool = False) -> None: - if tick is not None: - raise NotImplementedError - else: - to_log['_tick'] = self._tick - self._tick += 1 - to_log['_time'] = time.time() - - old_len = len(self.fieldnames) - for k in to_log: - if k not in self.fieldnames: - self.fieldnames.append(k) - if old_len != len(self.fieldnames): - with open(self.paths['fields'], 'w') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(self.fieldnames) - self._logger.info('Updated log fields: %s', self.fieldnames) - - if to_log['_tick'] == 0: - # print("\ncreating logs file ") - with open(self.paths['logs'], 'a') as f: - f.write('# %s\n' % ','.join(self.fieldnames)) - - if verbose: - self._logger.info('LOG | %s', ', '.join( - ['{}: {}'.format(k, to_log[k]) for k in sorted(to_log)])) - - with open(self.paths['logs'], 'a') as f: - writer = csv.DictWriter(f, fieldnames=self.fieldnames) - writer.writerow(to_log) - # print("\nadded to log file") - - def close(self, successful: bool = True) -> None: - self.metadata['date_end'] = datetime.datetime.now().strftime( - '%Y-%m-%d %H:%M:%S.%f') - self.metadata['successful'] = successful - self._save_metadata() - - def _save_metadata(self) -> None: - with open(self.paths['meta'], 'w') as jsonfile: - json.dump(self.metadata, jsonfile, indent=4, sort_keys=True) diff --git a/src/core/prof.py b/src/core/prof.py deleted file mode 100644 index 82a5689..0000000 --- a/src/core/prof.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Naive profiling using timeit.""" - -import collections -import timeit - - -class Timings: - """Not thread-safe.""" - - def __init__(self): - self._means = collections.defaultdict(int) - self._vars = collections.defaultdict(int) - self._counts = collections.defaultdict(int) - self.reset() - - def reset(self): - self.last_time = timeit.default_timer() - - def time(self, name): - """Save an update for event `name`. - - Nerd alarm: We could just store a - collections.defaultdict(list) - and compute means and standard deviations at the end. But thanks to the - clever math in Sutton-Barto - (http://www.incompleteideas.net/book/first/ebook/node19.html) and - https://math.stackexchange.com/a/103025/5051 we can update both the - means and the stds online. O(1) FTW! - """ - now = timeit.default_timer() - x = now - self.last_time - self.last_time = now - - n = self._counts[name] - - mean = self._means[name] + (x - self._means[name]) / (n + 1) - var = (n * self._vars[name] + n * (self._means[name] - mean)**2 + - (x - mean)**2) / (n + 1) - - self._means[name] = mean - self._vars[name] = var - self._counts[name] += 1 - - def means(self): - return self._means - - def vars(self): - return self._vars - - def stds(self): - return {k: v**0.5 for k, v in self._vars.items()} - - def summary(self, prefix=''): - means = self.means() - stds = self.stds() - total = sum(means.values()) - - result = prefix - for k in sorted(means, key=means.get, reverse=True): - result += f'\n %s: %.6fms +- %.6fms (%.2f%%) ' % ( - k, 1000 * means[k], 1000 * stds[k], 100 * means[k] / total) - result += '\nTotal: %.6fms' % (1000 * total) - return result diff --git a/src/core/vtrace.py b/src/core/vtrace.py deleted file mode 100644 index 7ee34dc..0000000 --- a/src/core/vtrace.py +++ /dev/null @@ -1,125 +0,0 @@ -# This file taken from -# https://github.com/deepmind/scalable_agent/blob/ -# cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py -# and modified. - -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functions to compute V-trace off-policy actor critic targets. - -For details and theory see: - -"IMPALA: Scalable Distributed Deep-RL with -Importance Weighted Actor-Learner Architectures" -by Espeholt, Soyer, Munos et al. - -See https://arxiv.org/abs/1802.01561 for the full paper. -""" - -import collections - -import torch -import torch.nn.functional as F - -VTraceFromLogitsReturns = collections.namedtuple('VTraceFromLogitsReturns', [ - 'vs', 'pg_advantages', 'log_rhos', 'behavior_action_log_probs', - 'target_action_log_probs' -]) - -VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages') - - -def action_log_probs(policy_logits, actions): - return -F.nll_loss( - F.log_softmax(torch.flatten(policy_logits, 0, 1), dim=-1), - torch.flatten(actions, 0, 1), - reduction='none').view_as(actions) - - -def from_logits(behavior_policy_logits, - target_policy_logits, - actions, - discounts, - rewards, - values, - bootstrap_value, - clip_rho_threshold=1.0, - clip_pg_rho_threshold=1.0): - """V-trace for softmax policies.""" - - target_action_log_probs = action_log_probs(target_policy_logits, actions) - behavior_action_log_probs = action_log_probs(behavior_policy_logits, - actions) - log_rhos = target_action_log_probs - behavior_action_log_probs - vtrace_returns = from_importance_weights( - log_rhos=log_rhos, - discounts=discounts, - rewards=rewards, - values=values, - bootstrap_value=bootstrap_value, - clip_rho_threshold=clip_rho_threshold, - clip_pg_rho_threshold=clip_pg_rho_threshold) - return VTraceFromLogitsReturns( - log_rhos=log_rhos, - behavior_action_log_probs=behavior_action_log_probs, - target_action_log_probs=target_action_log_probs, - **vtrace_returns._asdict()) - - -@torch.no_grad() -def from_importance_weights(log_rhos, - discounts, - rewards, - values, - bootstrap_value, - clip_rho_threshold=1.0, - clip_pg_rho_threshold=1.0): - """V-trace from log importance weights.""" - with torch.no_grad(): - rhos = torch.exp(log_rhos) - if clip_rho_threshold is not None: - clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold) - else: - clipped_rhos = rhos - - cs = torch.clamp(rhos, max=1.0) - # Append bootstrapped value to get [v1, ..., v_t+1] - values_t_plus_1 = torch.cat( - [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0) - deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) - - acc = torch.zeros_like(bootstrap_value) - result = [] - for t in range(discounts.shape[0] - 1, -1, -1): - acc = deltas[t] + discounts[t] * cs[t] * acc - result.append(acc) - result.reverse() - vs_minus_v_xs = torch.stack(result) - - # Add V(x_s) to get v_s. - vs = torch.add(vs_minus_v_xs, values) - - # Advantage for policy gradient. - vs_t_plus_1 = torch.cat( - [vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0) - if clip_pg_rho_threshold is not None: - clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) - else: - clipped_pg_rhos = rhos - pg_advantages = (clipped_pg_rhos * - (rewards + discounts * vs_t_plus_1 - values)) - - # Make sure no gradients backpropagated through the returned values. - return VTraceReturns(vs=vs, pg_advantages=pg_advantages) diff --git a/src/init_models_and_states.py b/src/init_models_and_states.py deleted file mode 100644 index 5d28d3b..0000000 --- a/src/init_models_and_states.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -import os -import sys -import threading -import time - -from copy import deepcopy -import numpy as np - -import torch -from torch import multiprocessing as mp -from torch import nn -from torch.nn import functional as F - -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False - -from src.core import file_writer -from src.models import PolicyNet -from src.embeddings import EmbeddingNet -from src.gym_wrappers import make_gym_env -from src.utils import get_batch, log, create_buffers - - -def init_models_and_states(flags): - """Initialize models and LSTM states for all algorithms.""" - torch.manual_seed(flags.run_id) - torch.cuda.manual_seed(flags.run_id) - np.random.seed(flags.run_id) - - # Set device - flags.device = None - if not flags.disable_cuda and torch.cuda.is_available(): - log.info('Using CUDA.') - flags.device = torch.device('cuda') - else: - log.info('Not using CUDA.') - flags.device = torch.device('cpu') - - # Init embedding - embedding_model = EmbeddingNet(flags.embedding_name, - in_channels=3, - pretrained=flags.pretrained_embedding, - train=flags.train_embedding) - - # Retrieve action_space and observation_space shapes - env = make_gym_env(flags, embedding_model) - obs_shape = env.observation_space.shape - n_actions = env.action_space.n - env.close() - - # Init policy models - actor_model = PolicyNet(obs_shape, n_actions) - learner_model = PolicyNet(obs_shape, n_actions).to(device=flags.device) - - # Load models (if there is a checkpoint) - if flags.checkpoint: - log.info(' ... loading model from %s %s ', flags.checkpoint, '...') - checkpoint = torch.load(flags.checkpoint) - actor_model.load_state_dict(checkpoint["actor_model_state_dict"]) - learner_model = deepcopy(actor_model).to(device=flags.device) - embedding_model.load_state_dict(checkpoint["embedding_model_state_dict"]) - - # Actors will run across multiple processes - actor_model.share_memory() - embedding_model.share_memory() - - # Init LSTM states - initial_agent_state_buffers = [] - for _ in range(flags.num_buffers): - state = actor_model.initial_state(batch_size=1) - for t in state: - t.share_memory_() - initial_agent_state_buffers.append(state) - - # Init optimizers - learner_model_optimizer = torch.optim.RMSprop( - learner_model.parameters(), - lr=flags.learning_rate, - momentum=flags.momentum, - eps=flags.epsilon, - alpha=flags.alpha) - - # LR scheduler - def lr_lambda(epoch): - x = np.maximum(flags.total_frames, 5e6) - return 1 - min(epoch * flags.unroll_length * flags.batch_size, x) / x - scheduler = torch.optim.lr_scheduler.LambdaLR(learner_model_optimizer, lr_lambda) - - # Buffer - buffers = create_buffers(obs_shape, n_actions, flags) - - return dict( - actor_model=actor_model, - learner_model=learner_model, - embedding_model=embedding_model, - initial_agent_state_buffers=initial_agent_state_buffers, - learner_model_optimizer=learner_model_optimizer, - scheduler=scheduler, - buffers=buffers, - ) diff --git a/src/losses.py b/src/losses.py deleted file mode 100644 index 9ccb0db..0000000 --- a/src/losses.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torch import nn -from torch.nn import functional as F -import numpy as np - - -def compute_baseline_loss(advantages): - return 0.5 * torch.sum(torch.mean(advantages**2, dim=1)) - - -def compute_entropy_loss(logits): - policy = F.softmax(logits, dim=-1) - log_policy = F.log_softmax(logits, dim=-1) - entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1) - return -torch.sum(torch.mean(entropy_per_timestep, dim=1)) - - -def compute_policy_gradient_loss(logits, actions, advantages): - cross_entropy = F.nll_loss( - F.log_softmax(torch.flatten(logits, 0, 1), dim=-1), - target=torch.flatten(actions, 0, 1), - reduction='none') - cross_entropy = cross_entropy.view_as(advantages) - advantages.requires_grad = False - policy_gradient_loss_per_timestep = cross_entropy * advantages - return torch.sum(torch.mean(policy_gradient_loss_per_timestep, dim=1)) - - -def compute_forward_dynamics_loss(pred_next_emb, next_emb): - forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2) - return torch.sum(torch.mean(forward_dynamics_loss, dim=1)) - - -def compute_inverse_dynamics_loss(pred_actions, true_actions): - inverse_dynamics_loss = F.nll_loss( - F.log_softmax(torch.flatten(pred_actions, 0, 1), dim=-1), - target=torch.flatten(true_actions, 0, 1), - reduction='none') - inverse_dynamics_loss = inverse_dynamics_loss.view_as(true_actions) - return torch.sum(torch.mean(inverse_dynamics_loss, dim=1)) diff --git a/src/utils.py b/src/utils.py deleted file mode 100644 index 4fa485b..0000000 --- a/src/utils.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import typing -import gym -import threading -from torch import multiprocessing as mp -import logging -import traceback -import os -import numpy as np -from copy import deepcopy - -from src.core import prof -from src.env_utils import make_environment - - -shandle = logging.StreamHandler() -shandle.setFormatter( - logging.Formatter( - '[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] ' - '%(message)s')) -log = logging.getLogger('torchbeast') -log.propagate = False -log.addHandler(shandle) -log.setLevel(logging.INFO) - -Buffers = typing.Dict[str, typing.List[torch.Tensor]] - - -def get_batch(free_queue: mp.SimpleQueue, - full_queue: mp.SimpleQueue, - buffers: Buffers, - agent_state_buffers, - flags, - timings, - lock=threading.Lock()): - with lock: - timings.time('lock') - indices = [full_queue.get() for _ in range(flags.batch_size)] - timings.time('dequeue') - - batch = { - key: torch.stack([buffers[key][m] for m in indices], dim=1) - for key in buffers - } - agent_state = ( - torch.cat(ts, dim=1) - for ts in zip(*[agent_state_buffers[m] for m in indices]) - ) - timings.time('batch') - - for m in indices: - free_queue.put(m) - timings.time('enqueue') - - batch = { - k: t.to(device=flags.device, non_blocking=True) - for k, t in batch.items() - } - agent_state = tuple(t.to(device=flags.device, non_blocking=True) - for t in agent_state) - timings.time('device') - - return batch, agent_state - - -def create_buffers(obs_shape, num_actions, flags) -> Buffers: - T = flags.unroll_length - specs = dict( - obs=dict(size=(T + 1, *obs_shape), dtype=torch.float32), - action=dict(size=(T + 1,), dtype=torch.int64), - reward=dict(size=(T + 1,), dtype=torch.float32), - done=dict(size=(T + 1,), dtype=torch.bool), - episode_return=dict(size=(T + 1,), dtype=torch.float32), - episode_success=dict(size=(T + 1,), dtype=torch.float32), - episode_step=dict(size=(T + 1,), dtype=torch.int32), - policy_logits=dict(size=(T + 1, num_actions), dtype=torch.float32), - baseline=dict(size=(T + 1,), dtype=torch.float32), - ) - buffers: Buffers = {key: [] for key in specs} - for _ in range(flags.num_buffers): - for key in buffers: - buffers[key].append(torch.empty(**specs[key]).share_memory_()) - return buffers - - -def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue, - actor_model: torch.nn.Module, embedding_model: torch.nn.Module, - buffers: Buffers, initial_agent_state_buffers, flags): - try: - timings = prof.Timings() - - env = make_environment(flags, embedding_model, i) - - log.info('Actor %i started on environment %s ...', i, flags.env) - - env_output = env.initial() - - agent_state = actor_model.initial_state(batch_size=1) - agent_output, unused_state = actor_model(env_output, agent_state) - - while True: - index = free_queue.get() - if index is None: - break - - # Write old rollout end - for key in env_output: - buffers[key][index][0, ...] = env_output[key] - for key in agent_output: - buffers[key][index][0, ...] = agent_output[key] - for key, tensor in enumerate(agent_state): - initial_agent_state_buffers[index][key][...] = tensor - - # Do new rollout - for t in range(flags.unroll_length): - timings.reset() - - with torch.no_grad(): - agent_output, agent_state = actor_model(env_output, agent_state) - - timings.time('actor_model') - - env_output = env.step(agent_output['action']) - - timings.time('step') - - for key in env_output: - buffers[key][index][t + 1, ...] = env_output[key] - - for key in agent_output: - buffers[key][index][t + 1, ...] = agent_output[key] - - timings.time('write') - - full_queue.put(index) - - if i == 0: - log.info('Actor %i: %s', i, timings.summary()) - - except KeyboardInterrupt: - pass - - except Exception as e: - logging.error('Exception in worker process %i', i) - traceback.print_exc() - print() - raise e