diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index 3c1c1a24..969f29bb 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -558,12 +558,7 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator): return DenseLinearOperator(self.to_dense() * other.to_dense()) else: - left_linear_op = self if self._root_decomposition_size() < other._root_decomposition_size() else other - right_linear_op = other if left_linear_op is self else self - return MulLinearOperator( - left_linear_op.root_decomposition(), - right_linear_op.root_decomposition(), - ) + return MulLinearOperator(self, other) def _preconditioner(self) -> Tuple[Callable, "LinearOperator", torch.Tensor]: """ diff --git a/linear_operator/operators/block_diag_linear_operator.py b/linear_operator/operators/block_diag_linear_operator.py index 2df9b3f5..ba5717a2 100644 --- a/linear_operator/operators/block_diag_linear_operator.py +++ b/linear_operator/operators/block_diag_linear_operator.py @@ -153,8 +153,9 @@ def matmul(self, other): return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op) # special case if we have a DiagLinearOperator if isinstance(other, DiagLinearOperator): - diag_reshape = other._diag.view(*self.base_linear_op.shape[:-2], 1, -1) - return BlockDiagLinearOperator(self.base_linear_op * diag_reshape) + diag_reshape = other._diag.view(*self.base_linear_op.shape[:-1]) + diag = DiagLinearOperator(diag_reshape) + return BlockDiagLinearOperator(self.base_linear_op @ diag) return super().matmul(other) @cached(name="svd") diff --git a/linear_operator/operators/cat_linear_operator.py b/linear_operator/operators/cat_linear_operator.py index 180c183f..f3ef8ec4 100644 --- a/linear_operator/operators/cat_linear_operator.py +++ b/linear_operator/operators/cat_linear_operator.py @@ -366,6 +366,9 @@ def _unsqueeze_batch(self, dim): ) return res + def to_dense(self): + return torch.cat([to_dense(L) for L in self.linear_ops], dim=self.cat_dim) + def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True): res = super().inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad) return tuple(r.to(self.device) for r in res) diff --git a/linear_operator/operators/dense_linear_operator.py b/linear_operator/operators/dense_linear_operator.py index b884a804..5757b3d8 100644 --- a/linear_operator/operators/dense_linear_operator.py +++ b/linear_operator/operators/dense_linear_operator.py @@ -23,7 +23,7 @@ def __init__(self, tsr): Args: - tsr (Tensor: matrix) a Tensor """ - super(DenseLinearOperator, self).__init__(tsr) + super().__init__(tsr) self.tensor = tsr def _cholesky_solve(self, rhs, upper=False): @@ -76,13 +76,7 @@ def __add__(self, other): elif isinstance(other, torch.Tensor): return DenseLinearOperator(self.tensor + other) else: - return super(DenseLinearOperator, self).__add__(other) - - def mul(self, other): - if isinstance(other, DenseLinearOperator): - return DenseLinearOperator(self.tensor * other.tensor) - else: - return super(DenseLinearOperator, self).mul(other) + return super().__add__(other) def to_linear_operator(obj: Union[torch.Tensor, LinearOperator]) -> LinearOperator: diff --git a/linear_operator/operators/diag_linear_operator.py b/linear_operator/operators/diag_linear_operator.py index 607f3e7b..7496c082 100644 --- a/linear_operator/operators/diag_linear_operator.py +++ b/linear_operator/operators/diag_linear_operator.py @@ -7,7 +7,7 @@ from .. import settings from ..utils.memoize import cached -from ._linear_operator import LinearOperator +from ._linear_operator import LinearOperator, to_dense from .block_diag_linear_operator import BlockDiagLinearOperator from .dense_linear_operator import DenseLinearOperator from .triangular_linear_operator import TriangularLinearOperator @@ -65,16 +65,6 @@ def _get_indices( res = res * torch.eq(row_index, col_index).to(device=res.device, dtype=res.dtype) return res - def _matmul(self, rhs: Tensor) -> Tensor: - # to perform matrix multiplication with diagonal matrices we can just - # multiply element-wise with the diagonal (using proper broadcasting) - if rhs.ndimension() == 1: - return self._diag * rhs - # special case if we have a DenseLinearOperator - if isinstance(rhs, DenseLinearOperator): - return DenseLinearOperator(self._diag.unsqueeze(-1) * rhs.tensor) - return self._diag.unsqueeze(-1) * rhs - def _mul_constant(self, constant: Tensor) -> "DiagLinearOperator": return self.__class__(self._diag * constant.unsqueeze(-1)) @@ -164,20 +154,29 @@ def log(self) -> "DiagLinearOperator": return self.__class__(self._diag.log()) def matmul(self, other: Union[Tensor, LinearOperator]) -> Union[Tensor, LinearOperator]: - # this is trivial if we multiply two DiagLinearOperators if isinstance(other, DiagLinearOperator): return DiagLinearOperator(self._diag * other._diag) - # special case if we have a DenseLinearOperator - if isinstance(other, DenseLinearOperator): - return DenseLinearOperator(self._diag.unsqueeze(-1) * other.tensor) - # special case if we have a BlockDiagLinearOperator - if isinstance(other, BlockDiagLinearOperator): - diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1], 1) - return BlockDiagLinearOperator(diag_reshape * other.base_linear_op) + elif isinstance(other, BlockDiagLinearOperator): + diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1]) + diag = DiagLinearOperator(diag_reshape) + # using matmul here avoids having to implement special case of elementwise multiplication + # with block diagonal operator, which itself has special cases for vectors and matrices + return BlockDiagLinearOperator(diag @ other.base_linear_op) # special case if we have a TriangularLinearOperator - if isinstance(other, TriangularLinearOperator): - return TriangularLinearOperator(self._diag.unsqueeze(-1) * other._tensor, upper=other.upper) - return super().matmul(other) + elif isinstance(other, TriangularLinearOperator): + return TriangularLinearOperator(self @ other._tensor, upper=other.upper) + elif isinstance(other, DenseLinearOperator): + return DenseLinearOperator(self @ other.tensor) + else: + return super().matmul(other) + + def _matmul(self, other: Tensor) -> Tensor: + # to perform matrix multiplication with diagonal matrices we can just + # multiply element-wise with the diagonal (using proper broadcasting) + diag = self._diag + if other.ndimension() > 1: + diag = diag.unsqueeze(-1) + return diag * other def solve(self, right_tensor: Tensor, left_tensor: Optional[Tensor] = None) -> Tensor: res = self.inverse()._matmul(right_tensor) diff --git a/linear_operator/operators/mul_linear_operator.py b/linear_operator/operators/mul_linear_operator.py index 1ef05d16..cffee668 100644 --- a/linear_operator/operators/mul_linear_operator.py +++ b/linear_operator/operators/mul_linear_operator.py @@ -22,11 +22,14 @@ def __init__(self, left_linear_op, right_linear_op): Args: - linear_ops (A list of LinearOperator) - A list of LinearOperator to multiplicate with. """ + if left_linear_op._root_decomposition_size() < right_linear_op._root_decomposition_size(): + left_linear_op, right_linear_op = right_linear_op, left_linear_op + if not isinstance(left_linear_op, RootLinearOperator): left_linear_op = left_linear_op.root_decomposition() if not isinstance(right_linear_op, RootLinearOperator): right_linear_op = right_linear_op.root_decomposition() - super(MulLinearOperator, self).__init__(left_linear_op, right_linear_op) + super().__init__(left_linear_op, right_linear_op) self.left_linear_op = left_linear_op self.right_linear_op = right_linear_op @@ -62,7 +65,8 @@ def _matmul(self, rhs): left_res = left_res.view(*output_batch_shape, n, rank, m) res = left_res.mul_(left_root.unsqueeze(-1)).sum(-2) # This is the case where we're not doing a root decomposition, because the matrix is too small - else: + else: # Dead? + print("ALIVE!") res = (self.left_linear_op.to_dense() * self.right_linear_op.to_dense()).matmul(rhs) res = res.squeeze(-1) if is_vector else res return res