-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* kitti dataset * kitti dataset * kitti dm * kitti dm * imports * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti
- Loading branch information
1 parent
c60e142
commit 32139f4
Showing
3 changed files
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import os | ||
import torch | ||
|
||
from pytorch_lightning import LightningDataModule | ||
from pl_bolts.datamodules.kitti_dataset import KittiDataset | ||
|
||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as transforms | ||
from torch.utils.data.dataset import random_split | ||
|
||
|
||
class KittiDataModule(LightningDataModule): | ||
|
||
name = 'kitti' | ||
|
||
def __init__( | ||
self, | ||
data_dir: str, | ||
val_split: float = 0.2, | ||
test_split: float = 0.1, | ||
num_workers: int = 16, | ||
batch_size: int = 32, | ||
seed: int = 42, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Kitti train, validation and test dataloaders. | ||
Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved. | ||
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 | ||
Specs: | ||
- 200 samples | ||
- Each image is (3 x 1242 x 376) | ||
In total there are 34 classes but some of these are not useful so by default we use only 19 of the classes | ||
specified by the `valid_labels` parameter. | ||
Example:: | ||
from pl_bolts.datamodules import KittiDataModule | ||
dm = KittiDataModule(PATH) | ||
model = LitModel() | ||
Trainer().fit(model, dm) | ||
Args:: | ||
data_dir: where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' | ||
val_split: size of validation test (default 0.2) | ||
test_split: size of test set (default 0.1) | ||
num_workers: how many workers to use for loading data | ||
batch_size: the batch size | ||
seed: random seed to be used for train/val/test splits | ||
""" | ||
super().__init__(*args, **kwargs) | ||
self.data_dir = data_dir if data_dir is not None else os.getcwd() | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.seed = seed | ||
|
||
self.default_transforms = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], | ||
std=[0.32064945, 0.32098866, 0.32325324]) | ||
]) | ||
|
||
# split into train, val, test | ||
kitti_dataset = KittiDataset(self.data_dir, transform=self.default_transforms) | ||
|
||
val_len = round(val_split * len(kitti_dataset)) | ||
test_len = round(test_split * len(kitti_dataset)) | ||
train_len = len(kitti_dataset) - val_len - test_len | ||
|
||
self.trainset, self.valset, self.testset = random_split(kitti_dataset, | ||
lengths=[train_len, val_len, test_len], | ||
generator=torch.Generator().manual_seed(self.seed)) | ||
|
||
def train_dataloader(self): | ||
loader = DataLoader(self.trainset, | ||
batch_size=self.batch_size, | ||
shuffle=True, | ||
num_workers=self.num_workers) | ||
return loader | ||
|
||
def val_dataloader(self): | ||
loader = DataLoader(self.valset, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=self.num_workers) | ||
return loader | ||
|
||
def test_dataloader(self): | ||
loader = DataLoader(self.testset, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=self.num_workers) | ||
return loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import os | ||
import numpy as np | ||
from PIL import Image | ||
|
||
from torch.utils.data import Dataset | ||
|
||
DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) | ||
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) | ||
|
||
|
||
class KittiDataset(Dataset): | ||
""" | ||
Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved. | ||
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 | ||
There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These | ||
useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored | ||
in `valid_labels`. | ||
The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index` | ||
(250 by default). It also sets all of the valid pixels to the appropriate value between 0 and | ||
`len(valid_labels)` (since that is the number of valid classes), so it can be used properly by | ||
the loss function when comparing with the output. | ||
Args: | ||
data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' | ||
img_size: image dimensions (width, height) | ||
void_labels: useless classes to be excluded from training | ||
valid_labels: useful classes to include | ||
""" | ||
IMAGE_PATH = os.path.join('training', 'image_2') | ||
MASK_PATH = os.path.join('training', 'semantic') | ||
|
||
def __init__( | ||
self, | ||
data_dir: str, | ||
img_size: tuple = (1242, 376), | ||
void_labels: list = DEFAULT_VOID_LABELS, | ||
valid_labels: list = DEFAULT_VALID_LABELS, | ||
transform=None | ||
): | ||
self.img_size = img_size | ||
self.void_labels = void_labels | ||
self.valid_labels = valid_labels | ||
self.ignore_index = 250 | ||
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels)))) | ||
self.transform = transform | ||
|
||
self.data_dir = data_dir | ||
self.img_path = os.path.join(self.data_dir, self.IMAGE_PATH) | ||
self.mask_path = os.path.join(self.data_dir, self.MASK_PATH) | ||
self.img_list = self.get_filenames(self.img_path) | ||
self.mask_list = self.get_filenames(self.mask_path) | ||
|
||
def __len__(self): | ||
return len(self.img_list) | ||
|
||
def __getitem__(self, idx): | ||
img = Image.open(self.img_list[idx]) | ||
img = img.resize(self.img_size) | ||
img = np.array(img) | ||
|
||
mask = Image.open(self.mask_list[idx]).convert('L') | ||
mask = mask.resize(self.img_size) | ||
mask = np.array(mask) | ||
mask = self.encode_segmap(mask) | ||
|
||
if self.transform: | ||
img = self.transform(img) | ||
|
||
return img, mask | ||
|
||
def encode_segmap(self, mask): | ||
""" | ||
Sets void classes to zero so they won't be considered for training | ||
""" | ||
for voidc in self.void_labels: | ||
mask[mask == voidc] = self.ignore_index | ||
for validc in self.valid_labels: | ||
mask[mask == validc] = self.class_map[validc] | ||
# remove extra idxs from updated dataset | ||
mask[mask > 18] = self.ignore_index | ||
return mask | ||
|
||
def get_filenames(self, path): | ||
""" | ||
Returns a list of absolute paths to images inside given `path` | ||
""" | ||
files_list = list() | ||
for filename in os.listdir(path): | ||
files_list.append(os.path.join(path, filename)) | ||
return files_list |