Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performant Scaling of BlockDiagLinearOperator by DiagLinearOperator #14

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,7 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe
if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was the source of the performance regressions. The reasoning behind it appears to be because MulLinearOperator always computes a root decomposition, which is both inefficient and introduces dead code in its implementation (see below). I am sidestepping this by replacing the * with secondary @ operators in the new special cases of the DiagLinearOperator and BlockDiagLinearOperator matmul methods, leading to MatmulLinearOperators instead.

However, this does not get rid of the more general issue. To fix that, I propose two steps in a future PR:

  1. Introducing logic in the constructor of MulLinearOperator that decides whether or not to build a root decomposition.
  2. Even if a root decomposition seems beneficial, delaying its computation until the very last moment when it is needed in matmul and cache the result. This will give us ~0 overhead in the case where the linear operator represents a posterior covariance matrix that is constructed via a posterior call but only the posterior mean is needed, as was the case in the notebook that exhibited the regression.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change should be fine, since the MulLinearOpeator constructor performs root decompositions on left_linear_op and right_linear_op:

if not isinstance(left_linear_op, RootLinearOperator):

Even if a root decomposition seems beneficial, delaying its computation until the very last moment when it is needed in matmul and cache the result.

Agreed.

return DenseLinearOperator(self.to_dense() * other.to_dense())
else:
left_linear_op = self if self._root_decomposition_size() < other._root_decomposition_size() else other
right_linear_op = other if left_linear_op is self else self
return MulLinearOperator(
left_linear_op.root_decomposition(),
right_linear_op.root_decomposition(),
)
return MulLinearOperator(self, other)

def _preconditioner(self) -> Tuple[Callable, "LinearOperator", torch.Tensor]:
"""
Expand Down
6 changes: 4 additions & 2 deletions linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,10 @@ def matmul(self, other):
return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op)
# special case if we have a DiagLinearOperator
if isinstance(other, DiagLinearOperator):
diag_reshape = other._diag.view(*self.base_linear_op.shape[:-2], 1, -1)
return BlockDiagLinearOperator(self.base_linear_op * diag_reshape)
# matmul is going to be cheap because of the special casing in DiagLinearOperator
diag_reshape = other._diag.view(*self.base_linear_op.shape[:-1])
diag = DiagLinearOperator(diag_reshape)
return BlockDiagLinearOperator(self.base_linear_op @ diag)
return super().matmul(other)

@cached(name="svd")
Expand Down
3 changes: 3 additions & 0 deletions linear_operator/operators/cat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def _unsqueeze_batch(self, dim):
)
return res

def to_dense(self):
return torch.cat([to_dense(L) for L in self.linear_ops], dim=self.cat_dim)

def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
res = super().inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad)
return tuple(r.to(self.device) for r in res)
Expand Down
10 changes: 2 additions & 8 deletions linear_operator/operators/dense_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, tsr):
Args:
- tsr (Tensor: matrix) a Tensor
"""
super(DenseLinearOperator, self).__init__(tsr)
super().__init__(tsr)
self.tensor = tsr

def _cholesky_solve(self, rhs, upper=False):
Expand Down Expand Up @@ -76,13 +76,7 @@ def __add__(self, other):
elif isinstance(other, torch.Tensor):
return DenseLinearOperator(self.tensor + other)
else:
return super(DenseLinearOperator, self).__add__(other)

def mul(self, other):
if isinstance(other, DenseLinearOperator):
return DenseLinearOperator(self.tensor * other.tensor)
else:
return super(DenseLinearOperator, self).mul(other)
return super().__add__(other)


def to_linear_operator(obj: Union[torch.Tensor, LinearOperator]) -> LinearOperator:
Expand Down
41 changes: 20 additions & 21 deletions linear_operator/operators/diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,6 @@ def _get_indices(
res = res * torch.eq(row_index, col_index).to(device=res.device, dtype=res.dtype)
return res

def _matmul(self, rhs: Tensor) -> Tensor:
# to perform matrix multiplication with diagonal matrices we can just
# multiply element-wise with the diagonal (using proper broadcasting)
if rhs.ndimension() == 1:
return self._diag * rhs
# special case if we have a DenseLinearOperator
if isinstance(rhs, DenseLinearOperator):
return DenseLinearOperator(self._diag.unsqueeze(-1) * rhs.tensor)
return self._diag.unsqueeze(-1) * rhs

def _mul_constant(self, constant: Tensor) -> "DiagLinearOperator":
return self.__class__(self._diag * constant.unsqueeze(-1))

Expand Down Expand Up @@ -164,20 +154,29 @@ def log(self) -> "DiagLinearOperator":
return self.__class__(self._diag.log())

def matmul(self, other: Union[Tensor, LinearOperator]) -> Union[Tensor, LinearOperator]:
# this is trivial if we multiply two DiagLinearOperators
if isinstance(other, DiagLinearOperator):
return DiagLinearOperator(self._diag * other._diag)
# special case if we have a DenseLinearOperator
if isinstance(other, DenseLinearOperator):
return DenseLinearOperator(self._diag.unsqueeze(-1) * other.tensor)
# special case if we have a BlockDiagLinearOperator
if isinstance(other, BlockDiagLinearOperator):
diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1], 1)
return BlockDiagLinearOperator(diag_reshape * other.base_linear_op)
elif isinstance(other, BlockDiagLinearOperator):
diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1])
diag = DiagLinearOperator(diag_reshape)
# using matmul here avoids having to implement special case of elementwise multiplication
# with block diagonal operator, which itself has special cases for vectors and matrices
return BlockDiagLinearOperator(diag @ other.base_linear_op)
# special case if we have a TriangularLinearOperator
if isinstance(other, TriangularLinearOperator):
return TriangularLinearOperator(self._diag.unsqueeze(-1) * other._tensor, upper=other.upper)
return super().matmul(other)
elif isinstance(other, TriangularLinearOperator):
return TriangularLinearOperator(self @ other._tensor, upper=other.upper)
elif isinstance(other, DenseLinearOperator):
return DenseLinearOperator(self @ other.tensor)
else:
return super().matmul(other)

def _matmul(self, other: Tensor) -> Tensor:
# to perform matrix multiplication with diagonal matrices we can just
# multiply element-wise with the diagonal (using proper broadcasting)
diag = self._diag
if other.ndimension() > 1:
diag = diag.unsqueeze(-1)
return diag * other

def solve(self, right_tensor: Tensor, left_tensor: Optional[Tensor] = None) -> Tensor:
res = self.inverse()._matmul(right_tensor)
Expand Down
7 changes: 5 additions & 2 deletions linear_operator/operators/mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def __init__(self, left_linear_op, right_linear_op):
Args:
- linear_ops (A list of LinearOperator) - A list of LinearOperator to multiplicate with.
"""
if left_linear_op._root_decomposition_size() < right_linear_op._root_decomposition_size():
left_linear_op, right_linear_op = right_linear_op, left_linear_op

if not isinstance(left_linear_op, RootLinearOperator):
left_linear_op = left_linear_op.root_decomposition()
if not isinstance(right_linear_op, RootLinearOperator):
right_linear_op = right_linear_op.root_decomposition()
super(MulLinearOperator, self).__init__(left_linear_op, right_linear_op)
super().__init__(left_linear_op, right_linear_op)
self.left_linear_op = left_linear_op
self.right_linear_op = right_linear_op

Expand Down Expand Up @@ -62,7 +65,7 @@ def _matmul(self, rhs):
left_res = left_res.view(*output_batch_shape, n, rank, m)
res = left_res.mul_(left_root.unsqueeze(-1)).sum(-2)
# This is the case where we're not doing a root decomposition, because the matrix is too small
else:
else: # Dead?
res = (self.left_linear_op.to_dense() * self.right_linear_op.to_dense()).matmul(rhs)
res = res.squeeze(-1) if is_vector else res
return res
Expand Down