diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index ef68274e58..c33a040d32 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -37,30 +37,27 @@ def __call__(cls, *args: Any, **kwargs: Any): r"""Converts Tensor-valued fields to DenseContainer under the assumption that said fields house collections of feature vectors.""" hints = get_type_hints(cls) - f_iter = filter( - lambda f: f.init and issubclass(hints[f.name], BotorchContainer), - fields(cls), - ) + f_iter = filter(lambda f: f.init, fields(cls)) f_dict = {} for obj, f in chain( zip(args, f_iter), ((kwargs.pop(f.name, MISSING), f) for f in f_iter) ): if obj is MISSING: if f.default is not MISSING: - obj = f.default elif f.default_factory is not MISSING: obj = f.default_factory() else: raise RuntimeError(f"Missing required field `{f.name}`.") - if isinstance(obj, Tensor): - obj = DenseContainer(obj, event_shape=obj.shape[-1:]) - elif not isinstance(obj, BotorchContainer): - raise TypeError( - f"Expected for field `{f.name}` " - f"but was {type(obj)}." - ) + if issubclass(hints[f.name], BotorchContainer): + if isinstance(obj, Tensor): + obj = DenseContainer(obj, event_shape=obj.shape[-1:]) + elif not isinstance(obj, BotorchContainer): + raise TypeError( + f"Expected for field `{f.name}` " + f"but was {type(obj)}." + ) f_dict[f.name] = obj return super().__call__(**f_dict, **kwargs) diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index 0eb088047c..fd550a2846 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -15,7 +15,7 @@ SupervisedDataset, ) from botorch.utils.testing import BotorchTestCase -from torch import rand, randperm, Size, stack, tensor +from torch import rand, randperm, Size, stack, tensor, Tensor class TestDatasets(BotorchTestCase): @@ -31,6 +31,7 @@ def test_supervised_meta(self): X = rand(3, 2) Y = rand(3, 1) A = DenseContainer(rand(3, 5), event_shape=Size([5])) + B = rand(2, 1) SupervisedDatasetWithDefaults = make_dataclass( cls_name="SupervisedDatasetWithDefaults", @@ -38,6 +39,7 @@ def test_supervised_meta(self): fields=[ ("default", DenseContainer, field(default=A)), ("factory", DenseContainer, field(default_factory=lambda: A)), + ("other", Tensor, field(default_factory=lambda: B)), ], ) @@ -55,13 +57,23 @@ def test_supervised_meta(self): dataset = SupervisedDatasetWithDefaults(X=X, Y=Y) self.assertEqual(dataset.default, A) self.assertEqual(dataset.factory, A) + self.assertTrue(dataset.other is B) # Check type coercion - dataset = SupervisedDatasetWithDefaults(X=X, Y=Y, default=X, factory=Y) + dataset = SupervisedDatasetWithDefaults(X=X, Y=Y, default=X, factory=Y, other=B) self.assertIsInstance(dataset.X, DenseContainer) self.assertIsInstance(dataset.Y, DenseContainer) self.assertEqual(dataset.default, dataset.X) self.assertEqual(dataset.factory, dataset.Y) + self.assertTrue(dataset.other is B) + + # Check handling of positional arguments + dataset = SupervisedDatasetWithDefaults(X, Y, X, Y, X) + self.assertIsInstance(dataset.X, DenseContainer) + self.assertIsInstance(dataset.Y, DenseContainer) + self.assertEqual(dataset.default, dataset.X) + self.assertEqual(dataset.factory, dataset.Y) + self.assertTrue(dataset.other is X) def test_supervised(self): # Generate some data