Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IQN #139

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft

IQN #139

wants to merge 6 commits into from

Conversation

qgallouedec
Copy link
Contributor

@qgallouedec qgallouedec commented Jan 26, 2023

Description

Context

  • I have raised an issue to propose this change (required)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • The functionality/performance matches that of the source (required for new training algorithms or training-related features).
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have included an example of using the feature (required for new features).
  • I have included baseline results (required for new training algorithms or training-related features).
  • I have updated the documentation accordingly.
  • I have updated the changelog accordingly (required).
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

Note: we are using a maximum length of 127 characters per line

@qgallouedec qgallouedec changed the title Feat/iqn IQN Jan 26, 2023
@qgallouedec
Copy link
Contributor Author

qgallouedec commented Jan 27, 2023

Results comparison

Current implementation Reference
(6 seeds) iqn from https://github.com/toshikwa/fqf-iqn-qrdqn.pytorch (1 seed, same parameters) image
W B Chart 28_01_2023, 23_13_44 (2 seeds, same parameters) image
https://di-engine-docs.readthedocs.io/en/latest/12_policies/iqn.html#benchmark

@emrul
Copy link

emrul commented Feb 25, 2023

@qgallouedec Thank you for adding this. I wanted to report that for me it works well and I was able to adapt it to implement the paper Self-Imitation Advantage Learning
. I'm not sure how useful it is for you but I'm happy to share my modifications to add SAIL-IQN to your IQN implementation (I don't have the resources right now to submit this as a separate PR):

New replay buffer to store discounted returns (G):

import warnings
import itertools
from typing import Generator, Optional, Union, NamedTuple, List, Dict, Any
import numpy as np
import torch as th
from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.vec_env import VecNormalize
from gymnasium import spaces

PLACEHOLDER_RETURN_VALUE = np.finfo(np.float32).min

class SAILReplayBufferSamples(NamedTuple):
    observations: th.Tensor
    actions: th.Tensor
    next_observations: th.Tensor
    dones: th.Tensor
    rewards: th.Tensor
    returns: th.Tensor

class SAILReplayBuffer(ReplayBuffer):
    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        device: Union[th.device, str] = "cpu",
        n_envs: int = 1,
        optimize_memory_usage: bool = False,
        gamma: float = 0.99
    ):
        super().__init__(buffer_size, observation_space, action_space, device, n_envs, optimize_memory_usage)
        ## TODO: Haven't looked at supporting optimize_memory_usage true yet
        # assert optimize_memory_usage == False, 'optimize_memory_usage does not work with SAIL currently'
        self.gamma = gamma
        self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        # For each env store where the episode starts (0 for all envs at the beginning)
        # but will vary as each episode can end at a different point
        self.episode_start_indices = np.zeros(self.n_envs, dtype=np.int32)

    def update_episodic_return(self, completed_env_indices: np.ndarray, episode_end_idx: int):
        # completed_env_indices - indices of envs with completed episode
        # For all episodes that have ended episode_end_pos will contain the end position though
        # it may be infrequent for multiple episodes to end at the same position
        for env_idx in completed_env_indices:
            # episode_start_idx can be > episode_end_idx due to buffer wrap-around
            episode_start_idx = self.episode_start_indices[env_idx]
            G = 0
            x = 0
            i = episode_end_idx # index used to calculate discounted return
            if episode_start_idx < episode_end_idx:
                max_episode_steps = episode_end_idx - episode_start_idx
            else:
                # This won't be accurate if we've wrapped around more than once but we should somewhere require
                # max_episode_steps to be less than buffer size to prevent that from happening.
                max_episode_steps = self.buffer_size - episode_end_idx + episode_start_idx
            while x <= max_episode_steps:
                G = self.rewards[i, env_idx] + self.gamma * G
                self.returns[i, env_idx] = G
                i = (i - 1) % self.buffer_size
                x += 1
                pass

        pass
    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        done: np.ndarray,
        infos: List[Dict[str, Any]],
    ) -> None:
        # we want position before it gets updated
        pos = self.pos
        super().add(obs=obs, next_obs=next_obs, action=action, reward=reward, done=done, infos=infos)
        self.returns[pos] = np.repeat(PLACEHOLDER_RETURN_VALUE, repeats=self.n_envs)
        if np.any(done):
            # Only use dones that are not due to timeouts
            true_dones = done * (1 - self.timeouts[pos])
            if np.any(true_dones):
                self.update_episodic_return(np.flatnonzero(true_dones), pos)
            # Update episode start indices (whether due to timeout or not) to the current start index
            # of the next episode (self.pos)
            np.put_along_axis(self.episode_start_indices, np.flatnonzero(done), self.pos, axis=0)
            pass
        return

    def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> SAILReplayBufferSamples:
        # noinspection PyTypeChecker
        return super().sample(batch_size=batch_size, env=env)

    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> SAILReplayBufferSamples:
        # Sample randomly the env idx
        env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))

        if self.optimize_memory_usage:
            next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
        else:
            next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)

        data = (
            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
            self.actions[batch_inds, env_indices, :],
            next_obs,
            # Only use dones that are not due to timeouts
            # deactivated by default (timeouts is initialized as an array of False)
            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
            self.returns[batch_inds, env_indices]
        )
        return SAILReplayBufferSamples(*tuple(map(self.to_torch, data)))

