Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ali/genrl pr #159

Merged
merged 11 commits into from
Sep 14, 2023
Empty file added emote/algorithms/__init__.py
Empty file.
Empty file.
48 changes: 48 additions & 0 deletions emote/algorithms/genrl/proxies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Dict, Optional

import numpy as np
import torch
import torch.nn as nn

from emote.memory.memory import TableMemoryProxy
from emote.memory.table import Table
from emote.typing import AgentId, DictObservation, DictResponse


class MemoryProxyWithEncoder(TableMemoryProxy):
def __init__(
self,
table: Table,
encoder: nn.Module,
minimum_length_threshold: Optional[int] = None,
use_terminal: bool = False,
input_key: str = "obs",
action_key: str = "actions",
):
super().__init__(table, minimum_length_threshold, use_terminal)
self.encoder = encoder
self._input_key = input_key
self._action_key = action_key

def add(
self,
observations: Dict[AgentId, DictObservation],
responses: Dict[AgentId, DictResponse],
):
updated_responses = {}
for agent_id, response in responses.items():
actions = np.array(response.list_data[self._action_key])
if np.size(actions) == 0:
updated_responses.update({agent_id: response})
else:
actions = torch.from_numpy(actions).to(torch.float)
obs = torch.from_numpy(
observations[agent_id].array_data[self._input_key]
)
obs = obs.to(torch.float)
latent = self.encoder(actions, obs).detach().cpu().numpy()
new_response = DictResponse(
list_data={self._action_key: latent}, scalar_data={}
)
updated_responses.update({agent_id: new_response})
super().add(observations, updated_responses)
82 changes: 82 additions & 0 deletions emote/algorithms/genrl/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Callable

import torch
import torch.nn.functional as F

from torch import nn, optim

from emote.callbacks import LossCallback
from emote.nn.initialization import normal_init_


class VariationalAutoencoder(nn.Module):
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
device: torch.device,
beta: float = 0.01,
):
super().__init__()
assert encoder.output_size == decoder.input_size
self.latent_size = encoder.output_size
self.device = device
self.encoder = encoder
self.decoder = decoder
self.beta = beta
self.encoder.apply(normal_init_)
self.decoder.apply(normal_init_)

def forward(self, x, condition=None):
mu, log_std = self.encoder(x, condition)
var = torch.exp(log_std)
eps = torch.randn_like(var).to(self.device)
latent = eps.mul(var).add(mu)
x_hat = self.decoder(latent, condition)
x_hat = x_hat.view(x.size())
return x_hat, mu, log_std, latent

def loss(self, x, x_hat, mu, log_std):
restore_loss = F.mse_loss(x_hat, x)
var = torch.exp(log_std)
kld = torch.sum(-log_std + (mu**2) * 0.5 + var, 1) - self.latent_size
kl_loss = kld.mean()
info = {"restore_loss": restore_loss, "kl_loss": kl_loss}
loss = restore_loss + self.beta * kl_loss
return loss, info


class VAELoss(LossCallback):
def __init__(
self,
*,
vae: VariationalAutoencoder,
opt: optim.Optimizer,
lr_schedule=None,
max_grad_norm: float = 10.0,
name: str = "vae",
data_group: str = "default",
input_key: str = "obs",
conditioning_func: Callable = lambda _: None,
):
super().__init__(
name=name,
optimizer=opt,
lr_schedule=lr_schedule,
network=vae,
max_grad_norm=max_grad_norm,
data_group=data_group,
)
self.vae = vae
self.conditioning_func = conditioning_func
self._input_key = input_key

def loss(self, observation, actions):

condition = self.conditioning_func(observation[self._input_key])
samples, dist_mean, dist_log_std, _ = self.vae.forward(actions, condition)
loss, info = self.vae.loss(actions, samples, dist_mean, dist_log_std)
self.log_scalar("vae/restore_loss", torch.mean(info["restore_loss"]))
self.log_scalar("vae/kl_loss", torch.mean(info["kl_loss"]))

