Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of non-Container-typed positional arguments in SupervisedDatasetMeta #1663

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,27 @@ def __call__(cls, *args: Any, **kwargs: Any):
r"""Converts Tensor-valued fields to DenseContainer under the assumption
that said fields house collections of feature vectors."""
hints = get_type_hints(cls)
f_iter = filter(
lambda f: f.init and issubclass(hints[f.name], BotorchContainer),
fields(cls),
)
f_iter = filter(lambda f: f.init, fields(cls))
f_dict = {}
for obj, f in chain(
zip(args, f_iter), ((kwargs.pop(f.name, MISSING), f) for f in f_iter)
):
if obj is MISSING:
if f.default is not MISSING:

obj = f.default
elif f.default_factory is not MISSING:
obj = f.default_factory()
else:
raise RuntimeError(f"Missing required field `{f.name}`.")

if isinstance(obj, Tensor):
obj = DenseContainer(obj, event_shape=obj.shape[-1:])
elif not isinstance(obj, BotorchContainer):
raise TypeError(
f"Expected <BotorchContainer | Tensor> for field `{f.name}` "
f"but was {type(obj)}."
)
if issubclass(hints[f.name], BotorchContainer):
if isinstance(obj, Tensor):
obj = DenseContainer(obj, event_shape=obj.shape[-1:])
elif not isinstance(obj, BotorchContainer):
raise TypeError(
f"Expected <BotorchContainer | Tensor> for field `{f.name}` "
f"but was {type(obj)}."
)
f_dict[f.name] = obj

return super().__call__(**f_dict, **kwargs)
Expand Down
16 changes: 14 additions & 2 deletions test/utils/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SupervisedDataset,
)
from botorch.utils.testing import BotorchTestCase
from torch import rand, randperm, Size, stack, tensor
from torch import rand, randperm, Size, stack, tensor, Tensor


class TestDatasets(BotorchTestCase):
Expand All @@ -31,13 +31,15 @@ def test_supervised_meta(self):
X = rand(3, 2)
Y = rand(3, 1)
A = DenseContainer(rand(3, 5), event_shape=Size([5]))
B = rand(2, 1)

SupervisedDatasetWithDefaults = make_dataclass(
cls_name="SupervisedDatasetWithDefaults",
bases=(SupervisedDataset,),
fields=[
("default", DenseContainer, field(default=A)),
("factory", DenseContainer, field(default_factory=lambda: A)),
("other", Tensor, field(default_factory=lambda: B)),
],
)

Expand All @@ -55,13 +57,23 @@ def test_supervised_meta(self):
dataset = SupervisedDatasetWithDefaults(X=X, Y=Y)
self.assertEqual(dataset.default, A)
self.assertEqual(dataset.factory, A)
self.assertTrue(dataset.other is B)

# Check type coercion
dataset = SupervisedDatasetWithDefaults(X=X, Y=Y, default=X, factory=Y)
dataset = SupervisedDatasetWithDefaults(X=X, Y=Y, default=X, factory=Y, other=B)
self.assertIsInstance(dataset.X, DenseContainer)
self.assertIsInstance(dataset.Y, DenseContainer)
self.assertEqual(dataset.default, dataset.X)
self.assertEqual(dataset.factory, dataset.Y)
self.assertTrue(dataset.other is B)

# Check handling of positional arguments
dataset = SupervisedDatasetWithDefaults(X, Y, X, Y, X)
self.assertIsInstance(dataset.X, DenseContainer)
self.assertIsInstance(dataset.Y, DenseContainer)
self.assertEqual(dataset.default, dataset.X)
self.assertEqual(dataset.factory, dataset.Y)
self.assertTrue(dataset.other is X)

def test_supervised(self):
# Generate some data
Expand Down