Skip to content

Commit

Permalink
Addressing issues (#25)
Browse files Browse the repository at this point in the history
* fix

* unittest fix
  • Loading branch information
eatpk authored Dec 1, 2023
1 parent e3e5e1b commit 6130ab5
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
8 changes: 5 additions & 3 deletions analog/storage/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def _flush_unsafe(self, buffer, push_count) -> str:
"""
_flush_unsafe is thread unsafe flush of current buffer. No shared variable must be allowed.
"""
save_path = str(os.path.join(self.log_dir, f"data_{push_count}.pt"))
save_path = self.file_prefix + f"{push_count}.mmap"
torch.save(buffer, save_path)
buffer_list = [(k, v) for k, v in buffer]
self.mmap_handler.write(buffer_list, save_path)
return save_path

def _flush_safe(self) -> str:
Expand Down Expand Up @@ -227,8 +229,8 @@ def _find_chunk_indices(self, directory):

def _add_metadata_and_mmap(self, mmap_filename, chunk_index):
# Load the memmap file
mmap, metadata = self.mmap_handler.read(mmap_filename)
self.memmaps.append(mmap)
with self.mmap_handler.read(mmap_filename) as (mmap, metadata):
self.memmaps.append(mmap)

# Update the mapping from data_id to chunk
for entry in metadata:
Expand Down
9 changes: 7 additions & 2 deletions analog/storage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import numpy as np

from einops import rearrange
from contextlib import contextmanager


def extract_arrays(obj, base_path=()):
Expand Down Expand Up @@ -67,6 +67,7 @@ def write(self, data_buffer, filename):
with open(metadata_filename, "w") as f:
json.dump(metadata, f, indent=2)

@contextmanager
def read(self, filename):
"""
read reads the file by chunk index, it will return the data_buffer with metadata.
Expand All @@ -84,7 +85,11 @@ def read(self, filename):
os.path.join(self.save_path, filename), dtype=self.mmap_dtype, mode="r"
)
metadata = self.read_metafile(file_root + "_metadata.json")
return mmap, metadata

try:
yield mmap, metadata
finally:
del mmap

def read_metafile(self, meta_filename):
file_root, file_ext = os.path.splitext(meta_filename)
Expand Down
38 changes: 22 additions & 16 deletions tests/storage/test_util_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,20 @@ def test_write_and_read(self):

handler.write(data_buffer, filename)

mmap, metadata = handler.read(filename)
for item in metadata:
offset = item["offset"]
size = item["size"]
shape = tuple(item["shape"])
dtype = np.dtype(item["dtype"])
expected_data = data_buffer[item["data_id"]][1]["dummy_data"]
read_data = np.frombuffer(
mmap, dtype=dtype, count=size // dtype.itemsize, offset=offset
).reshape(shape)
# Test if expected value and read value equals
self.assertTrue(np.array_equal(read_data, expected_data), "Data mismatch")
with handler.read(filename) as (mmap, metadata):
for item in metadata:
offset = item["offset"]
size = item["size"]
shape = tuple(item["shape"])
dtype = np.dtype(item["dtype"])
expected_data = data_buffer[item["data_id"]][1]["dummy_data"]
read_data = np.frombuffer(
mmap, dtype=dtype, count=size // dtype.itemsize, offset=offset
).reshape(shape)
# Test if expected value and read value equals
self.assertTrue(
np.array_equal(read_data, expected_data), "Data mismatch"
)

def test_read(self):
expected_files_path = os.path.join(
Expand All @@ -82,11 +84,15 @@ def test_read(self):
]

handler.write(data_buffer, filename)
mmap, metadata = handler.read(filename)
mmap, metadata, expected_mmap = None, None, None

with handler.read(filename) as (mm, md):
mmap = mm
metadata = md

with handler.read("expected_data.mmap") as (em, _):
expected_mmap = em

expected_mmap, _ = handler.read(
"expected_data.mmap"
) # Using same metadata file.
for item in metadata:
offset = item["offset"]
size = item["size"]
Expand Down

0 comments on commit 6130ab5

Please sign in to comment.