Skip to content

Commit

Permalink
Rename *diag() -> *diagonal()
Browse files Browse the repository at this point in the history
- LinearOperator subclasses implement _diagonal() rather than _diag()
- The public diagonal() method matches the same API as torch (i.e. has
    offset, dim1, dim2 arguments - though it only works when offset=0,
    dim1=-1, and dim2=-2).
- _approx_diag() -> _approx_diagonal()
- add_diag() -> add_diagonal()
  • Loading branch information
gpleiss committed May 25, 2022
1 parent 00077a2 commit caead4d
Show file tree
Hide file tree
Showing 34 changed files with 225 additions and 234 deletions.
9 changes: 7 additions & 2 deletions linear_operator/functions/_pivoted_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ def forward(ctx, representation_tree, max_iter, error_tol, *matrix_args):
error_tol = settings.preconditioner_tolerance.value()

# Need to get diagonals. This is easy if it's a LinearOperator, since
# LinearOperator.diag() operates in batch mode.
matrix_diag = matrix._approx_diag()
# LinearOperator.diagonal() operates in batch mode.
matrix_diag = matrix._approx_diagonal()
# NOTE: we will be performing inpace operations on matrix_diag
# Calling _approx_diag() may return a tensor that shares the same storage as the LinearOperator
# To ensure that we are not mutating any of the LinearOperator's entries, we need to clone
# the diagonal.
matrix_diag = matrix_diag.clone()

