Skip to content

Commit

Permalink
Accelerated BDD100K (lava-nc#341)
Browse files Browse the repository at this point in the history
Signed-off-by: bamsumit <bam_sumit@hotmail.com>
  • Loading branch information
bamsumit authored Jul 23, 2024
1 parent 16f4c1c commit 348a6ee
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 47 deletions.
113 changes: 66 additions & 47 deletions src/lava/lib/dl/slayer/object_detection/dataset/bdd100k.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
from typing import Any, Dict, Tuple
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import torch
Expand Down Expand Up @@ -43,8 +44,12 @@ class _BDD(Dataset):
def __init__(self,
root: str = '.',
dataset: str = '.',
train: bool = False) -> None:
train: bool = False,
seq_len: int = 32,
randomize_seq: bool = False) -> None:
super().__init__()
self.seq_len = seq_len
self.randomize_seq = randomize_seq

image_set = 'train' if train else 'val'
self.label_path = root + os.sep + \
Expand Down Expand Up @@ -76,6 +81,24 @@ def __init__(self,
self.cat_name = sorted(list(categories))
self.idx_map = {name: idx for idx, name in enumerate(self.cat_name)}

def _get_frame(self, path, labels):
image = Image.open(path).convert('RGB')
width, height = image.size
size = {'height': height, 'width': width}
objects = []
for ann in labels:
name = ann['category']
bndbox = {'xmin': ann['box2d']['x1'],
'ymin': ann['box2d']['y1'],
'xmax': ann['box2d']['x2'],
'ymax': ann['box2d']['y2']}
objects.append({'id': self.idx_map[name],
'name': name,
'bndbox': bndbox})

annotation = {'annotation': {'size': size, 'object': objects}}
return image, annotation

def __getitem__(self, index: int) -> Tuple[torch.tensor, Dict[Any, Any]]:
id = self.ids[index]
img_path = self.image_path + os.sep + id + os.sep
Expand All @@ -84,25 +107,25 @@ def __getitem__(self, index: int) -> Tuple[torch.tensor, Dict[Any, Any]]:
annotations = []
with open(self.label_path + os.sep + id + '.json') as file:
data = json.load(file)
for img in data:
image = Image.open(
img_path + os.sep + img['name']).convert('RGB')
width, height = image.size
size = {'height': height, 'width': width}
objects = []
for ann in img['labels']:
name = ann['category']
bndbox = {'xmin': ann['box2d']['x1'],
'ymin': ann['box2d']['y1'],
'xmax': ann['box2d']['x2'],
'ymax': ann['box2d']['y2']}
objects.append({'id': self.idx_map[name],
'name': name,
'bndbox': bndbox})

annotation = {'size': size, 'object': objects}
images.append(image)
annotations.append({'annotation': annotation})
num_seq = len(data)
if self.randomize_seq:
start_idx = np.random.randint(max(num_seq - self.seq_len, 0))
else:
start_idx = 0
stop_idx = start_idx + self.seq_len
data = data[start_idx:stop_idx]

with ThreadPoolExecutor() as pool:
path = map(lambda img: img_path + os.sep + img['name'], data)
labels = map(lambda img: img['labels'], data)
for image, annotation in pool.map(self._get_frame,
path, labels):
images.append(image)
annotations.append(annotation)
if len(images) != self.seq_len:
delta = self.seq_len - len(images)
images = images + [images[-1]] * delta
annotations = annotations + [annotations[-1]] * delta

return images, annotations

Expand Down Expand Up @@ -152,14 +175,18 @@ def __init__(self,
lambda x: bbutils.resize_bounding_boxes(x, size),
])

self.datasets = [_BDD(root=root, dataset=dataset, train=train)]
self.datasets = [_BDD(root=root, dataset=dataset, train=train,
seq_len=seq_len, randomize_seq=randomize_seq)]

self.classes = self.datasets[0].cat_name
self.idx_map = self.datasets[0].idx_map
self.augment_prob = augment_prob
self.seq_len = seq_len
self.randomize_seq = randomize_seq

def flip_lr(self, img) -> Image:
return Image.Image.transpose(img, Transpose.FLIP_LEFT_RIGHT)

def __getitem__(self, index: int) -> Tuple[torch.tensor, Dict[Any, Any]]:
"""Get a sample video sequence of BDD100K dataset.
Expand All @@ -179,38 +206,30 @@ def __getitem__(self, index: int) -> Tuple[torch.tensor, Dict[Any, Any]]:

# flip left right
if np.random.random() < self.augment_prob:
for idx in range(len(images)):
images[idx] = Image.Image.transpose(
images[idx], Transpose.FLIP_LEFT_RIGHT)
annotations[idx] = bbutils.fliplr_bounding_boxes(
annotations[idx])
with ThreadPoolExecutor() as pool:
images = pool.map(self.flip_lr, images)
with ThreadPoolExecutor() as pool:
annotations = pool.map(bbutils.fliplr_bounding_boxes,
annotations)
# blur
if np.random.random() < self.augment_prob:
for idx in range(len(images)):
images[idx] = self.blur(images[idx])
with ThreadPoolExecutor() as pool:
images = pool.map(self.blur, images)
# color jitter
if np.random.random() < self.augment_prob:
for idx in range(len(images)):
images[idx] = self.color_jitter(images[idx])
with ThreadPoolExecutor() as pool:
images = pool.map(self.color_jitter, images)
# grayscale
if np.random.random() < self.augment_prob:
for idx in range(len(images)):
images[idx] = self.grayscale(images[idx])

image = torch.cat([torch.unsqueeze(self.img_transform(img), -1)
for img in images], dim=-1)
annotations = [self.bb_transform(ann) for ann in annotations]

# [C, H, W, T], [bbox] * T
num_seq = image.shape[-1]
if self.randomize_seq:
start_idx = np.random.randint(num_seq - self.seq_len)
else:
start_idx = 0
stop_idx = start_idx + self.seq_len

# list in time
return image[..., start_idx:stop_idx], annotations[start_idx:stop_idx]
with ThreadPoolExecutor() as pool:
images = pool.map(self.grayscale, images)

with ThreadPoolExecutor() as pool:
results = pool.map(self.img_transform, images)
image = torch.stack(list(results), dim=-1)
annotations = list(map(self.bb_transform, annotations))

return image, annotations

def __len__(self) -> int:
"""Number of samples in the dataset.
Expand Down
1 change: 1 addition & 0 deletions src/lava/lib/dl/slayer/utils/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def forward(ctx, z, neuron, recurrent_mat):
z = z.detach().requires_grad_()
x = torch.zeros_like(z).to(z.device)
recurrent_mat_T = recurrent_mat.transpose(0, 1).clone().detach()
recurrent_mat_T = recurrent_mat_T.to(z.device)

ctx.dend_sums = []
ctx.spikes = []
Expand Down

0 comments on commit 348a6ee

Please sign in to comment.