Skip to content

Commit

Permalink
quad_form_derivative -> bilinear_derivative
Browse files Browse the repository at this point in the history
Update bilinear_derivative docs.
  • Loading branch information
gpleiss committed May 25, 2022
1 parent 270be14 commit 00077a2
Show file tree
Hide file tree
Showing 23 changed files with 98 additions and 87 deletions.
6 changes: 3 additions & 3 deletions docs/source/custom_linear_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ the matrix that the LinearOperator represents)

In addition to these, the following methods should be implemented for maximum efficiency

* :meth:`~linear_operator.operators.LinearOperator._quad_form_derivative`,
which computes the derivative of a quadratic form with the LinearOperator
(e.g. :math:`d (\mathbf b^T \mathbf A \mathbf c) / d \mathbf A`).
* :meth:`~linear_operator.operators.LinearOperator._bilinear_derivative`,
which computes the derivative of a quadratic form with the LinearOperator's representation
(e.g. :math:`\partial (\mathbf b^T \mathbf A(\boldsymbol \theta) \mathbf c) / \partial \boldsymbol \theta`).
* :meth:`~linear_operator.operators.LinearOperator._get_indices`, which returns
a :class:`torch.Tensor` containing elements that are given by various tensor indices.
* :meth:`~linear_operator.operators.LinearOperator._expand_batch`, which
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/functions/_inv_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def backward(ctx, grad_output):
left_solves = InvMatmul.apply(ctx.representation_tree, False, grad_output, *matrix_args)