# Make sure max_iter isn't bigger than the matrix
max_iter = min(max_iter, matrix_shape[-1])
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_root_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(
t_mat = t_mat.unsqueeze(0)
n_probes = t_mat.size(0)

mins = to_linear_operator(t_mat).diag().min(dim=-1, keepdim=True)[0].unsqueeze(-1)
mins = to_linear_operator(t_mat)._diagonal().min(dim=-1, keepdim=True)[0].unsqueeze(-1)
jitter_mat = (settings.tridiagonal_jitter.value() * mins) * torch.eye(
t_mat.size(-1), device=t_mat.device, dtype=t_mat.dtype
).expand_as(t_mat)
Expand Down
51 changes: 35 additions & 16 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _args(self) -> Tuple[Union[torch.Tensor, "LinearOperator"], ...]:
def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator"], ...]) -> None:
self._args_memo = args

def _approx_diag(self) -> torch.Tensor:
def _approx_diagonal(self) -> torch.Tensor:
"""
(Optional) returns an (approximate) diagonal of the matrix
Expand All @@ -398,7 +398,7 @@ def _approx_diag(self) -> torch.Tensor:
:return: the (batch of) diagonals (... x N)
"""
return self.diag()
return self._diagonal()

@cached(name="cholesky")
def _cholesky(self, upper: bool = False) -> "TriangularLinearOperator": # noqa F811
Expand Down Expand Up @@ -458,6 +458,19 @@ def _choose_root_method(self) -> str:
return "cholesky"
return "lanczos"

def _diagonal(self) -> torch.Tensor:
r"""
As :func:`torch._diagonal`, returns the diagonal of the matrix
:math:`\mathbf A` this LinearOperator represents as a vector.
.. note::
This method is used as an internal helper. Calling this method directly is discouraged.
:return: The diagonal (or batch of diagonals) of :math:`\mathbf A`.
"""
row_col_iter = torch.arange(0, self.matrix_shape[-1], dtype=torch.long, device=self.device)
return self[..., row_col_iter, row_col_iter]

def _inv_matmul_preconditioner(self) -> Callable:
r"""
(Optional) define a preconditioner :math:`\mathbf P` that can be used for linear systems,
Expand Down Expand Up @@ -799,7 +812,7 @@ def add(self, other: Union[torch.Tensor, "LinearOperator"], alpha: float = None)
else:
return self + alpha * other

def add_diag(self, diag: torch.Tensor) -> LinearOperator:
def add_diagonal(self, diag: torch.Tensor) -> LinearOperator:
r"""
Adds an element to the diagonal of the matrix.
Expand All @@ -811,7 +824,7 @@ def add_diag(self, diag: torch.Tensor) -> LinearOperator:
from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator

if not self.is_square:
raise RuntimeError("add_diag only defined for square matrices")
raise RuntimeError("add_diagonal only defined for square matrices")

diag_shape = diag.shape
if len(diag_shape) == 0:
Expand All @@ -825,7 +838,7 @@ def add_diag(self, diag: torch.Tensor) -> LinearOperator:
expanded_diag = diag.expand(self.shape[:-1])
except RuntimeError:
raise RuntimeError(
"add_diag for LinearOperator of size {} received invalid diagonal of size {}.".format(
"add_diagonal for LinearOperator of size {} received invalid diagonal of size {}.".format(
self.shape, diag_shape
)
)
Expand All @@ -837,15 +850,15 @@ def add_jitter(self, jitter_val: float = 1e-3) -> LinearOperator:
r"""
Adds jitter (i.e., a small diagonal component) to the matrix this
LinearOperator represents.
This is equivalent to calling :meth:`~linear_operator.operators.LinearOperator.add_diag`
This is equivalent to calling :meth:`~linear_operator.operators.LinearOperator.add_diagonal`
with a scalar tensor.
:param jitter_val: The diagonal component to add
:return: :math:`\mathbf A + \alpha (\mathbf I)`, where :math:`\mathbf A` is the linear operator
and :math:`\alpha` is :attr:`jitter_val`.
"""
diag = torch.tensor(jitter_val, dtype=self.dtype, device=self.device)
return self.add_diag(diag)
return self.add_diagonal(diag)

def add_low_rank(
self,
Expand Down Expand Up @@ -1210,20 +1223,26 @@ def detach_(self) -> LinearOperator:
val.detach_()
return self

# TODO: rename to diagonal
def diag(self) -> torch.Tensor:
def diagonal(self, offset: int = 0, dim1: int = -2, dim2: int = -1) -> torch.Tensor:
r"""
As :func:`torch.diag`, returns the diagonal of the matrix
As :func:`torch.diagonal`, returns the diagonal of the matrix
:math:`\mathbf A` this LinearOperator represents as a vector.
.. note::
This method is only implemented for when :attr:`dim1` and :attr:`dim2` are equal
to -2 and -1, respectfully, and :attr:`offset = 0`.
:return: The diagonal (or batch of diagonals) of :math:`\mathbf A`.
"""
if settings.debug.on():
if not self.is_square:
raise RuntimeError("Diag works on square matrices (or batches)")

row_col_iter = torch.arange(0, self.matrix_shape[-1], dtype=torch.long, device=self.device)
return self[..., row_col_iter, row_col_iter]
if not offset == 0 and ((dim1 == -2 and dim2 == -1) or (dim1 == -1 and dim2 == -2)):
raise NotImplementedError(
"LinearOperator#diagonal is only implemented for when :attr:`dim1` and :attr:`dim2` are equal "
"to -2 and -1, respectfully, and :attr:`offset = 0`. "
f"Got: offset={offset}, dim1={dim1}, dim2={dim2}."
)
elif not self.is_square:
raise RuntimeError("LinearOperator#diagonal is only implemented for square operators.")
return self._diagonal()

@cached(name="diagonalization")
def diagonalization(self, method: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
14 changes: 7 additions & 7 deletions linear_operator/operators/added_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self, *linear_ops, preconditioner_override=None):
def _matmul(self, rhs):
return torch.addcmul(self._linear_op._matmul(rhs), self._diag_tensor._diag.unsqueeze(-1), rhs)

def add_diag(self, added_diag):
return self.__class__(self._linear_op, self._diag_tensor.add_diag(added_diag))
def add_diagonal(self, added_diag):
return self.__class__(self._linear_op, self._diag_tensor.add_diagonal(added_diag))

def __add__(self, other):
from .diag_linear_operator import DiagLinearOperator
Expand Down Expand Up @@ -122,7 +122,7 @@ def precondition_closure(tensor):

def _init_cache(self):
*batch_shape, n, k = self._piv_chol_self.shape
self._noise = self._diag_tensor.diag().unsqueeze(-1)
self._noise = self._diag_tensor._diagonal().unsqueeze(-1)

# the check for constant diag needs to be done carefully for batches.
noise_first_element = self._noise[..., :1, :]
Expand Down Expand Up @@ -166,14 +166,14 @@ def _init_cache_for_non_constant_diag(self, eye, batch_shape, n):
def _svd(self) -> Tuple["LinearOperator", Tensor, "LinearOperator"]:
if isinstance(self._diag_tensor, ConstantDiagLinearOperator):
U, S_, V = self._linear_op.svd()
S = S_ + self._diag_tensor.diag()
S = S_ + self._diag_tensor._diagonal()
return U, S, V
return super()._svd()

def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LinearOperator]]:
if isinstance(self._diag_tensor, ConstantDiagLinearOperator):
evals_, evecs = self._linear_op.symeig(eigenvectors=eigenvectors)
evals = evals_ + self._diag_tensor.diag()
evals = evals_ + self._diag_tensor._diagonal()
return evals, evecs
return super()._symeig(eigenvectors=eigenvectors)

Expand All @@ -182,8 +182,8 @@ def evaluate_kernel(self):
Overriding this is currently necessary to allow for subclasses of AddedDiagLT to be created. For example,
consider the following:
>>> covar1 = covar_module(x).add_diag(torch.tensor(1.)).evaluate_kernel()
>>> covar2 = covar_module(x).evaluate_kernel().add_diag(torch.tensor(1.))
>>> covar1 = covar_module(x).add_diagonal(torch.tensor(1.)).evaluate_kernel()
>>> covar2 = covar_module(x).evaluate_kernel().add_diagonal(torch.tensor(1.))
Unless we override this method (or find a better solution), covar1 and covar2 might not be the same type.
In particular, covar1 would *always* be a standard AddedDiagLinearOperator, but covar2 might be a subtype.
Expand Down
8 changes: 4 additions & 4 deletions linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def _cholesky_solve(self, rhs, upper: bool = False):
res = self._remove_batch_dim(res)
return res

def _diagonal(self):
res = self.base_linear_op._diagonal().contiguous()
return res.view(*self.batch_shape, self.size(-1))

def _get_indices(self, row_index, col_index, *batch_indices):
# Figure out what block the row/column indices belong to
row_index_block = torch.div(row_index, self.base_linear_op.size(-2), rounding_mode="floor")
Expand Down Expand Up @@ -94,10 +98,6 @@ def _solve(self, rhs, preconditioner, num_tridiag=0):
res = self._remove_batch_dim(res)
return res

def diag(self):
res = self.base_linear_op.diag().contiguous()
return res.view(*self.batch_shape, self.size(-1))

def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
if inv_quad_rhs is not None:
inv_quad_rhs = self._add_batch_dim(inv_quad_rhs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def _cholesky_solve(self, rhs, upper: bool = False):
res = self._remove_batch_dim(res)
return res

def _diagonal(self):
block_diag = self.base_linear_op._diagonal()
return block_diag.transpose(-1, -2).contiguous().view(*block_diag.shape[:-2], -1)

def _get_indices(self, row_index, col_index, *batch_indices):
# Figure out what block the row/column indices belong to
row_index_block = row_index.fmod(self.base_linear_op.size(-3))
Expand Down Expand Up @@ -92,10 +96,6 @@ def _solve(self, rhs, preconditioner, num_tridiag=0):
res = self._remove_batch_dim(res)
return res

def diag(self):
block_diag = self.base_linear_op.diag()
return block_diag.transpose(-1, -2).contiguous().view(*block_diag.shape[:-2], -1)

def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
if inv_quad_rhs is not None:
inv_quad_rhs = self._add_batch_dim(inv_quad_rhs)
Expand Down
55 changes: 25 additions & 30 deletions linear_operator/operators/cat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch

from .. import settings
from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape, _to_helper
from ..utils.deprecation import bool_compat
from ..utils.getitem import _noop_index
Expand Down Expand Up @@ -125,6 +124,31 @@ def _split_slice(self, slice_idx):
[first_slice] + [_noop_index] * num_middle_tensors + [last_slice],
)

def _diagonal(self):
if self.cat_dim == -2:
res = []
curr_col = 0
for t in self.linear_ops:
n_rows, n_cols = t.shape[-2:]
rows = torch.arange(0, n_rows, dtype=torch.long, device=t.device)
cols = torch.arange(curr_col, curr_col + n_rows, dtype=torch.long, device=t.device)
res.append(t[..., rows, cols].to(self.device))
curr_col += n_rows
res = torch.cat(res, dim=-1)
elif self.cat_dim == -1:
res = []
curr_row = 0
for t in self.linear_ops:
n_rows, n_cols = t.shape[-2:]
rows = torch.arange(curr_row, curr_row + n_cols, dtype=torch.long, device=t.device)
cols = torch.arange(0, n_cols, dtype=torch.long, device=t.device)
curr_row += n_cols
res.append(t[..., rows, cols].to(self.device))
res = torch.cat(res, dim=-1)
else:
res = torch.cat([t._diagonal().to(self.device) for t in self.linear_ops], dim=self.cat_dim + 1)
return res

def _expand_batch(self, batch_shape):
batch_dim = self.cat_dim + 2
if batch_dim < 0:
Expand Down Expand Up @@ -317,35 +341,6 @@ def _unsqueeze_batch(self, dim):
)
return res

def diag(self):
if settings.debug.on():
if not self.is_square:
raise RuntimeError("Diag works on square matrices (or batches)")

if self.cat_dim == -2:
res = []
curr_col = 0
for t in self.linear_ops:
n_rows, n_cols = t.shape[-2:]
rows = torch.arange(0, n_rows, dtype=torch.long, device=t.device)
cols = torch.arange(curr_col, curr_col + n_rows, dtype=torch.long, device=t.device)
res.append(t[..., rows, cols].to(self.device))
curr_col += n_rows
res = torch.cat(res, dim=-1)
elif self.cat_dim == -1:
res = []
curr_row = 0
for t in self.linear_ops:
n_rows, n_cols = t.shape[-2:]
rows = torch.arange(curr_row, curr_row + n_cols, dtype=torch.long, device=t.device)
cols = torch.arange(0, n_cols, dtype=torch.long, device=t.device)
curr_row += n_cols
res.append(t[..., rows, cols].to(self.device))
res = torch.cat(res, dim=-1)
else:
res = torch.cat([t.diag().to(self.device) for t in self.linear_ops], dim=self.cat_dim + 1)
return res

def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
res = super().inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad)
return tuple(r.to(self.device) for r in res)
Expand Down
12 changes: 6 additions & 6 deletions linear_operator/operators/chol_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, chol: _TriangularLinearOperatorBase, upper: bool = False):

@property
def _chol_diag(self):
return self.root.diag()
return self.root._diagonal()

@cached(name="cholesky")
def _cholesky(self, upper=False):
Expand All @@ -38,16 +38,16 @@ def _cholesky(self, upper=False):
else:
return self.root._transpose_nonbatch()

@cached
def _diagonal(self):
# TODO: Can we be smarter here?
return (self.root.to_dense() ** 2).sum(-1)

def _solve(self, rhs, preconditioner, num_tridiag=0):
if num_tridiag:
return super()._solve(rhs, preconditioner, num_tridiag=num_tridiag)
return self.root._cholesky_solve(rhs, upper=self.upper)

@cached
def diag(self):
# TODO: Can we be smarter here?
return (self.root.to_dense() ** 2).sum(-1)

@cached
def to_dense(self):
root = self.root
Expand Down
12 changes: 6 additions & 6 deletions linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ def __init__(self, base_linear_op, constant):
self.base_linear_op = base_linear_op
self._constant = constant

def _approx_diag(self):
res = self.base_linear_op._approx_diag()
def _approx_diagonal(self):
res = self.base_linear_op._approx_diagonal()
return res * self._constant.unsqueeze(-1)

def _diagonal(self):
res = self.base_linear_op._diagonal()
return res * self._constant.unsqueeze(-1)

def _expand_batch(self, batch_shape):
Expand Down Expand Up @@ -147,10 +151,6 @@ def expanded_constant(self):

return constant

def diag(self):
res = self.base_linear_op.diag()
return res * self._constant.unsqueeze(-1)

@cached
def to_dense(self):
res = self.base_linear_op.to_dense()
Expand Down
10 changes: 3 additions & 7 deletions linear_operator/operators/dense_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(self, tsr):
def _cholesky_solve(self, rhs, upper=False):
return torch.cholesky_solve(rhs, self.to_dense(), upper=upper)

def _diagonal(self):
return self.tensor.diagonal(dim1=-1, dim2=-2)

def _expand_batch(self, batch_shape):
return self.__class__(self.tensor.expand(*batch_shape, *self.matrix_shape))

Expand Down Expand Up @@ -62,13 +65,6 @@ def _transpose_nonbatch(self):
def _t_matmul(self, rhs):
return torch.matmul(self.tensor.transpose(-1, -2), rhs)

def diag(self):
if self.tensor.ndimension() < 3:
return self.tensor.diag()
else:
row_col_iter = torch.arange(0, self.matrix_shape[-1], dtype=torch.long, device=self.device)
return self.tensor[..., row_col_iter, row_col_iter].view(*self.batch_shape, -1)

def to_dense(self):
return self.tensor

Expand Down
Loading

0 comments on commit caead4d

Please sign in to comment.