Skip to content

Commit

Permalink
Add kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
Scienfitz committed May 17, 2024
1 parent c353728 commit d3b4423
Showing 1 changed file with 174 additions and 0 deletions.
174 changes: 174 additions & 0 deletions baybe/kernels/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,62 @@
from baybe.utils.validation import finite_float


@define(frozen=True)
class CosineKernel(Kernel):
"""A cosine kernel."""

period_length_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the kernel periodic length."""

period_length_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel periodic length."""

def to_gpytorch(self, *args, **kwargs): # noqa: D102
# See base class.
import torch

from baybe.utils.torch import DTypeFloatTorch

gpytorch_kernel = super().to_gpytorch(*args, **kwargs)
if (initial_value := self.period_length_initial_value) is not None:
gpytorch_kernel.period_length = torch.tensor(
initial_value, dtype=DTypeFloatTorch
)
return gpytorch_kernel


@define(frozen=True)
class LinearKernel(Kernel):
"""A linear kernel."""

variance_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the variance parameter."""

variance_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the variance parameter."""

def to_gpytorch(self, *args, **kwargs): # noqa: D102
# See base class.
import torch

from baybe.utils.torch import DTypeFloatTorch

gpytorch_kernel = super().to_gpytorch(*args, **kwargs)
if (initial_value := self.variance_initial_value) is not None:
gpytorch_kernel.variance = torch.tensor(
initial_value, dtype=DTypeFloatTorch
)
return gpytorch_kernel


@define(frozen=True)
class MaternKernel(Kernel):
"""A Matern kernel using a smoothness parameter."""
Expand All @@ -34,3 +90,121 @@ class MaternKernel(Kernel):
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel lengthscale."""


@define(frozen=True)
class PeriodicKernel(Kernel):
"""A periodic kernel."""

lengthscale_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel lengthscale."""

period_length_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the kernel periodic length."""

period_length_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel periodic length."""

def to_gpytorch(self, *args, **kwargs): # noqa: D102
# See base class.
import torch

from baybe.utils.torch import DTypeFloatTorch

gpytorch_kernel = super().to_gpytorch(*args, **kwargs)
# lengthscale is handled by the base class

if (initial_value := self.period_length_initial_value) is not None:
gpytorch_kernel.period_length = torch.tensor(
initial_value, dtype=DTypeFloatTorch
)
return gpytorch_kernel


@define(frozen=True)
class PiecewisePolynomialKernel(Kernel):
"""A piecewise polynomial kernel."""

q: float = field(converter=int, validator=in_([0, 1, 2, 3]), default=2)
"""A smoothness parameter."""

lengthscale_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel lengthscale."""


@define(frozen=True)
class PolynomialKernel(Kernel):
"""A polynomial kernel."""

power: float = field(converter=int, default=2)
"""The power of the polynomial term."""

offset_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the kernel offset."""

offset_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel offset."""

def to_gpytorch(self, *args, **kwargs): # noqa: D102
# See base class.
import torch

from baybe.utils.torch import DTypeFloatTorch

gpytorch_kernel = super().to_gpytorch(*args, **kwargs)

if (initial_value := self.offset_initial_value) is not None:
gpytorch_kernel.offset = torch.tensor(initial_value, dtype=DTypeFloatTorch)
return gpytorch_kernel


@define(frozen=True)
class RBFKernel(Kernel):
"""A radial basis function (RBF) kernel."""

lengthscale_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel lengthscale."""


@define(frozen=True)
class RQKernel(Kernel):
"""A rational quadratic (RQ) kernel."""

lengthscale_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel lengthscale."""

0 comments on commit d3b4423

Please sign in to comment.