Skip to content

Commit

Permalink
Kitti Datamodule (#248)
Browse files Browse the repository at this point in the history
* kitti dataset

* kitti dataset

* kitti dm

* kitti dm

* imports

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti

* kitti
  • Loading branch information
annikabrundyn authored Sep 27, 2020
1 parent c60e142 commit 32139f4
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,8 @@
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule

from pl_bolts.datamodules.kitti_dataset import KittiDataset
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
except ImportError:
pass
99 changes: 99 additions & 0 deletions pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import torch

from pytorch_lightning import LightningDataModule
from pl_bolts.datamodules.kitti_dataset import KittiDataset

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data.dataset import random_split


class KittiDataModule(LightningDataModule):

name = 'kitti'

def __init__(
self,
data_dir: str,
val_split: float = 0.2,
test_split: float = 0.1,
num_workers: int = 16,
batch_size: int = 32,
seed: int = 42,
*args,
**kwargs,
):
"""
Kitti train, validation and test dataloaders.
Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
Specs:
- 200 samples
- Each image is (3 x 1242 x 376)
In total there are 34 classes but some of these are not useful so by default we use only 19 of the classes
specified by the `valid_labels` parameter.
Example::
from pl_bolts.datamodules import KittiDataModule
dm = KittiDataModule(PATH)
model = LitModel()
Trainer().fit(model, dm)
Args::
data_dir: where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
val_split: size of validation test (default 0.2)
test_split: size of test set (default 0.1)
num_workers: how many workers to use for loading data
batch_size: the batch size
seed: random seed to be used for train/val/test splits
"""
super().__init__(*args, **kwargs)
self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.batch_size = batch_size
self.num_workers = num_workers
self.seed = seed

self.default_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])

# split into train, val, test
kitti_dataset = KittiDataset(self.data_dir, transform=self.default_transforms)

val_len = round(val_split * len(kitti_dataset))
test_len = round(test_split * len(kitti_dataset))
train_len = len(kitti_dataset) - val_len - test_len

self.trainset, self.valset, self.testset = random_split(kitti_dataset,
lengths=[train_len, val_len, test_len],
generator=torch.Generator().manual_seed(self.seed))

def train_dataloader(self):
loader = DataLoader(self.trainset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers)
return loader

def val_dataloader(self):
loader = DataLoader(self.valset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)
return loader

def test_dataloader(self):
loader = DataLoader(self.testset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)
return loader
92 changes: 92 additions & 0 deletions pl_bolts/datamodules/kitti_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import numpy as np
from PIL import Image

from torch.utils.data import Dataset

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


class KittiDataset(Dataset):
"""
Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These
useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored
in `valid_labels`.
The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
(250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
`len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
the loss function when comparing with the output.
Args:
data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
img_size: image dimensions (width, height)
void_labels: useless classes to be excluded from training
valid_labels: useful classes to include
"""
IMAGE_PATH = os.path.join('training', 'image_2')
MASK_PATH = os.path.join('training', 'semantic')

def __init__(
self,
data_dir: str,
img_size: tuple = (1242, 376),
void_labels: list = DEFAULT_VOID_LABELS,
valid_labels: list = DEFAULT_VALID_LABELS,
transform=None
):
self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
self.transform = transform

self.data_dir = data_dir
self.img_path = os.path.join(self.data_dir, self.IMAGE_PATH)
self.mask_path = os.path.join(self.data_dir, self.MASK_PATH)
self.img_list = self.get_filenames(self.img_path)
self.mask_list = self.get_filenames(self.mask_path)

def __len__(self):
return len(self.img_list)

def __getitem__(self, idx):
img = Image.open(self.img_list[idx])
img = img.resize(self.img_size)
img = np.array(img)

mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)

if self.transform:
img = self.transform(img)

return img, mask

def encode_segmap(self, mask):
"""
Sets void classes to zero so they won't be considered for training
"""
for voidc in self.void_labels:
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
# remove extra idxs from updated dataset
mask[mask > 18] = self.ignore_index
return mask

def get_filenames(self, path):
"""
Returns a list of absolute paths to images inside given `path`
"""
files_list = list()
for filename in os.listdir(path):
files_list.append(os.path.join(path, filename))
return files_list

0 comments on commit 32139f4

Please sign in to comment.