and updated training loop:

        # Sample replay buffer
        replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
        with th.no_grad():
            # BEGIN - SAIL addition.
            rewards = replay_data.rewards
            # Ref to https://github.com/google-research/google-research/blob/master/sail_rl/agents/sail_iqn.py
            # Calculates **current state** action-values.
            # Shape: batch_size x n_quantiles x num_actions.
            replay_target_net_outputs = self.quantile_net_target(replay_data.observations, self.n_quantiles)

            # Shape: batch_size x num_actions
            replay_target_q_values = replay_target_net_outputs.mean(dim=1)

            replay_action_one_hot = th.nn.functional.one_hot(replay_data.actions.squeeze(-1), self.action_space.n).type(th.float32)
            replay_target_q = th.max(replay_target_q_values, dim=1).values
            replay_target_q_al = th.sum(replay_action_one_hot * replay_target_q_values, dim=1)
            comp_value = th.max(replay_target_q_al, replay_data.returns)

            if self.clip > 0.:
                sil_bonus = self.alpha * th.clamp(comp_value - replay_target_q, min=-self.clip, max=self.clip)
            else:
                sil_bonus = self.alpha * (comp_value - replay_target_q)

            rewards = rewards + sil_bonus.unsqueeze(-1)
            # END - SAIL addition

            # Compute the quantiles of next observation
            next_quantiles = self.quantile_net_target(replay_data.next_observations, self.n_quantiles)

            # Shape of next_quantiles:
            # batch_size x n_quantiles x num_actions.
            # e.g. if num_actions is 2, it might look something like this:
            # Vals for Quantile .2  Vals for Quantile .4  Vals for Quantile .6
            #    [[0.1, 0.5],         [0.15, -0.3],          [0.15, -0.2]]
            # Q-values = [(0.1 + 0.15 + 0.15)/3, (0.5 + 0.15 + -0.2)/3].

            # Compute the greedy actions which maximize the next Q values
            next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True)

            # Make "num_tau_prime_samples" copies of actions, and reshape to (batch_size, num_tau_prime_samples, 1)
            next_greedy_actions = next_greedy_actions.expand(batch_size, self.num_tau_prime_samples, 1)

            # Compute the quantiles of next observation, but with another number of tau samples
            next_quantiles = self.quantile_net_target(replay_data.next_observations, self.num_tau_prime_samples)

            # Follow greedy policy: use the one with the highest Q values
            next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2)

            # 1-step TD target
            target_quantiles = rewards + (1 - replay_data.dones) * self.gamma * next_quantiles

        # Get current quantile estimates
        current_quantiles = self.quantile_net(replay_data.observations, self.num_tau_samples)

        # Make "num_tau_samples" copies of actions, and reshape to (batch_size, num_tau_samples, 1).
        actions = replay_data.actions[..., None].long().expand(batch_size, self.num_tau_samples, 1)

        # Retrieve the quantiles for the actions from the replay buffer
        current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)

        # Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
        loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
        return loss

The extra parameters alpha and clip are defaulted to 0.9 and 1.0.

I found immediately that SAIL-IQN performs nicely on sparse rewards so am quite happy with my initial results but by no means has my testing been thorough.

@qgallouedec
Copy link
Contributor Author

Thanks for your feedback @emrul! This PR is still draft because I can't replicate exactly the results of the paper for Qbert. I don't know if it's a hyperparameter problem or something else, I'm still looking.

I think SIL (and probably maybe SAIL) would fit in SB3-contrib. However, it would be best to discuss it in a dedicated issue. I'll open it right away.

@qgallouedec qgallouedec mentioned this pull request Feb 26, 2023
@emrul
Copy link

emrul commented Feb 26, 2023

Thanks @qgallouedec - I didn't know there's a reproduction issue, I will look into this also - I compared your implementation with the Dopamine one and the Medipexel/pytorch port of that and it looked quite different. I will dig in to see where they differ and feedback if I find anything to assist.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants