Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Using batched_to_model_list on a model with AppendFeature input transform #1273

Closed
benmltu opened this issue Jun 29, 2022 · 3 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@benmltu
Copy link
Contributor

benmltu commented Jun 29, 2022

🐛 Bug

Hello there, I ran into an issue when trying to sample GPs which use the AppendFeature input transform. This issue arises from using batched_to_model_list.

To reproduce

# System info:
# botorch 0.6.4
# gpytorch 1.6.0
# pytorch 1.11.0

from botorch.models.gp_regression import SingleTaskGP
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_model
from botorch.test_functions.multi_objective import BraninCurrin
from botorch.utils.sampling import draw_sobol_samples, draw_sobol_normal_samples
from botorch.models.transforms.input import AppendFeatures
from botorch.models.converter import batched_to_model_list

def fit_model(tx, ty, intf):
    model = SingleTaskGP(tx, ty, input_transform=intf)
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_model(mll)
    
    return model

problem = BraninCurrin(negate=True)
bounds = problem.bounds
d = problem.dim
m = problem.num_objectives
n_w = 5
n = 10

train_X = draw_sobol_samples(bounds=bounds, n=n, q=1, seed=123).squeeze(-2)
train_Y = problem(train_X)
perturbation_set = draw_sobol_normal_samples(d=1, n=n_w)
input_transform = AppendFeatures(feature_set=perturbation_set)

model = fit_model(train_X, train_Y, input_transform)
model_list = batched_to_model_list(model)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [5], in <cell line: 1>()
----> 1 model_list = batched_to_model_list(model)

File C:\ProgramData\Anaconda3\envs\scalarize\lib\site-packages\botorch\models\converter.py:274, in batched_to_model_list(batch_model)
    272         kwargs["outcome_transform"] = None
    273     model = batch_model.__class__(input_transform=input_transform, **kwargs)
--> 274     model.load_state_dict(sd)
    275     models.append(model)
    277 return ModelListGP(*models)

File C:\ProgramData\Anaconda3\envs\scalarize\lib\site-packages\torch\nn\modules\module.py:1497, in Module.load_state_dict(self, state_dict, strict)
   1492         error_msgs.insert(
   1493             0, 'Missing key(s) in state_dict: {}. '.format(
   1494                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1496 if len(error_msgs) > 0:
-> 1497     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1498                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1499 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for SingleTaskGP:
	size mismatch for covar_module.base_kernel.raw_lengthscale: copying a param with shape torch.Size([1, 2]) from checkpoint, the shape in current model is torch.Size([1, 3]).

Something about this particular input_transform appears to change after training:

input_transform = AppendFeatures(feature_set=perturbation_set)

m1 = SingleTaskGP(train_X, train_Y[:, 0].unsqueeze(-1), input_transform=input_transform)
print("raw_lengthscale={}".format(m1.covar_module.base_kernel.raw_lengthscale))

m2 = fit_model(train_X, train_Y[:, 0].unsqueeze(-1), input_transform)

m3 = SingleTaskGP(train_X, train_Y[:, 0].unsqueeze(-1), input_transform=input_transform)
print("raw_lengthscale={}".format(m3.covar_module.base_kernel.raw_lengthscale))

raw_lengthscale=Parameter containing:
tensor([[0., 0.]], requires_grad=True)

raw_lengthscale=Parameter containing:
tensor([[0., 0., 0.]], requires_grad=True)

I have checked that this error does not occur when using InputPerturbation instead.

@benmltu benmltu added the bug Something isn't working label Jun 29, 2022
@saitcakmak
Copy link
Contributor

Hi @benmltu! Thanks for flagging the issue and the detailed repro. I haven't yet looked into the root cause, but this works if the model is in train mode, i.e., the following should work:

model.train()
model_list = batched_to_model_list(model)

Putting the model in train mode will effectively disable the AppendFeatures transform. It probably comes into effect somewhere in the batched_to_model_list call if the model is in eval mode and falsely adds an additional output dimension to the hyper-parameters, which would explain the error.

@saitcakmak
Copy link
Contributor

Also, I am curious to hear about your use case for AppendFeatures transform (if you don't mind sharing).

@saitcakmak saitcakmak self-assigned this Jun 30, 2022
saitcakmak added a commit to saitcakmak/botorch that referenced this issue Jun 30, 2022
Summary:
Fixes pytorch#1273

During model construction, input transforms should be in `train` mode (so that they only apply if `transform_on_train` is true).
Having the input transforms in eval mode leads to buggy behavior due to `transformed_X` getting transformed when it shouldn't.

Differential Revision: D37542474

fbshipit-source-id: f871ca15743030d6ccdbb251bb52ab9fe5f62333
@benmltu
Copy link
Contributor Author

benmltu commented Jun 30, 2022

Thanks for the fix!

I was just looking at some of your recent work on risk measures and multi-objective BO and thought it would be cool to test these transformations in order to improve my own understanding.

saitcakmak added a commit to saitcakmak/botorch that referenced this issue Jun 30, 2022
…h#1283)

Summary:
Pull Request resolved: pytorch#1283

Fixes pytorch#1273

During model construction, input transforms should be in `train` mode (so that they only apply if `transform_on_train` is true).
Having the input transforms in eval mode leads to buggy behavior due to `transformed_X` getting transformed when it shouldn't.

Differential Revision: D37542474

fbshipit-source-id: f4278294de5d83d967f3d21c312370e562cf372c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants