Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AlphaLoss to support entropy schedules #186

Merged
merged 10 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions emote/algorithms/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from typing import Any, Dict, Optional

import numpy as np
import torch

from torch import nn, optim

from emote.callback import Callback
from emote.callbacks.loss import LossCallback
from emote.extra.schedules import ConstantSchedule, Schedule
from emote.mixins.logging import LoggingMixin
from emote.proxies import AgentProxy, GenericAgentProxy
from emote.utils.deprecated import deprecated
Expand Down Expand Up @@ -245,17 +245,17 @@ class AlphaLoss(LossCallback):
probability given a state.
:param ln_alpha (torch.tensor): The current weight for the entropy part of the
soft Q.
:param lr_schedule (torch.optim.lr_scheduler._LRSchedule): Learning rate schedule
:param lr_schedule (torch.optim.lr_scheduler._LRSchedule | None): Learning rate schedule
for the optimizer of alpha.
:param opt (torch.optim.Optimizer): An optimizer for ln_alpha.
:param n_actions (int): The dimension of the action space. Scales the target
entropy.
:param max_grad_norm (float): Clip the norm of the gradient during backprop using
this value.
:param entropy_eps (float): Scaling value for the target entropy.
:param name (str): The name of the module. Used e.g. while logging.
:param data_group (str): The name of the data group from which this Loss takes its
data.
:param t_entropy (float | Schedule | None): Value or schedule for the target entropy.
"""

def __init__(
Expand All @@ -264,13 +264,13 @@ def __init__(
pi: nn.Module,
ln_alpha: torch.tensor,
opt: optim.Optimizer,
lr_schedule: Optional[optim.lr_scheduler._LRScheduler] = None,
lr_schedule: optim.lr_scheduler._LRScheduler | None = None,
n_actions: int,
max_grad_norm: float = 10.0,
entropy_eps: float = 0.089,
max_alpha: float = 0.2,
name: str = "alpha",
data_group: str = "default",
t_entropy: float | Schedule | None = None,
):
super().__init__(
name=name,
Expand All @@ -284,14 +284,20 @@ def __init__(
self._max_ln_alpha = torch.log(torch.tensor(max_alpha, device=ln_alpha.device))
# TODO(singhblom) Check number of actions
# self.t_entropy = -np.prod(self.env.action_space.shape).item() # Value from rlkit from Harnouja
self.t_entropy = n_actions * (1.0 + np.log(2.0 * np.pi * entropy_eps**2)) / 2.0
t_entropy = -n_actions if t_entropy is None else t_entropy
if not isinstance(t_entropy, (int, float, Schedule)):
raise TypeError("t_entropy must be a number or an instance of Schedule")

self.t_entropy = (
t_entropy if isinstance(t_entropy, Schedule) else ConstantSchedule(t_entropy)
)
self.ln_alpha = ln_alpha # This is log(alpha)

def loss(self, observation):
with torch.no_grad():
_, logp_pi = self.policy(**observation)
entropy = -logp_pi
error = entropy - self.t_entropy
error = entropy - self.t_entropy.value
alpha_loss = torch.mean(self.ln_alpha * error.detach())
assert alpha_loss.dim() == 0
self.log_scalar("loss/alpha_loss", alpha_loss)
Expand All @@ -304,7 +310,8 @@ def end_batch(self):
self.ln_alpha = torch.clamp_max_(self.ln_alpha, self._max_ln_alpha)
self.ln_alpha.requires_grad_(True)
self.log_scalar("training/alpha_value", torch.exp(self.ln_alpha).item())
self.log_scalar("training/target_entropy", self.t_entropy)
self.log_scalar("training/target_entropy", self.t_entropy.value)
self.t_entropy.step()

def state_dict(self):
state = super().state_dict()
Expand Down
19 changes: 10 additions & 9 deletions emote/extra/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ def __init__(self, initial: float, final: float, steps: int):
self.steps = steps

self._step_count = 0
self._last_val = initial
self._current_value = initial

def get_last_val(self) -> float:
return self._last_val
@property
def value(self):
return self._current_value

def step(self):
pass
Expand Down Expand Up @@ -79,7 +80,7 @@ def step(self):
fraction = math.floor(fraction * self.staircase_steps) / self.staircase_steps
fraction = min(fraction, 1.0)

self._last_val = self.initial + fraction * (self.final - self.initial)
self._current_value = self.initial + fraction * (self.final - self.initial)

self._step_count += 1

Expand Down Expand Up @@ -127,7 +128,7 @@ def step(self):
cycle = math.floor(1 + self._step_count / (2 * self.steps))
x = math.fabs(self._step_count / self.steps - 2 * cycle + 1)

self._last_val = self.initial + (self.final - self.initial) * max(
self._current_value = self.initial + (self.final - self.initial) * max(
0, (1 - x)
) * self.scale_fn(cycle)

Expand All @@ -149,13 +150,13 @@ def __init__(self, initial: float, final: float, steps: int):
def step(self):
if self._step_count > 0:
if (self._step_count - 1 - self.steps) % (2 * self.steps) == 0:
self._last_val += (
self._current_value += (
(self.initial - self.final) * (1 - math.cos(math.pi / self.steps)) / 2
)
else:
self._last_val = (1 + math.cos(math.pi * self._step_count / self.steps)) / (
self._current_value = (1 + math.cos(math.pi * self._step_count / self.steps)) / (
1 + math.cos(math.pi * (self._step_count - 1) / self.steps)
) * (self._last_val - self.final) + self.final
) * (self._current_value - self.final) + self.final

self._step_count += 1

Expand All @@ -176,7 +177,7 @@ def step(self):
if self._step_count >= self.steps:
self._step_count %= self.steps

self._last_val = (
self._current_value = (
self.final
+ (self.initial - self.final)
* (1 + math.cos(math.pi * self._step_count / self.steps))
Expand Down
47 changes: 47 additions & 0 deletions tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from emote.algorithms.sac import AlphaLoss, FeatureAgentProxy
from emote.extra.schedules import ConstantSchedule, CyclicSchedule
from emote.nn.gaussian_policy import GaussianMlpPolicy
from emote.typing import DictObservation, EpisodeState

Expand Down Expand Up @@ -67,3 +68,49 @@ def test_alpha_value_ref_valid_after_load():
assert (
ln_alpha_before_load is ln_alpha_after_load
), "expected ln(alpha) to be the same python object after loading. The reference is used by other loss functions such as PolicyLoss!"


def test_target_entropy_schedules():
policy = GaussianMlpPolicy(IN_DIM, OUT_DIM, [16, 16])
init_ln_alpha = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)
optim = torch.optim.Adam([init_ln_alpha])
loss = AlphaLoss(pi=policy, ln_alpha=init_ln_alpha, opt=optim, n_actions=OUT_DIM)

# Check if default is set correctly when no t_entropy is passed
init_entropy = loss.t_entropy.value
assert init_entropy == -OUT_DIM
print(init_entropy)

# Check that default schedule is constant and doesn't update the value
assert isinstance(loss.t_entropy, ConstantSchedule)
for _ in range(5):
loss.end_batch()
assert init_entropy == loss.t_entropy.value

# Check that value is updated when using a schedule
start = 5
end = 0
steps = 5
schedule = CyclicSchedule(start, end, steps, mode="triangular")
loss = AlphaLoss(
pi=policy, ln_alpha=init_ln_alpha, opt=optim, n_actions=OUT_DIM, t_entropy=schedule
)

for _ in range(steps + 1):
loss.end_batch()
assert loss.t_entropy.value == end

for _ in range(steps):
loss.end_batch()
assert loss.t_entropy.value == start

# Check that invalid types are not accepted
invalid_t_entropy = torch.optim.lr_scheduler.LinearLR(optim, 1, end / start, steps)
with pytest.raises(TypeError):
AlphaLoss(
pi=policy,
ln_alpha=init_ln_alpha,
opt=optim,
n_actions=OUT_DIM,
t_entropy=invalid_t_entropy,
)