From 3e3644e5344ed89550188a511f7bb3937fe75d36 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Thu, 4 May 2023 03:32:27 -0700 Subject: [PATCH] More tests for SQL Dataset Summary: I forgot to include these tests to D45086611 when transferring code from pixar_replay repo. They test the new ORM types used in SQL dataset and are SQL Alchemy 2.0 specific. An important test for extending types is a proof of concept for generality of SQL Dataset. The idea is to extend FrameAnnotation and FrameData in parallel. Reviewed By: bottler Differential Revision: D45529284 fbshipit-source-id: 2a634e518f580c312602107c85fc320db43abcf5 --- tests/implicitron/test_extending_orm_types.py | 230 ++++++++++++++++++ tests/implicitron/test_orm_types.py | 37 +++ 2 files changed, 267 insertions(+) create mode 100644 tests/implicitron/test_extending_orm_types.py create mode 100644 tests/implicitron/test_orm_types.py diff --git a/tests/implicitron/test_extending_orm_types.py b/tests/implicitron/test_extending_orm_types.py new file mode 100644 index 000000000..abfb3aabf --- /dev/null +++ b/tests/implicitron/test_extending_orm_types.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import logging +import os +import tempfile +import unittest +from typing import ClassVar, Optional, Type + +import pandas as pd +import pkg_resources +import sqlalchemy as sa + +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.frame_data import FrameData, GenericFrameDataBuilder +from pytorch3d.implicitron.dataset.orm_types import ( + SqlFrameAnnotation, + SqlSequenceAnnotation, +) +from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset +from pytorch3d.implicitron.dataset.utils import GenericWorkaround +from pytorch3d.implicitron.tools.config import registry +from sqlalchemy.orm import composite, Mapped, mapped_column, Session + +NO_BLOBS_KWARGS = { + "dataset_root": "", + "load_images": False, + "load_depths": False, + "load_masks": False, + "load_depth_masks": False, + "box_crop": False, +} + +DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset") +METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite") + +logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset") +sh = logging.StreamHandler() +logger.addHandler(sh) +logger.setLevel(logging.DEBUG) + + +@dataclasses.dataclass +class MagneticFieldAnnotation: + path: str + average_flux_density: Optional[float] = None + + +class ExtendedSqlFrameAnnotation(SqlFrameAnnotation): + num_dogs: Mapped[Optional[int]] = mapped_column(default=None) + + magnetic_field: Mapped[MagneticFieldAnnotation] = composite( + mapped_column("_magnetic_field_path", nullable=True), + mapped_column("_magnetic_field_average_flux_density", nullable=True), + default_factory=lambda: None, + ) + + +class ExtendedSqlIndexDataset(SqlIndexDataset): + frame_annotations_type: ClassVar[ + Type[SqlFrameAnnotation] + ] = ExtendedSqlFrameAnnotation + + +class CanineFrameData(FrameData): + num_dogs: Optional[int] = None + magnetic_field_average_flux_density: Optional[float] = None + + +@registry.register +class CanineFrameDataBuilder( + GenericWorkaround, GenericFrameDataBuilder[CanineFrameData] +): + """ + A concrete class to build an extended FrameData object + """ + + frame_data_type: ClassVar[Type[FrameData]] = CanineFrameData + + def build( + self, + frame_annotation: ExtendedSqlFrameAnnotation, + sequence_annotation: types.SequenceAnnotation, + load_blobs: bool = True, + ) -> CanineFrameData: + frame_data = super().build(frame_annotation, sequence_annotation, load_blobs) + frame_data.num_dogs = frame_annotation.num_dogs or 101 + frame_data.magnetic_field_average_flux_density = ( + frame_annotation.magnetic_field.average_flux_density + ) + return frame_data + + +class CanineSqlIndexDataset(SqlIndexDataset): + frame_annotations_type: ClassVar[ + Type[SqlFrameAnnotation] + ] = ExtendedSqlFrameAnnotation + + frame_data_builder_class_type: str = "CanineFrameDataBuilder" + + +class TestExtendingOrmTypes(unittest.TestCase): + def setUp(self): + # create a temporary copy of the DB with an extended schema + engine = sa.create_engine(f"sqlite:///{METADATA_FILE}") + with Session(engine) as session: + extended_annots = [ + ExtendedSqlFrameAnnotation( + **{ + k: v + for k, v in frame_annot.__dict__.items() + if not k.startswith("_") # remove mapped fields and SA metadata + } + ) + for frame_annot in session.scalars(sa.select(SqlFrameAnnotation)) + ] + seq_annots = session.scalars( + sa.select(SqlSequenceAnnotation), + execution_options={"prebuffer_rows": True}, + ) + session.expunge_all() + + self._temp_db = tempfile.NamedTemporaryFile(delete=False) + engine_ext = sa.create_engine(f"sqlite:///{self._temp_db.name}") + ExtendedSqlFrameAnnotation.metadata.create_all(engine_ext, checkfirst=True) + with Session(engine_ext, expire_on_commit=False) as session_ext: + session_ext.add_all(extended_annots) + for instance in seq_annots: + session_ext.merge(instance) + session_ext.commit() + + # check the setup is correct + with engine_ext.connect() as connection_ext: + df = pd.read_sql_query( + sa.select(ExtendedSqlFrameAnnotation), connection_ext + ) + self.assertEqual(len(df), 100) + self.assertIn("_magnetic_field_average_flux_density", df.columns) + + df_seq = pd.read_sql_query(sa.select(SqlSequenceAnnotation), connection_ext) + self.assertEqual(len(df_seq), 10) + + def tearDown(self): + self._temp_db.close() + os.remove(self._temp_db.name) + + def test_basic(self, sequence="cat1_seq2", frame_number=4): + dataset = ExtendedSqlIndexDataset( + sqlite_metadata_file=self._temp_db.name, + remove_empty_masks=False, + frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 100) + + # check the items are consecutive + past_sequences = set() + last_frame_number = -1 + last_sequence = "" + for i in range(len(dataset)): + item = dataset[i] + + if item.frame_number == 0: + self.assertNotIn(item.sequence_name, past_sequences) + past_sequences.add(item.sequence_name) + last_sequence = item.sequence_name + else: + self.assertEqual(item.sequence_name, last_sequence) + self.assertEqual(item.frame_number, last_frame_number + 1) + + last_frame_number = item.frame_number + + # test indexing + with self.assertRaises(IndexError): + dataset[len(dataset) + 1] + + # test sequence-frame indexing + item = dataset[sequence, frame_number] + self.assertEqual(item.sequence_name, sequence) + self.assertEqual(item.frame_number, frame_number) + + with self.assertRaises(IndexError): + dataset[sequence, 13] + + def test_extending_frame_data(self, sequence="cat1_seq2", frame_number=4): + dataset = CanineSqlIndexDataset( + sqlite_metadata_file=self._temp_db.name, + remove_empty_masks=False, + frame_data_builder_CanineFrameDataBuilder_args=NO_BLOBS_KWARGS, + ) + + self.assertEqual(len(dataset), 100) + + # check the items are consecutive + past_sequences = set() + last_frame_number = -1 + last_sequence = "" + for i in range(len(dataset)): + item = dataset[i] + self.assertIsInstance(item, CanineFrameData) + self.assertEqual(item.num_dogs, 101) + self.assertIsNone(item.magnetic_field_average_flux_density) + + if item.frame_number == 0: + self.assertNotIn(item.sequence_name, past_sequences) + past_sequences.add(item.sequence_name) + last_sequence = item.sequence_name + else: + self.assertEqual(item.sequence_name, last_sequence) + self.assertEqual(item.frame_number, last_frame_number + 1) + + last_frame_number = item.frame_number + + # test indexing + with self.assertRaises(IndexError): + dataset[len(dataset) + 1] + + # test sequence-frame indexing + item = dataset[sequence, frame_number] + self.assertIsInstance(item, CanineFrameData) + self.assertEqual(item.sequence_name, sequence) + self.assertEqual(item.frame_number, frame_number) + self.assertEqual(item.num_dogs, 101) + + with self.assertRaises(IndexError): + dataset[sequence, 13] diff --git a/tests/implicitron/test_orm_types.py b/tests/implicitron/test_orm_types.py new file mode 100644 index 000000000..7570b002b --- /dev/null +++ b/tests/implicitron/test_orm_types.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import numpy as np + +from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory + + +class TestOrmTypes(unittest.TestCase): + def test_tuple_serialization_none(self): + ttype = TupleTypeFactory()() + output = ttype.process_bind_param(None, None) + self.assertIsNone(output) + output = ttype.process_result_value(output, None) + self.assertIsNone(output) + + def test_tuple_serialization_1d(self): + for input_tuple in [(1, 2, 3), (4.5, 6.7)]: + ttype = TupleTypeFactory(type(input_tuple[0]), (len(input_tuple),))() + output = ttype.process_bind_param(input_tuple, None) + input_hat = ttype.process_result_value(output, None) + self.assertEqual(type(input_hat[0]), type(input_tuple[0])) + np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6) + + def test_tuple_serialization_2d(self): + input_tuple = ((1.0, 2.0, 3.0), (4.5, 5.5, 6.6)) + ttype = TupleTypeFactory(type(input_tuple[0][0]), (2, 3))() + output = ttype.process_bind_param(input_tuple, None) + input_hat = ttype.process_result_value(output, None) + self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0])) + # we use float32 to serialise + np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)