-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: I forgot to include these tests to D45086611 when transferring code from pixar_replay repo. They test the new ORM types used in SQL dataset and are SQL Alchemy 2.0 specific. An important test for extending types is a proof of concept for generality of SQL Dataset. The idea is to extend FrameAnnotation and FrameData in parallel. Reviewed By: bottler Differential Revision: D45529284 fbshipit-source-id: 2a634e518f580c312602107c85fc320db43abcf5
- Loading branch information
1 parent
178a777
commit 3e3644e
Showing
2 changed files
with
267 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import dataclasses | ||
import logging | ||
import os | ||
import tempfile | ||
import unittest | ||
from typing import ClassVar, Optional, Type | ||
|
||
import pandas as pd | ||
import pkg_resources | ||
import sqlalchemy as sa | ||
|
||
from pytorch3d.implicitron.dataset import types | ||
from pytorch3d.implicitron.dataset.frame_data import FrameData, GenericFrameDataBuilder | ||
from pytorch3d.implicitron.dataset.orm_types import ( | ||
SqlFrameAnnotation, | ||
SqlSequenceAnnotation, | ||
) | ||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset | ||
from pytorch3d.implicitron.dataset.utils import GenericWorkaround | ||
from pytorch3d.implicitron.tools.config import registry | ||
from sqlalchemy.orm import composite, Mapped, mapped_column, Session | ||
|
||
NO_BLOBS_KWARGS = { | ||
"dataset_root": "", | ||
"load_images": False, | ||
"load_depths": False, | ||
"load_masks": False, | ||
"load_depth_masks": False, | ||
"box_crop": False, | ||
} | ||
|
||
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset") | ||
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite") | ||
|
||
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset") | ||
sh = logging.StreamHandler() | ||
logger.addHandler(sh) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
|
||
@dataclasses.dataclass | ||
class MagneticFieldAnnotation: | ||
path: str | ||
average_flux_density: Optional[float] = None | ||
|
||
|
||
class ExtendedSqlFrameAnnotation(SqlFrameAnnotation): | ||
num_dogs: Mapped[Optional[int]] = mapped_column(default=None) | ||
|
||
magnetic_field: Mapped[MagneticFieldAnnotation] = composite( | ||
mapped_column("_magnetic_field_path", nullable=True), | ||
mapped_column("_magnetic_field_average_flux_density", nullable=True), | ||
default_factory=lambda: None, | ||
) | ||
|
||
|
||
class ExtendedSqlIndexDataset(SqlIndexDataset): | ||
frame_annotations_type: ClassVar[ | ||
Type[SqlFrameAnnotation] | ||
] = ExtendedSqlFrameAnnotation | ||
|
||
|
||
class CanineFrameData(FrameData): | ||
num_dogs: Optional[int] = None | ||
magnetic_field_average_flux_density: Optional[float] = None | ||
|
||
|
||
@registry.register | ||
class CanineFrameDataBuilder( | ||
GenericWorkaround, GenericFrameDataBuilder[CanineFrameData] | ||
): | ||
""" | ||
A concrete class to build an extended FrameData object | ||
""" | ||
|
||
frame_data_type: ClassVar[Type[FrameData]] = CanineFrameData | ||
|
||
def build( | ||
self, | ||
frame_annotation: ExtendedSqlFrameAnnotation, | ||
sequence_annotation: types.SequenceAnnotation, | ||
load_blobs: bool = True, | ||
) -> CanineFrameData: | ||
frame_data = super().build(frame_annotation, sequence_annotation, load_blobs) | ||
frame_data.num_dogs = frame_annotation.num_dogs or 101 | ||
frame_data.magnetic_field_average_flux_density = ( | ||
frame_annotation.magnetic_field.average_flux_density | ||
) | ||
return frame_data | ||
|
||
|
||
class CanineSqlIndexDataset(SqlIndexDataset): | ||
frame_annotations_type: ClassVar[ | ||
Type[SqlFrameAnnotation] | ||
] = ExtendedSqlFrameAnnotation | ||
|
||
frame_data_builder_class_type: str = "CanineFrameDataBuilder" | ||
|
||
|
||
class TestExtendingOrmTypes(unittest.TestCase): | ||
def setUp(self): | ||
# create a temporary copy of the DB with an extended schema | ||
engine = sa.create_engine(f"sqlite:///{METADATA_FILE}") | ||
with Session(engine) as session: | ||
extended_annots = [ | ||
ExtendedSqlFrameAnnotation( | ||
**{ | ||
k: v | ||
for k, v in frame_annot.__dict__.items() | ||
if not k.startswith("_") # remove mapped fields and SA metadata | ||
} | ||
) | ||
for frame_annot in session.scalars(sa.select(SqlFrameAnnotation)) | ||
] | ||
seq_annots = session.scalars( | ||
sa.select(SqlSequenceAnnotation), | ||
execution_options={"prebuffer_rows": True}, | ||
) | ||
session.expunge_all() | ||
|
||
self._temp_db = tempfile.NamedTemporaryFile(delete=False) | ||
engine_ext = sa.create_engine(f"sqlite:///{self._temp_db.name}") | ||
ExtendedSqlFrameAnnotation.metadata.create_all(engine_ext, checkfirst=True) | ||
with Session(engine_ext, expire_on_commit=False) as session_ext: | ||
session_ext.add_all(extended_annots) | ||
for instance in seq_annots: | ||
session_ext.merge(instance) | ||
session_ext.commit() | ||
|
||
# check the setup is correct | ||
with engine_ext.connect() as connection_ext: | ||
df = pd.read_sql_query( | ||
sa.select(ExtendedSqlFrameAnnotation), connection_ext | ||
) | ||
self.assertEqual(len(df), 100) | ||
self.assertIn("_magnetic_field_average_flux_density", df.columns) | ||
|
||
df_seq = pd.read_sql_query(sa.select(SqlSequenceAnnotation), connection_ext) | ||
self.assertEqual(len(df_seq), 10) | ||
|
||
def tearDown(self): | ||
self._temp_db.close() | ||
os.remove(self._temp_db.name) | ||
|
||
def test_basic(self, sequence="cat1_seq2", frame_number=4): | ||
dataset = ExtendedSqlIndexDataset( | ||
sqlite_metadata_file=self._temp_db.name, | ||
remove_empty_masks=False, | ||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, | ||
) | ||
|
||
self.assertEqual(len(dataset), 100) | ||
|
||
# check the items are consecutive | ||
past_sequences = set() | ||
last_frame_number = -1 | ||
last_sequence = "" | ||
for i in range(len(dataset)): | ||
item = dataset[i] | ||
|
||
if item.frame_number == 0: | ||
self.assertNotIn(item.sequence_name, past_sequences) | ||
past_sequences.add(item.sequence_name) | ||
last_sequence = item.sequence_name | ||
else: | ||
self.assertEqual(item.sequence_name, last_sequence) | ||
self.assertEqual(item.frame_number, last_frame_number + 1) | ||
|
||
last_frame_number = item.frame_number | ||
|
||
# test indexing | ||
with self.assertRaises(IndexError): | ||
dataset[len(dataset) + 1] | ||
|
||
# test sequence-frame indexing | ||
item = dataset[sequence, frame_number] | ||
self.assertEqual(item.sequence_name, sequence) | ||
self.assertEqual(item.frame_number, frame_number) | ||
|
||
with self.assertRaises(IndexError): | ||
dataset[sequence, 13] | ||
|
||
def test_extending_frame_data(self, sequence="cat1_seq2", frame_number=4): | ||
dataset = CanineSqlIndexDataset( | ||
sqlite_metadata_file=self._temp_db.name, | ||
remove_empty_masks=False, | ||
frame_data_builder_CanineFrameDataBuilder_args=NO_BLOBS_KWARGS, | ||
) | ||
|
||
self.assertEqual(len(dataset), 100) | ||
|
||
# check the items are consecutive | ||
past_sequences = set() | ||
last_frame_number = -1 | ||
last_sequence = "" | ||
for i in range(len(dataset)): | ||
item = dataset[i] | ||
self.assertIsInstance(item, CanineFrameData) | ||
self.assertEqual(item.num_dogs, 101) | ||
self.assertIsNone(item.magnetic_field_average_flux_density) | ||
|
||
if item.frame_number == 0: | ||
self.assertNotIn(item.sequence_name, past_sequences) | ||
past_sequences.add(item.sequence_name) | ||
last_sequence = item.sequence_name | ||
else: | ||
self.assertEqual(item.sequence_name, last_sequence) | ||
self.assertEqual(item.frame_number, last_frame_number + 1) | ||
|
||
last_frame_number = item.frame_number | ||
|
||
# test indexing | ||
with self.assertRaises(IndexError): | ||
dataset[len(dataset) + 1] | ||
|
||
# test sequence-frame indexing | ||
item = dataset[sequence, frame_number] | ||
self.assertIsInstance(item, CanineFrameData) | ||
self.assertEqual(item.sequence_name, sequence) | ||
self.assertEqual(item.frame_number, frame_number) | ||
self.assertEqual(item.num_dogs, 101) | ||
|
||
with self.assertRaises(IndexError): | ||
dataset[sequence, 13] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
|
||
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory | ||
|
||
|
||
class TestOrmTypes(unittest.TestCase): | ||
def test_tuple_serialization_none(self): | ||
ttype = TupleTypeFactory()() | ||
output = ttype.process_bind_param(None, None) | ||
self.assertIsNone(output) | ||
output = ttype.process_result_value(output, None) | ||
self.assertIsNone(output) | ||
|
||
def test_tuple_serialization_1d(self): | ||
for input_tuple in [(1, 2, 3), (4.5, 6.7)]: | ||
ttype = TupleTypeFactory(type(input_tuple[0]), (len(input_tuple),))() | ||
output = ttype.process_bind_param(input_tuple, None) | ||
input_hat = ttype.process_result_value(output, None) | ||
self.assertEqual(type(input_hat[0]), type(input_tuple[0])) | ||
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6) | ||
|
||
def test_tuple_serialization_2d(self): | ||
input_tuple = ((1.0, 2.0, 3.0), (4.5, 5.5, 6.6)) | ||
ttype = TupleTypeFactory(type(input_tuple[0][0]), (2, 3))() | ||
output = ttype.process_bind_param(input_tuple, None) | ||
input_hat = ttype.process_result_value(output, None) | ||
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0])) | ||
# we use float32 to serialise | ||
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6) |