return loss
128 changes: 128 additions & 0 deletions emote/algorithms/genrl/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Callable

import torch

from torch import Tensor, nn

from emote.nn.gaussian_policy import GaussianMlpPolicy


class DecoderWrapper(nn.Module):
def __init__(
self,
decoder: nn.Module,
condition_fn: Callable,
latent_multiplier: float = 3.0,
):
super().__init__()
self.device = decoder.device
self._latent_multiplier = latent_multiplier
self.latent_size = decoder.input_size
self.output_size = decoder.output_size
self.condition_size = decoder.condition_size
self.condition_fn = condition_fn
self.decoder = decoder

def forward(
self, latent: torch.Tensor, observation: torch.Tensor = None
) -> torch.Tensor:

"""
Running decoder

Arguments:
latent (torch.Tensor): batch x latent_size
observation (torch.Tensor): batch x obs_size

Returns:
torch.Tensor: the sample (batch x data_size)
"""
latent = latent * self._latent_multiplier

latent = latent.to(self.device)
condition = None
if observation is not None:
observation = observation.to(self.device)
condition = self.condition_fn(observation)

sample = self.decoder.forward(latent, condition)

return sample

def load_state_dict(self, state_dict, strict=True):
model_dict = self.state_dict()
new_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
assert new_state_dict != {}
model_dict.update(new_state_dict)
super().load_state_dict(model_dict)


class EncoderWrapper(nn.Module):
def __init__(
self,
encoder: nn.Module,
condition_fn: Callable,
):
super().__init__()
self.encoder = encoder
self.device = encoder.device
self.action_size = encoder.input_size
self.latent_size = encoder.output_size
self.condition_size = encoder.condition_size

self.condition_fn = condition_fn

def forward(
self, action: torch.Tensor, observation: torch.Tensor = None
) -> torch.Tensor:
"""
Running encoder

Arguments:
action (torch.Tensor): batch x data_size
observation (torch.Tensor): batch x obs_size

Returns:
torch.Tensor: the mean (batch x data_size)
"""
action = action.to(self.device)
condition = None
if observation is not None:
observation = observation.to(self.device)
condition = self.condition_fn(observation)

mean, _ = self.encoder.forward(action, condition)
return mean

def load_state_dict(self, state_dict, strict=True):
model_dict = self.state_dict()
new_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
assert new_state_dict != {}
model_dict.update(new_state_dict)
super().load_state_dict(model_dict)


class PolicyWrapper(nn.Module):
def __init__(
self,
decoder: DecoderWrapper,
policy: GaussianMlpPolicy,
):
super().__init__()
self.latent_size = decoder.latent_size
self.decoder = decoder
self.policy = policy

def forward(self, obs: Tensor, epsilon: Tensor = None):
# we need to discard the extra dimensions of epsilon.
# the input epsilon is given for the original action space
# however, the policy outputs latent actions.
if epsilon is not None:
epsilon = epsilon[:, : self.latent_size]

if self.training:
sample, log_prob = self.policy.forward(obs, epsilon)
action = self.decoder(sample, obs)
return action, log_prob

return self.decoder(self.policy.forward(obs, epsilon), obs)
2 changes: 2 additions & 0 deletions emote/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
LoggingProxyWrapper,
MemoryExporterProxyWrapper,
MemoryLoader,
MemoryWarmup,
TableMemoryProxy,
)
from .table import Table
Expand All @@ -19,4 +20,5 @@
"MemoryExporterProxyWrapper",
"MemoryImporterCallback",
"LoggingProxyWrapper",
"MemoryWarmup",
]
15 changes: 15 additions & 0 deletions emote/nn/initialization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import torch

from torch import nn

Expand All @@ -17,3 +18,17 @@ def xavier_uniform_init_(m, gain):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain)
nn.init.constant_(m.bias, 0.0)


def normal_init_(m: nn.Module):
if isinstance(m, nn.Conv1d):
torch.nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, std=1e-3)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
Loading