Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a joint memory loader capable of sampling data from multiple memory loaders #181

Merged
merged 7 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions emote/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from __future__ import annotations

import collections
import inspect
import logging
import os
Expand Down Expand Up @@ -463,6 +464,57 @@ def __iter__(self):
yield {self.data_group: data, self.size_key: data[self.size_key]}


class JointMemoryLoader:
"""A memory loader capable of loading data from multiple `MemoryLoader`s."""

def __init__(self, loaders: list[MemoryLoader], size_key: str = "batch_size"):
self._loaders = loaders
self._size_key = size_key

counts = collections.Counter((loader.data_group for loader in loaders))
counts_over_1 = {k: count for k, count in counts.items() if count > 1}
if len(counts_over_1) != 0:
raise ValueError(
f"""JointMemoryLoader was provided MemoryLoaders that share the same datagroup. This will clobber the joint output data and is not allowed.
Here is a dict of each datagroup encountered more than once, and its occurance count: {counts_over_1}"""
)

def is_ready(self):
return all(loader.is_ready() for loader in self._loaders)

def __iter__(self):
if not self.is_ready():
raise RuntimeError(
"""memory loader(s) in JointMemoryLoader does not have enough data. Check `is_ready()`
before trying to iterate over data."""
)

while True:
out = {self._size_key: 0}

for loader in self._loaders:
data = next(iter(loader))
out[loader.data_group] = data[loader.data_group]
# for joint memory loaders we sum up all individual loader sizes
out[self._size_key] += data[loader.size_key]

yield out


class JointMemoryLoaderWithDataGroup(JointMemoryLoader):
"""A JointMemoryLoader that places its data inside of a user-specified datagroup."""

def __init__(self, loaders: list[MemoryLoader], data_group: str, size_key: str = "batch_size"):
super().__init__(loaders, size_key)
self._data_group = data_group

def __iter__(self):
data = next(super().__iter__())
total_size = data.pop(self._size_key)

yield {self._data_group: data, self._size_key: total_size}


class MemoryWarmup(Callback):
"""A blocker to ensure memory has data.

Expand Down
123 changes: 123 additions & 0 deletions tests/test_memory_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np
import pytest

from emote.memory.column import Column
from emote.memory.fifo_strategy import FifoEjectionStrategy
from emote.memory.memory import JointMemoryLoader, JointMemoryLoaderWithDataGroup, MemoryLoader
from emote.memory.table import ArrayTable
from emote.memory.uniform_strategy import UniformSampleStrategy


@pytest.fixture
def a_dummy_table():
tab = ArrayTable(
columns=[Column("state", (), np.float32), Column("action", (), np.float32)],
maxlen=1_000,
sampler=UniformSampleStrategy(),
ejector=FifoEjectionStrategy(),
length_key="action",
device="cpu",
)
tab.add_sequence(
0,
{
"state": [5.0, 6.0],
"action": [1.0],
},
)

return tab


@pytest.fixture
def another_dummy_table():
tab = ArrayTable(
columns=[Column("state", (), np.float32), Column("action", (), np.float32)],
maxlen=1_000,
sampler=UniformSampleStrategy(),
ejector=FifoEjectionStrategy(),
length_key="action",
device="cpu",
)
tab.add_sequence(
0,
{
"state": [5.0, 6.0],
"action": [1.0],
},
)

return tab


def test_joint_memory_loader(a_dummy_table: ArrayTable, another_dummy_table: ArrayTable):
a_loader = MemoryLoader(
table=a_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="a",
)
another_loader = MemoryLoader(
table=another_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="another",
)

joint_loader = JointMemoryLoader(loaders=[a_loader, another_loader])

data = next(iter(joint_loader))
assert "a" in data and "another" in data, "JointMemoryLoader did not yield expected memory data"


def test_joint_memory_loader_datagroup(a_dummy_table: ArrayTable, another_dummy_table: ArrayTable):
a_loader = MemoryLoader(
table=a_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="a",
)
another_loader = MemoryLoader(
table=another_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="another",
)

joint_loader = JointMemoryLoaderWithDataGroup(
loaders=[a_loader, another_loader], data_group="joint_datagroup"
)

encapsulated_data = next(iter(joint_loader))
data = encapsulated_data["joint_datagroup"]

assert (
"joint_datagroup" in encapsulated_data
), "Expected joint dataloader to place data in its own datagroup, but it does not exist."
assert (
"a" in data and "another" in data
), "Expected joint dataloader to actually place data in its datagroup, but it is empty."


def test_joint_memory_loader_nonunique_loaders_trigger_exception(a_dummy_table: ArrayTable):
loader1 = MemoryLoader(
table=a_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="a",
)
loader2 = MemoryLoader(
table=a_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="a",
)

with pytest.raises(Exception, match="JointMemoryLoader"):
joint_loader = JointMemoryLoader([loader1, loader2]) # noqa