Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expanding on docstrings in BoTorch Model #1496

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

r"""Abstract base module for all BoTorch models.

Contains `Model`, the abstract base class for all BoTorch models, and
`ModelList`, a container for a list of Models.
This module contains `Model`, the abstract base class for all BoTorch models,
and `ModelList`, a container for a list of Models.
"""

from __future__ import annotations
Expand Down Expand Up @@ -46,8 +46,18 @@
class Model(Module, ABC):
r"""Abstract base class for BoTorch models.

Model cannot be used directly; it only defines an API for other BoTorch
models.
The `Model` base class cannot be used directly; it only defines an API for other
BoTorch models.

`Model` subclasses `torch.nn.Module`. While a `Module` is most typically
encountered as a representation of a neural network layer, it can be used more
generally: see
`documentation <https://pytorch.org/tutorials/beginner/examples_nn/polynomial_module.html>`_
on custom NN Modules.

`Module` provides several pieces of useful functionality: A `Model`'s attributes of
`Tensor` or `Module` type are automatically registered so they can be moved and/or
cast with the `to` method, automatically differentiated, and used with CUDA.

Args:
_has_transformed_inputs: A boolean denoting whether `train_inputs` are currently
Expand All @@ -56,7 +66,7 @@ class Model(Module, ABC):
`_revert_to_original_inputs`. Note that this is necessary since
transform / untransform cycle introduces numerical errors which lead
to upstream errors during training.
"""
""" # noqa: E501

_has_transformed_inputs: bool = False
_original_train_inputs: Optional[Tensor] = None
Expand Down Expand Up @@ -215,7 +225,8 @@ def eval(self) -> Model:
return super().eval()

def train(self, mode: bool = True) -> Model:
r"""Puts the model in `train` mode and reverts to the original inputs.
r"""Put the model in `train` mode. Reverts to the original inputs if in `train`
mode (`mode=True`) or sets transformed inputs if in `eval` mode (`mode=False`).

Args:
mode: A boolean denoting whether to put in `train` or `eval` mode.
Expand Down