-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrecycledataset.py
71 lines (59 loc) · 2.82 KB
/
recycledataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from torch.utils.data import Dataset
from pycocotools.coco import COCO
import cv2
import os
import numpy as np
category_names = ['Backgroud', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', 'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing']
def get_classname(classID, cats):
for i in range(len(cats)):
if cats[i]['id']==classID:
return cats[i]['name']
return "None"
class CustomDataLoader(Dataset):
"""COCO format"""
def __init__(self, data_path,ann, mode = 'train', transform = None):
super().__init__()
self.mode = mode
self.transform = transform
self.data_path = data_path
self.coco = COCO(os.path.join(data_path,ann))
def __getitem__(self, index: int):
# dataset이 index되어 list처럼 동작
image_id = self.coco.dataset['images'][index]['id']
image_infos = self.coco.loadImgs(image_id)[0]
# cv2 를 활용하여 image 불러오기
images = cv2.imread(os.path.join(self.data_path, image_infos['file_name']))
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0
if (self.mode in ('train', 'val')):
ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])
anns = self.coco.loadAnns(ann_ids)
# Load the categories in a variable
cat_ids = self.coco.getCatIds()
cats = self.coco.loadCats(cat_ids)
# masks : size가 (height x width)인 2D
# 각각의 pixel 값에는 "category id" 할당
# Background = 0
masks = np.zeros((image_infos["height"], image_infos["width"]))
# General trash = 1, ... , Cigarette = 10
anns = sorted(anns, key=lambda idx : idx['area'], reverse=True)
for i in range(len(anns)):
className = get_classname(anns[i]['category_id'], cats)
pixel_value = category_names.index(className)
masks[self.coco.annToMask(anns[i]) == 1] = pixel_value
masks = masks.astype(np.int8)
# transform -> albumentations 라이브러리 활용
if self.transform is not None:
transformed = self.transform(image=images, mask=masks)
images = transformed["image"]
masks = transformed["mask"]
return images, masks, image_infos
if self.mode == 'test':
# transform -> albumentations 라이브러리 활용
if self.transform is not None:
transformed = self.transform(image=images)
images = transformed["image"]
return images, image_infos
def __len__(self) -> int:
# 전체 dataset의 size를 return
return len(self.coco.getImgIds())