if any(ctx.needs_input_grad[3:]):
# We call _quad_form_derivative to compute dl/dK
# We call _bilinear_derivative to compute dl/dK
# To ensure that this term is symmetric, we concatenate the left and right solves together,
# and divide the result by 1/2
arg_grads = linear_op._quad_form_derivative(
arg_grads = linear_op._bilinear_derivative(
torch.cat([left_solves, right_solves], -1), torch.cat([right_solves, left_solves], -1).mul(-0.5)
)
if ctx.needs_input_grad[2]:
Expand All @@ -117,7 +117,7 @@ def backward(ctx, grad_output):
left_grad = grad_output @ right_solves.transpose(-1, -2)
if any(ctx.needs_input_grad[4:]):
# We do this concatenation to ensure that the gradient of linear_op is symmetric
arg_grads = linear_op._quad_form_derivative(
arg_grads = linear_op._bilinear_derivative(
torch.cat([left_solves, right_solves], -1), torch.cat([right_solves, left_solves], -1).mul(-0.5)
)
if ctx.needs_input_grad[3]:
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_inv_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def backward(ctx, inv_quad_grad_output):
if any(ctx.needs_input_grad[2:]):
left_factors = neg_inv_quad_solves_times_grad_out
right_factors = inv_quad_solves
matrix_arg_grads = linear_op._quad_form_derivative(left_factors, right_factors)
matrix_arg_grads = linear_op._bilinear_derivative(left_factors, right_factors)

# input_2 gradients
if ctx.needs_input_grad[1]:
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/functions/_inv_quad_logdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ def backward(ctx, inv_quad_grad_output, logdet_grad_output):

left_factors = torch.cat(left_factors_list, -1)
right_factors = torch.cat(right_factors_list, -1)
matrix_arg_grads = linear_op._quad_form_derivative(left_factors, right_factors)
matrix_arg_grads = linear_op._bilinear_derivative(left_factors, right_factors)

# precond gradient
precond_arg_grads = precond_lt._quad_form_derivative(
precond_arg_grads = precond_lt._bilinear_derivative(
-precond_probe_vectors * coef, precond_probe_vectors * logdet_grad_output
)

Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def backward(ctx, grad_output):
if any(ctx.needs_input_grad[2:]):
rhs = rhs.unsqueeze(-1) if (rhs.ndimension() == 1) else rhs
grad_output_matrix = grad_output.unsqueeze(-1) if grad_output.ndimension() == 1 else grad_output
arg_grads = ctx.representation_tree(*matrix_args)._quad_form_derivative(grad_output_matrix, rhs)
arg_grads = ctx.representation_tree(*matrix_args)._bilinear_derivative(grad_output_matrix, rhs)

# input_2 gradient
if ctx.needs_input_grad[1]:
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_root_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def is_empty(tensor):
else:
left_factor = left_factor.contiguous()
right_factor = right_factor.contiguous()
res = linear_op._quad_form_derivative(left_factor, right_factor)
res = linear_op._bilinear_derivative(left_factor, right_factor)

return tuple([None] * 9 + list(res))
else:
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/functions/_sqrt_inv_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def backward(ctx, sqrt_inv_matmul_grad, inv_quad_grad):
# Compute matrix grads
terms1 = torch.cat([lhs_no_shift_solves.unsqueeze(0), lhs_solves], 0)
terms2 = torch.cat([neg_inv_quad_solves_mul_grad.unsqueeze(0), weighted_rhs_solves_mul_grad], 0)
matrix_arg_grads = ctx.linear_op._quad_form_derivative(
matrix_arg_grads = ctx.linear_op._bilinear_derivative(
torch.cat([terms1, terms2], -1), torch.cat([terms2, terms1], -1).mul_(0.5)
)

Expand All @@ -94,7 +94,7 @@ def backward(ctx, sqrt_inv_matmul_grad, inv_quad_grad):
# Compute matrix grads
terms1 = grad_solves_mul_weights
terms2 = rhs_solves
matrix_arg_grads = ctx.linear_op._quad_form_derivative(
matrix_arg_grads = ctx.linear_op._bilinear_derivative(
torch.cat([terms1, terms2], -1), torch.cat([terms2, terms1], -1).mul_(0.5)
)

Expand Down
93 changes: 52 additions & 41 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,58 @@ def _unsqueeze_batch(self, dim: int) -> LinearOperator:
####
# The following methods PROBABLY should be over-written by LinearOperator subclasses for efficiency
####
def _bilinear_derivative(self, left_vecs: torch.Tensor, right_vecs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
r"""
Given :math:`\mathbf U` (left_vecs) and :math:`\mathbf V` (right_vecs),
Computes the derivatives of (:math:`\mathbf u^\top \mathbf K \mathbf v`) w.r.t. :math:`\mathbf K`.
Assume a :math:`\ldots x M X N` linear operator :math:`\mathbf K(\boldsymbol \theta)`,
represented by tensors/sub-operators :math:`\boldsymbol \theta`.
If :math:`\mathbf U \in \mathcal R^{\ldots \times M \times D}` and
:math:`\mathbf V \in \mathcal R^{\ldots \times N \times D}`, this function computes:
.. math::
\sum_{i=1}^D \frac{\partial \mathbf u_i^\top \mathbf K(\boldsymbol \theta) v_i}
{\partial \boldsymbol \theta}
Note that the columns of :math:`\mathbf U` and :math:`\mathbf V` are summed over.
.. note::
This method is intended to be used only internally by various
Functions that support backpropagation. For example, this method
is used internally by :func:`~linear_operator.LinearOperator.inv_quad_logdet`.
It is not likely that users will need to call this method directly.
:param left_vecs: The vectors :math:`\mathbf U = [\mathbf u_1, \ldots, \mathbf u_D]`
:param right_vecs: The vectors :math:`\mathbf V = [\mathbf v_1, \ldots, \mathbf v_D]`
:return: Derivative with respect to the arguments (:math:`\boldsymbol \theta`) that
represent this this LinearOperator.
"""
from collections import deque

args = tuple(self.representation())
args_with_grads = tuple(arg for arg in args if arg.requires_grad)

# Easy case: if we don't require any gradients, then just return!
if not len(args_with_grads):
return tuple(None for _ in args)

# Normal case: we'll use the autograd to get us a derivative
with torch.autograd.enable_grad():
loss = (left_vecs * self._matmul(right_vecs)).sum()
loss.requires_grad_(True)
actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True))

# Now make sure that the object we return has one entry for every item in args
grads = []
for arg in args:
if arg.requires_grad:
grads.append(actual_grads.popleft())
else:
grads.append(None)

return tuple(grads)

def _expand_batch(self, batch_shape: torch.Size) -> LinearOperator:
"""
Expands along batch dimensions. Return size will be *batch_shape x *matrix_shape.
Expand Down Expand Up @@ -318,47 +370,6 @@ def _get_indices(
)
return res

def _quad_form_derivative(self, left_vecs: torch.Tensor, right_vecs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
r"""
Given :math:`\mathbf u` (left_vecs) and :math:`\mathbf v` (right_vecs),
Computes the derivatives of (:math:`\mathbf u^\top \mathbf K \mathbf v`) w.r.t. :math:`\mathbf K`.
..note::
This method is intended to be used only internally by various
Functions that support backpropagation. For example, this method
is used internally by :func:`~linear_operator.LinearOperator.inv_quad_logdet`.
It is not likely that users will need to call this method directly.
:param left_vecs: The vectors :math:`\mathbf u`
:param right_vecs: The vectors :math:`\mathbf v`
:return: Derivative with respect to the arguments that are actually
used to represent this this LinearOperator.
"""
from collections import deque

args = tuple(self.representation())
args_with_grads = tuple(arg for arg in args if arg.requires_grad)

# Easy case: if we don't require any gradients, then just return!
if not len(args_with_grads):
return tuple(None for _ in args)

# Normal case: we'll use the autograd to get us a derivative
with torch.autograd.enable_grad():
loss = (left_vecs * self._matmul(right_vecs)).sum()
loss.requires_grad_(True)
actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True))

# Now make sure that the object we return has one entry for every item in args
grads = []
for arg in args:
if arg.requires_grad:
grads.append(actual_grads.popleft())
else:
grads.append(None)

return tuple(grads)

####
# Class definitions
####
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/operators/batch_repeat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _permute_batch(self, *dims):
res = self.__class__(self.base_linear_op._permute_batch(*dims), batch_repeat=new_batch_repeat)
return res

def _quad_form_derivative(self, left_vectors, right_vectors):
def _bilinear_derivative(self, left_vectors, right_vectors):
if self.is_square:
left_output_shape = _matmul_broadcast_shape(self.shape, left_vectors.shape)
if left_output_shape != left_vectors.shape:
Expand All @@ -196,9 +196,9 @@ def _quad_form_derivative(self, left_vectors, right_vectors):
left_vectors = self._move_repeat_batches_to_columns(left_vectors, left_output_shape)
right_vectors = self._move_repeat_batches_to_columns(right_vectors, right_output_shape)

return self.base_linear_op._quad_form_derivative(left_vectors, right_vectors)
return self.base_linear_op._bilinear_derivative(left_vectors, right_vectors)
else:
return super()._quad_form_derivative(left_vectors, right_vectors)
return super()._bilinear_derivative(left_vectors, right_vectors)

def _root_decomposition(self):
return self.base_linear_op._root_decomposition().repeat(*self.batch_repeat, 1, 1)
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/block_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _matmul(self, rhs):
res = res.squeeze(-1)
return res

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
if left_vecs.ndim == 1:
left_vecs = left_vecs.unsqueeze(-1)
right_vecs = right_vecs.unsqueeze(-1)
Expand All @@ -119,7 +119,7 @@ def _quad_form_derivative(self, left_vecs, right_vecs):
left_vecs = left_vecs.unsqueeze(-1)
left_vecs = self._add_batch_dim(left_vecs)
right_vecs = self._add_batch_dim(right_vecs)
res = self.base_linear_op._quad_form_derivative(left_vecs, right_vecs)
res = self.base_linear_op._bilinear_derivative(left_vecs, right_vecs)
return res

def _permute_batch(self, *dims):
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _permute_batch(self, *dims):
self.base_linear_op._permute_batch(*dims), self._constant.expand(self.batch_shape).permute(*dims)
)

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
# Gradient with respect to the constant
constant_deriv = left_vecs * self.base_linear_op._matmul(right_vecs)
constant_deriv = constant_deriv.sum(-2).sum(-1)
Expand All @@ -118,7 +118,7 @@ def _quad_form_derivative(self, left_vecs, right_vecs):

# Get derivaties of everything else
left_vecs = left_vecs * self.expanded_constant
res = self.base_linear_op._quad_form_derivative(left_vecs, right_vecs)
res = self.base_linear_op._bilinear_derivative(left_vecs, right_vecs)

return tuple(res) + (constant_deriv,)

Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/dense_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _matmul(self, rhs):
def _prod_batch(self, dim):
return self.__class__(self.tensor.prod(dim))

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
res = left_vecs.matmul(right_vecs.transpose(-1, -2))
return (res,)

Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _mul_matrix(self, other):
def _prod_batch(self, dim):
return self.__class__(self._diag.prod(dim))

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
# TODO: Use proper batching for input vectors (prepand to shape rathern than append)
if not self._diag.requires_grad:
return (None,)
Expand Down Expand Up @@ -257,7 +257,7 @@ def _mul_matrix(self, other):
def _prod_batch(self, dim):
return self.__class__(self.diag_values.prod(dim), diag_shape=self.diag_shape)

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
# TODO: Use proper batching for input vectors (prepand to shape rathern than append)
if not self.diag_values.requires_grad:
return (None,)
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/interpolated_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _t_matmul(self, rhs):
res = res.squeeze(-1)
return res

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
# Get sparse tensor representations of left/right interp matrices
left_interp_t = self._sparse_left_interp_t(self.left_interp_indices, self.left_interp_values)
right_interp_t = self._sparse_right_interp_t(self.right_interp_indices, self.right_interp_values)
Expand All @@ -238,7 +238,7 @@ def _quad_form_derivative(self, left_vecs, right_vecs):
# base_linear_op grad
left_res = sparse.bdsmm(left_interp_t, left_vecs)
right_res = sparse.bdsmm(right_interp_t, right_vecs)
base_lv_grad = list(self.base_linear_op._quad_form_derivative(left_res, right_res))
base_lv_grad = list(self.base_linear_op._bilinear_derivative(left_res, right_res))

# left_interp_values grad
n_vecs = right_res.size(-1)
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/keops_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def _getitem(self, row_index, col_index, *batch_indices):
# Now construct a kernel with those indices
return self.__class__(x1, x2, covar_func=self.covar_func, **self.params)

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
"""
Use default behavior, but KeOps does not automatically make args contiguous like torch.matmul.
This is necessary for variational GP models.
"""
return super()._quad_form_derivative(left_vecs.contiguous(), right_vecs.contiguous())
return super()._bilinear_derivative(left_vecs.contiguous(), right_vecs.contiguous())
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def _expand_batch(self, batch_shape):
def _mul_constant(self, constant):
return DiagLinearOperator(self._diag * constant.unsqueeze(-1))

def _quad_form_derivative(self, left_vecs, right_vecs):
return KroneckerProductTriangularLinearOperator._quad_form_derivative(self, left_vecs, right_vecs)
def _bilinear_derivative(self, left_vecs, right_vecs):
return KroneckerProductTriangularLinearOperator._bilinear_derivative(self, left_vecs, right_vecs)

def sqrt(self):
return self.__class__(*[lt.sqrt() for lt in self.linear_ops])
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/operators/matmul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def _matmul(self, right_linear_op):
def _t_matmul(self, right_linear_op):
return self.right_linear_op._t_matmul(self.left_linear_op._t_matmul(right_linear_op))

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
if left_vecs.ndimension() == 1:
left_vecs = left_vecs.unsqueeze(1)
right_vecs = right_vecs.unsqueeze(1)
right_vecs_times_right_linear_op = self.right_linear_op._matmul(right_vecs)
left_vecs_times_left_linear_op_t = self.left_linear_op._t_matmul(left_vecs)
left_grad = self.left_linear_op._quad_form_derivative(left_vecs, right_vecs_times_right_linear_op)
right_grad = self.right_linear_op._quad_form_derivative(left_vecs_times_left_linear_op_t, right_vecs)
left_grad = self.left_linear_op._bilinear_derivative(left_vecs, right_vecs_times_right_linear_op)
right_grad = self.right_linear_op._bilinear_derivative(left_vecs_times_left_linear_op_t, right_vecs)

left_grad = (left_grad,) if not isinstance(left_grad, tuple) else left_grad
right_grad = (right_grad,) if not isinstance(right_grad, tuple) else right_grad
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/operators/mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _mul_constant(self, constant):
res = super()._mul_constant(constant)
return res

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
if left_vecs.ndimension() == 1:
left_vecs = left_vecs.unsqueeze(1)
right_vecs = right_vecs.unsqueeze(1)
Expand All @@ -92,7 +92,7 @@ def _quad_form_derivative(self, left_vecs, right_vecs):

left_factor = left_factor.view(*batch_shape, n, num_vecs * right_rank)
right_factor = right_factor.view(*batch_shape, n, num_vecs * right_rank)
left_deriv_args = self.left_linear_op._quad_form_derivative(left_factor, right_factor)
left_deriv_args = self.left_linear_op._bilinear_derivative(left_factor, right_factor)

if isinstance(self.left_linear_op, RootLinearOperator):
left_root = self.left_linear_op.root.to_dense()
Expand All @@ -107,7 +107,7 @@ def _quad_form_derivative(self, left_vecs, right_vecs):

left_factor = left_factor.view(*batch_shape, n, num_vecs * left_rank)
right_factor = right_factor.view(*batch_shape, n, num_vecs * left_rank)
right_deriv_args = self.right_linear_op._quad_form_derivative(left_factor, right_factor)
right_deriv_args = self.right_linear_op._bilinear_derivative(left_factor, right_factor)

return tuple(list(left_deriv_args) + list(right_deriv_args))

Expand Down
4 changes: 2 additions & 2 deletions linear_operator/operators/sum_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def _mul_constant(self, other):
# We're using a custom method here - the constant mul is applied to the base_linear_ops
return self.__class__(*[lt._mul_constant(other) for lt in self.linear_ops])

def _quad_form_derivative(self, left_vecs, right_vecs):
def _bilinear_derivative(self, left_vecs, right_vecs):
return tuple(
var for linear_op in self.linear_ops for var in linear_op._quad_form_derivative(left_vecs, right_vecs)
var for linear_op in self.linear_ops for var in linear_op._bilinear_derivative(left_vecs, right_vecs)
)

def _size(self):
Expand Down
Loading

0 comments on commit 00077a2

Please sign in to comment.