Skip to content

Commit

Permalink
CO3Dv2 multi-category extension
Browse files Browse the repository at this point in the history
Summary:
Allows loading of multiple categories.
Multiple categories are provided in a comma-separated list of category names.

Reviewed By: bottler, shapovalov

Differential Revision: D40803297

fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
  • Loading branch information
davnov134 authored and facebook-github-bot committed Nov 2, 2022
1 parent c54e048 commit e4a3298
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 25 deletions.
1 change: 1 addition & 0 deletions projects/implicitron_trainer/tests/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ data_source_ImplicitronDataSource_args:
test_on_train: false
only_test_set: false
load_eval_batches: true
num_load_workers: 4
n_known_frames_for_test: 0
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory
Expand Down
23 changes: 23 additions & 0 deletions pytorch3d/implicitron/dataset/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -259,6 +260,12 @@ def get_frame_numbers_and_timestamps(
"""
raise ValueError("This dataset does not contain videos.")

def join(self, other_datasets: Iterable["DatasetBase"]) -> None:
"""
Joins the current dataset with a list of other datasets of the same type.
"""
raise NotImplementedError()

def get_eval_batches(self) -> Optional[List[List[int]]]:
return None

Expand All @@ -267,6 +274,22 @@ def sequence_names(self) -> Iterable[str]:
# pyre-ignore[16]
return self._seq_to_idx.keys()

def category_to_sequence_names(self) -> Dict[str, List[str]]:
"""
Returns a dict mapping from each dataset category to a list of its
sequence names.
Returns:
category_to_sequence_names: Dict {category_i: [..., sequence_name_j, ...]}
"""
c2seq = defaultdict(list)
for sequence_name in self.sequence_names():
first_frame_idx = next(self.sequence_indices_in_order(sequence_name))
# crashes without overriding __getitem__
sequence_category = self[first_frame_idx].sequence_category
c2seq[sequence_category].append(sequence_name)
return dict(c2seq)

def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
Expand Down
30 changes: 29 additions & 1 deletion pytorch3d/implicitron/dataset/dataset_map_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import os
from dataclasses import dataclass
from typing import Iterator, Optional
from typing import Iterable, Iterator, Optional

from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
Expand Down Expand Up @@ -51,6 +51,34 @@ def iter_datasets(self) -> Iterator[DatasetBase]:
if self.test is not None:
yield self.test

def join(self, other_dataset_maps: Iterable["DatasetMap"]) -> None:
"""
Joins the current DatasetMap with other dataset maps from the input list.
For each subset of each dataset map (train/val/test), the function
omits joining the subsets that are None.
Note the train/val/test datasets of the current dataset map will be
modified in-place.
Args:
other_dataset_maps: The list of dataset maps to be joined into the
current dataset map.
"""
for set_ in ["train", "val", "test"]:
dataset_list = [
getattr(self, set_),
*[getattr(dmap, set_) for dmap in other_dataset_maps],
]
dataset_list = [d for d in dataset_list if d is not None]
if len(dataset_list) == 0:
setattr(self, set_, None)
continue
d0 = dataset_list[0]
if len(dataset_list) > 1:
d0.join(dataset_list[1:])
setattr(self, set_, d0)


class DatasetMapProviderBase(ReplaceableBase):
"""
Expand Down
72 changes: 70 additions & 2 deletions pytorch3d/implicitron/dataset/json_index_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import (
Any,
ClassVar,
Dict,
Iterable,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -188,7 +190,44 @@ def _extract_and_set_eval_batches(self):
self.eval_batch_index
)

def is_filtered(self):
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
"""
Join the dataset with other JsonIndexDataset objects.
Args:
other_datasets: A list of JsonIndexDataset objects to be joined
into the current dataset.
"""
if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
raise ValueError("This function can only join a list of JsonIndexDataset")
# pyre-ignore[16]
self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
# pyre-ignore[16]
self.seq_annots.update(
# https://gist.github.com/treyhunner/f35292e676efa0be1728
functools.reduce(
lambda a, b: {**a, **b},
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
)
)
all_eval_batches = [
self.eval_batches,
# pyre-ignore
*[d.eval_batches for d in other_datasets],
]
if not (
all(ba is None for ba in all_eval_batches)
or all(ba is not None for ba in all_eval_batches)
):
raise ValueError(
"When joining datasets, either all joined datasets have to have their"
" eval_batches defined, or all should have their eval batches undefined."
)
if self.eval_batches is not None:
self.eval_batches = sum(all_eval_batches, [])
self._invalidate_indexes(filter_seq_annots=True)

def is_filtered(self) -> bool:
"""
Returns `True` in case the dataset has been filtered and thus some frame annotations
stored on the disk might be missing in the dataset object.
Expand All @@ -211,6 +250,7 @@ def seq_frame_index_to_dataset_index(
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
allow_missing_indices: bool = False,
remove_missing_indices: bool = False,
suppress_missing_index_warning: bool = True,
) -> List[List[Union[Optional[int], int]]]:
"""
Obtain indices into the dataset object given a list of frame ids.
Expand All @@ -228,6 +268,11 @@ def seq_frame_index_to_dataset_index(
If `False`, returns `None` in place of `seq_frame_index` entries that
are not present in the dataset.
If `True` removes missing indices from the returned indices.
suppress_missing_index_warning:
Active if `allow_missing_indices==True`. Suppressess a warning message
in case an entry from `seq_frame_index` is missing in the dataset
(expected in certain cases - e.g. when setting
`self.remove_empty_masks=True`).
Returns:
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
Expand All @@ -254,7 +299,8 @@ def _get_dataset_idx(
)
if not allow_missing_indices:
raise IndexError(msg)
warnings.warn(msg)
if not suppress_missing_index_warning:
warnings.warn(msg)
return idx
if path is not None:
# Check that the loaded frame path is consistent
Expand Down Expand Up @@ -288,6 +334,21 @@ def subset_from_frame_index(
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
allow_missing_indices: bool = True,
) -> "JsonIndexDataset":
"""
Generate a dataset subset given the list of frames specified in `frame_index`.
Args:
frame_index: The list of frame indentifiers (as stored in the metadata)
specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
Image paths relative to the dataset_root can be stored specified as well:
`List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
in the latter case, if imaga_path do not match the stored paths, an error
is raised.
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
entry from `frame_index` which is missing in the dataset.
Otherwise, generates a subset consisting of frames entries that actually
exist in the dataset.
"""
# Get the indices into the frame annots.
dataset_indices = self.seq_frame_index_to_dataset_index(
[frame_index],
Expand Down Expand Up @@ -838,6 +899,13 @@ def get_frame_numbers_and_timestamps(
)
return out

def category_to_sequence_names(self) -> Dict[str, List[str]]:
c2seq = defaultdict(list)
# pyre-ignore
for sequence_name, sa in self.seq_annots.items():
c2seq[sa.category].append(sequence_name)
return dict(c2seq)

def get_eval_batches(self) -> Optional[List[List[int]]]:
return self.eval_batches

Expand Down
54 changes: 39 additions & 15 deletions pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import copy
import json
import logging
import multiprocessing
import os
import warnings
from collections import defaultdict
Expand All @@ -30,6 +31,7 @@
)

from pytorch3d.renderer.cameras import CamerasBase
from tqdm import tqdm


_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
Expand Down Expand Up @@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
(test frames can repeat across batches).
Args:
category: The object category of the dataset.
category: Dataset categories to load expressed as a string of comma-separated
category names (e.g. `"apple,car,orange"`).
subset_name: The name of the dataset subset. For CO3Dv2, these include
e.g. "manyview_dev_0", "fewview_test", ...
dataset_root: The root folder of the dataset.
Expand All @@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
test_on_train: bool = False
only_test_set: bool = False
load_eval_batches: bool = True
num_load_workers: int = 4

n_known_frames_for_test: int = 0

Expand All @@ -189,11 +193,33 @@ def __post_init__(self):
if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")

frame_file = os.path.join(
self.dataset_root, self.category, "frame_annotations.jgz"
)
if "," in self.category:
# a comma-separated list of categories to load
categories = [c.strip() for c in self.category.split(",")]
logger.info(f"Loading a list of categories: {str(categories)}.")
with multiprocessing.Pool(
processes=min(self.num_load_workers, len(categories))
) as pool:
category_dataset_maps = list(
tqdm(
pool.imap(self._load_category, categories),
total=len(categories),
)
)
dataset_map = category_dataset_maps[0]
dataset_map.join(category_dataset_maps[1:])

else:
# one category to load
dataset_map = self._load_category(self.category)

self.dataset_map = dataset_map

def _load_category(self, category: str) -> DatasetMap:

frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(
self.dataset_root, self.category, "sequence_annotations.jgz"
self.dataset_root, category, "sequence_annotations.jgz"
)

path_manager = self.path_manager_factory.get()
Expand Down Expand Up @@ -232,7 +258,7 @@ def __post_init__(self):

dataset = dataset_type(**common_dataset_kwargs)

available_subset_names = self._get_available_subset_names()
available_subset_names = self._get_available_subset_names(category)
logger.debug(f"Available subset names: {str(available_subset_names)}.")
if self.subset_name not in available_subset_names:
raise ValueError(
Expand All @@ -242,20 +268,20 @@ def __post_init__(self):

# load the list of train/val/test frames
subset_mapping = self._load_annotation_json(
os.path.join(
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
)
os.path.join(category, "set_lists", f"set_lists_{self.subset_name}.json")
)

# load the evaluation batches
if self.load_eval_batches:
eval_batch_index = self._load_annotation_json(
os.path.join(
self.category,
category,
"eval_batches",
f"eval_batches_{self.subset_name}.json",
)
)
else:
eval_batch_index = None

train_dataset = None
if not self.only_test_set:
Expand Down Expand Up @@ -313,9 +339,7 @@ def __post_init__(self):
)
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")

self.dataset_map = DatasetMap(
train=train_dataset, val=val_dataset, test=test_dataset
)
return DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)

@classmethod
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
Expand Down Expand Up @@ -381,10 +405,10 @@ def _load_annotation_json(self, json_filename: str):
data = json.load(f)
return data

def _get_available_subset_names(self):
def _get_available_subset_names(self, category: str):
return get_available_subset_names(
self.dataset_root,
self.category,
category,
path_manager=self.path_manager_factory.get(),
)

Expand Down
Loading

0 comments on commit e4a3298

Please sign in to comment.