Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk committed Nov 19, 2023
1 parent e79aee7 commit 114b2cd
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions analog/storage/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def build_log_dataloader(self, batch_size=16, num_workers=0):

class DefaultLogDataset(Dataset):
def __init__(self, mmap_handler):
self.schemas = []
self.memmaps = []
self.data_id_to_chunk = OrderedDict()
self.mmap_handler = mmap_handler
Expand Down Expand Up @@ -230,39 +229,40 @@ def _add_schema_and_mmap(
# Load the memmap file
mmap, schema = self.mmap_handler.read(mmap_filename)
self.memmaps.append(mmap)
self.schemas.append(schema)

# Update the mapping from data_id to chunk
for entry in schema:
self.data_id_to_chunk[entry["data_id"]] = chunk_index
data_id = entry["data_id"]

if data_id in self.data_id_to_chunk:
# Append to the existing list for this data_id
self.data_id_to_chunk[data_id][1].append(entry)
continue
self.data_id_to_chunk[data_id] = (chunk_index, [entry])
def __getitem__(self, index):
data_id = list(self.data_id_to_chunk.keys())[index]
chunk_idx = self.data_id_to_chunk[data_id]
chunk_idx, entries = self.data_id_to_chunk[data_id]

nested_dict = {}

mmap = self.memmaps[chunk_idx]
schema = self.schemas[chunk_idx]
for entry in schema:
if entry["data_id"] == data_id:
# Read the data and put it into the nested dictionary
path = entry["path"]
offset = entry["offset"]
shape = tuple(entry["shape"])
dtype = np.dtype(entry["dtype"])

array = np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C")
tensor = torch.Tensor(array)

# Place the tensor in the correct location within the nested dictionary
current_level = nested_dict
for key in path[:-1]:
if key not in current_level:
current_level[key] = {}
current_level = current_level[key]
current_level[path[-1]] = tensor

for entry in entries:
# Read the data and put it into the nested dictionary
path = entry["path"]
offset = entry["offset"]
shape = tuple(entry["shape"])
dtype = np.dtype(entry["dtype"])

array = np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C")
tensor = torch.Tensor(array)

# Place the tensor in the correct location within the nested dictionary
current_level = nested_dict
for key in path[:-1]:
if key not in current_level:
current_level[key] = {}
current_level = current_level[key]
current_level[path[-1]] = tensor
return data_id, nested_dict

def __len__(self):
Expand Down

0 comments on commit 114b2cd

Please sign in to comment.