Skip to content

Commit

Permalink
Performant Scaling of BlockDiagLinearOperator by `DiagLinearOperato…
Browse files Browse the repository at this point in the history
…r` (#14)

* added special cases to diag matmul

* Update linear_operator/operators/block_diag_linear_operator.py

Co-authored-by: Geoff Pleiss <gpleiss@gmail.com>
  • Loading branch information
SebastianAment and gpleiss authored Sep 22, 2022
1 parent a58efe7 commit 3a37f0c
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 39 deletions.
7 changes: 1 addition & 6 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,7 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe
if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator):
return DenseLinearOperator(self.to_dense() * other.to_dense())
else:
left_linear_op = self if self._root_decomposition_size() < other._root_decomposition_size() else other
right_linear_op = other if left_linear_op is self else self
return MulLinearOperator(
left_linear_op.root_decomposition(),
right_linear_op.root_decomposition(),
)
return MulLinearOperator(self, other)

def _preconditioner(self) -> Tuple[Callable, "LinearOperator", torch.Tensor]:
"""
Expand Down
6 changes: 4 additions & 2 deletions linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,10 @@ def matmul(self, other):
return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op)
# special case if we have a DiagLinearOperator
if isinstance(other, DiagLinearOperator):
diag_reshape = other._diag.view(*self.base_linear_op.shape[:-2], 1, -1)
return BlockDiagLinearOperator(self.base_linear_op * diag_reshape)
# matmul is going to be cheap because of the special casing in DiagLinearOperator
diag_reshape = other._diag.view(*self.base_linear_op.shape[:-1])
diag = DiagLinearOperator(diag_reshape)
return BlockDiagLinearOperator(self.base_linear_op @ diag)
return super().matmul(other)

@cached(name="svd")
Expand Down
3 changes: 3 additions & 0 deletions linear_operator/operators/cat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def _unsqueeze_batch(self, dim):
)
return res

def to_dense(self):
return torch.cat([to_dense(L) for L in self.linear_ops], dim=self.cat_dim)

def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
res = super().inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad)
return tuple(r.to(self.device) for r in res)
Expand Down
10 changes: 2 additions & 8 deletions linear_operator/operators/dense_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, tsr):
Args:
- tsr (Tensor: matrix) a Tensor
"""
super(DenseLinearOperator, self).__init__(tsr)
super().__init__(tsr)
self.tensor = tsr

def _cholesky_solve(self, rhs, upper=False):
Expand Down Expand Up @@ -76,13 +76,7 @@ def __add__(self, other):
elif isinstance(other, torch.Tensor):
return DenseLinearOperator(self.tensor + other)
else:
return super(DenseLinearOperator, self).__add__(other)

def mul(self, other):
if isinstance(other, DenseLinearOperator):
return DenseLinearOperator(self.tensor * other.tensor)
else:
return super(DenseLinearOperator, self).mul(other)
return super().__add__(other)


def to_linear_operator(obj: Union[torch.Tensor, LinearOperator]) -> LinearOperator:
Expand Down
41 changes: 20 additions & 21 deletions linear_operator/operators/diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,6 @@ def _get_indices(
res = res * torch.eq(row_index, col_index).to(device=res.device, dtype=res.dtype)
return res

def _matmul(self, rhs: Tensor) -> Tensor:
# to perform matrix multiplication with diagonal matrices we can just
# multiply element-wise with the diagonal (using proper broadcasting)
if rhs.ndimension() == 1:
return self._diag * rhs
# special case if we have a DenseLinearOperator
if isinstance(rhs, DenseLinearOperator):
return DenseLinearOperator(self._diag.unsqueeze(-1) * rhs.tensor)
return self._diag.unsqueeze(-1) * rhs

def _mul_constant(self, constant: Tensor) -> "DiagLinearOperator":
return self.__class__(self._diag * constant.unsqueeze(-1))

Expand Down Expand Up @@ -164,20 +154,29 @@ def log(self) -> "DiagLinearOperator":
return self.__class__(self._diag.log())

def matmul(self, other: Union[Tensor, LinearOperator]) -> Union[Tensor, LinearOperator]:
# this is trivial if we multiply two DiagLinearOperators
if isinstance(other, DiagLinearOperator):
return DiagLinearOperator(self._diag * other._diag)
# special case if we have a DenseLinearOperator
if isinstance(other, DenseLinearOperator):
return DenseLinearOperator(self._diag.unsqueeze(-1) * other.tensor)
# special case if we have a BlockDiagLinearOperator
if isinstance(other, BlockDiagLinearOperator):
diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1], 1)
return BlockDiagLinearOperator(diag_reshape * other.base_linear_op)
elif isinstance(other, BlockDiagLinearOperator):
diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1])
diag = DiagLinearOperator(diag_reshape)
# using matmul here avoids having to implement special case of elementwise multiplication
# with block diagonal operator, which itself has special cases for vectors and matrices
return BlockDiagLinearOperator(diag @ other.base_linear_op)
# special case if we have a TriangularLinearOperator
if isinstance(other, TriangularLinearOperator):
return TriangularLinearOperator(self._diag.unsqueeze(-1) * other._tensor, upper=other.upper)
return super().matmul(other)
elif isinstance(other, TriangularLinearOperator):
return TriangularLinearOperator(self @ other._tensor, upper=other.upper)
elif isinstance(other, DenseLinearOperator):
return DenseLinearOperator(self @ other.tensor)
else:
return super().matmul(other)

def _matmul(self, other: Tensor) -> Tensor:
# to perform matrix multiplication with diagonal matrices we can just
# multiply element-wise with the diagonal (using proper broadcasting)
diag = self._diag
if other.ndimension() > 1:
diag = diag.unsqueeze(-1)
return diag * other

def solve(self, right_tensor: Tensor, left_tensor: Optional[Tensor] = None) -> Tensor:
res = self.inverse()._matmul(right_tensor)
Expand Down
7 changes: 5 additions & 2 deletions linear_operator/operators/mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def __init__(self, left_linear_op, right_linear_op):
Args:
- linear_ops (A list of LinearOperator) - A list of LinearOperator to multiplicate with.
"""
if left_linear_op._root_decomposition_size() < right_linear_op._root_decomposition_size():
left_linear_op, right_linear_op = right_linear_op, left_linear_op

if not isinstance(left_linear_op, RootLinearOperator):
left_linear_op = left_linear_op.root_decomposition()
if not isinstance(right_linear_op, RootLinearOperator):
right_linear_op = right_linear_op.root_decomposition()
super(MulLinearOperator, self).__init__(left_linear_op, right_linear_op)
super().__init__(left_linear_op, right_linear_op)
self.left_linear_op = left_linear_op
self.right_linear_op = right_linear_op

Expand Down Expand Up @@ -62,7 +65,7 @@ def _matmul(self, rhs):
left_res = left_res.view(*output_batch_shape, n, rank, m)
res = left_res.mul_(left_root.unsqueeze(-1)).sum(-2)
# This is the case where we're not doing a root decomposition, because the matrix is too small
else:
else: # Dead?
res = (self.left_linear_op.to_dense() * self.right_linear_op.to_dense()).matmul(rhs)
res = res.squeeze(-1) if is_vector else res
return res
Expand Down

0 comments on commit 3a37f0c

Please sign in to comment.