Skip to content

Commit

Permalink
[Datumaro] Fix project loading (cvat-ai#1013)
Browse files Browse the repository at this point in the history
* Fix occasional infinite loop in project loading

* Fix project import source options saving

* Fix project import .git dir placement

* Make code aware of grayscale images
  • Loading branch information
zhiltsov-max authored and Chris Lee-Messer committed Mar 5, 2020
1 parent 2878bc9 commit bb5ee67
Show file tree
Hide file tree
Showing 14 changed files with 35 additions and 19 deletions.
6 changes: 3 additions & 3 deletions datumaro/datumaro/components/algorithms/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def normalize_hmaps(self, heatmaps, counts):
def apply(self, image, progressive=False):
import cv2

assert len(image.shape) == 3, \
assert len(image.shape) in [2, 3], \
"Expected an input image in (H, W, C) format"
assert image.shape[2] in [3, 4], \
"Expected BGR or BGRA input"
if len(image.shape) == 3:
assert image.shape[2] in [3, 4], "Expected BGR or BGRA input"
image = image[:, :, :3].astype(np.float32)

model = self.model
Expand Down
6 changes: 3 additions & 3 deletions datumaro/datumaro/components/converters/ms_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def is_empty(self):

def save_image_info(self, item, filename):
if item.has_image:
h, w, _ = item.image.shape
h, w = item.image.shape[:2]
else:
h = 0
w = 0
Expand Down Expand Up @@ -187,7 +187,7 @@ def save_annotations(self, item):
p.label == ann.label]
if polygons:
segmentation = [p.get_points() for p in polygons]
h, w, _ = item.image.shape
h, w = item.image.shape[:2]
rles = mask_utils.frPyObjects(segmentation, h, w)
rle = mask_utils.merge(rles)
area = mask_utils.area(rle)
Expand All @@ -211,7 +211,7 @@ def save_annotations(self, item):
area = ann.area()

if self._context._merge_polygons:
h, w, _ = item.image.shape
h, w = item.image.shape[:2]
rles = mask_utils.frPyObjects(segmentation, h, w)
rle = mask_utils.merge(rles)
area = mask_utils.area(rle)
Expand Down
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/converters/tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def float_list_feature(value):
if not item.has_image:
raise Exception(
"Failed to export dataset item '%s': item has no image" % item.id)
height, width, _ = item.image.shape
height, width = item.image.shape[:2]

features.update({
'image/height': int64_feature(height),
Expand Down
4 changes: 3 additions & 1 deletion datumaro/datumaro/components/converters/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def save_subsets(self):
ET.SubElement(source_elem, 'image').text = 'Unknown'

if item.has_image:
h, w, c = item.image.shape
image_shape = item.image.shape
h, w = image_shape[:2]
c = 1 if len(image_shape) == 2 else image_shape[2]
size_elem = ET.SubElement(root_elem, 'size')
ET.SubElement(size_elem, 'width').text = str(w)
ET.SubElement(size_elem, 'height').text = str(h)
Expand Down
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/converters/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(self, extractor, save_dir):
if not osp.exists(image_path):
save_image(image_path, item.image)

height, width, _ = item.image.shape
height, width = item.image.shape[:2]

yolo_annotation = ''
for bbox in item.annotations:
Expand Down
3 changes: 2 additions & 1 deletion datumaro/datumaro/components/dataset_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def encode_item(self, item):
def encode_image(cls, image):
image_elem = ET.Element('image')

h, w, c = image.shape
h, w = image.shape[:2]
c = 1 if len(image.shape) == 2 else image.shape[2]
ET.SubElement(image_elem, 'width').text = str(w)
ET.SubElement(image_elem, 'height').text = str(h)
ET.SubElement(image_elem, 'depth').text = str(c)
Expand Down
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/importers/cvat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __call__(self, path, **extra_params):
project.add_source(subset_name, {
'url': subset_path,
'format': self.EXTRACTOR_NAME,
'options': extra_params,
'options': dict(extra_params),
})

return project
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/importers/datumaro.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __call__(self, path, **extra_params):
project.add_source(subset_name, {
'url': subset_path,
'format': self.EXTRACTOR_NAME,
'options': extra_params,
'options': dict(extra_params),
})

return project
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/importers/ms_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __call__(self, path, **extra_params):
project.add_source(source_name, {
'url': ann_file,
'format': self._COCO_EXTRACTORS[ann_type],
'options': extra_params,
'options': dict(extra_params),
})

return project
Expand Down
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/importers/tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __call__(self, path, **extra_params):
project.add_source(subset_name, {
'url': subset_path,
'format': self.EXTRACTOR_NAME,
'options': extra_params,
'options': dict(extra_params),
})

return project
Expand Down
4 changes: 2 additions & 2 deletions datumaro/datumaro/components/importers/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __call__(self, path, **extra_params):
project.add_source(task.name, {
'url': path,
'format': extractor_type,
'options': extra_params,
'options': dict(extra_params),
})

if len(project.config.sources) == 0:
Expand Down Expand Up @@ -69,7 +69,7 @@ def __call__(self, path, **extra_params):
project.add_source(task_name, {
'url': task_dir,
'format': extractor_type,
'options': extra_params,
'options': dict(extra_params),
})

if len(project.config.sources) == 0:
Expand Down
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/importers/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __call__(self, path, **extra_params):
project.add_source(source_name, {
'url': config_path,
'format': 'yolo',
'options': extra_params,
'options': dict(extra_params),
})

return project
4 changes: 2 additions & 2 deletions datumaro/datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class GitWrapper:
def __init__(self, config=None):
self.repo = None

if config is not None:
if config is not None and osp.isdir(config.project_dir):
self.init(config.project_dir)

@staticmethod
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(self, project):

own_source = None
own_source_dir = osp.join(config.project_dir, config.dataset_dir)
if osp.isdir(own_source_dir):
if osp.isdir(config.project_dir) and osp.isdir(own_source_dir):
log.disable(log.INFO)
own_source = env.make_importer(DEFAULT_FORMAT)(own_source_dir) \
.make_dataset()
Expand Down
13 changes: 13 additions & 0 deletions datumaro/tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,19 @@ def __iter__(self):

self.assertEqual(5, len(dataset))

def test_can_save_and_load_own_dataset(self):
with TestDir() as test_dir:
src_project = Project()
src_dataset = src_project.make_dataset()
item = DatasetItem(id=1)
src_dataset.put(item)
src_dataset.save(test_dir.path)

loaded_project = Project.load(test_dir.path)
loaded_dataset = loaded_project.make_dataset()

self.assertEqual(list(src_dataset), list(loaded_dataset))

def test_project_own_dataset_can_be_modified(self):
project = Project()
dataset = project.make_dataset()
Expand Down

0 comments on commit bb5ee67

Please sign in to comment.