Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.37】为 Paddle 新增 householder_product API -part #58214

Merged
merged 21 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
eigh,
eigvals,
eigvalsh,
householder_product,
lstsq,
cocoshe marked this conversation as resolved.
Show resolved Hide resolved
lu,
lu_unpack,
Expand Down Expand Up @@ -53,6 +54,7 @@
'matrix_rank',
'svd',
'qr',
'householder_product',
'pca_lowrank',
'lu',
'lu_unpack',
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
zeros,
zeros_like,
)

from .einsum import einsum # noqa: F401
from .linalg import ( # noqa: F401
bincount,
Expand All @@ -78,6 +79,7 @@
eigvals,
eigvalsh,
histogram,
householder_product,
lstsq,
lu,
lu_unpack,
Expand Down Expand Up @@ -435,6 +437,7 @@
'mv',
'matrix_power',
'qr',
'householder_product',
'pca_lowrank',
'eigvals',
'eigvalsh',
Expand Down
130 changes: 130 additions & 0 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3739,3 +3739,133 @@ def cdist(
return paddle.linalg.norm(
x[..., None, :] - y[..., None, :, :], p=p, axis=-1
)


def householder_product(x, tau, name=None):
r"""

Computes the first n columns of a product of Householder matrices.

This function can get the vector :math:`\omega_{i}` from matrix `x` (m x n), the :math:`i-1` elements are zeros, and the i-th is `1`, the rest of the elements are from i-th column of `x`.
And with the vector `tau` can calculate the first n columns of a product of Householder matrices.

:math:`H_i = I_m - \tau_i \omega_i \omega_i^H`

Args:
x (Tensor): A tensor with shape (*, m, n) where * is zero or more batch dimensions.
tau (Tensor): A tensor with shape (*, k) where * is zero or more batch dimensions.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Tensor, the dtype is same as input tensor, the Q in QR decomposition.

:math:`out = Q = H_1H_2H_3...H_k`

Examples:
.. code-block:: python

>>> import paddle
>>> x = paddle.to_tensor([[-1.1280, 0.9012, -0.0190],
... [ 0.3699, 2.2133, -1.4792],
... [ 0.0308, 0.3361, -3.1761],
... [-0.0726, 0.8245, -0.3812]])
>>> tau = paddle.to_tensor([1.7497, 1.1156, 1.7462])
>>> Q = paddle.linalg.householder_product(x, tau)
>>> print(Q)
Tensor(shape=[4, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[-0.74969995, -0.02181768, 0.31115776],
[-0.64721400, -0.12367040, -0.21738708],
[-0.05389076, -0.37562513, -0.84836429],
[ 0.12702821, -0.91822827, 0.36892807]])
"""

check_dtype(
x.dtype,
'x',
[
'float32',
'float64',
'complex64',
'complex128',
],
'householder_product',
)
check_dtype(
tau.dtype,
'tau',
[
'float32',
'float64',
'complex64',
'complex128',
],
'householder_product',
)
assert (
x.dtype == tau.dtype
), "The input x must have the same dtype with input tau.\n"
assert (
len(x.shape) >= 2
and len(tau.shape) >= 1
and len(x.shape) == len(tau.shape) + 1
), (
"The input x must have more than 2 dimensions, and input tau must have more than 1 dimension,"
"and the dimension of x is 1 larger than the dimension of tau\n"
)
assert (
x.shape[-2] >= x.shape[-1]
), "The rows of input x must be greater than or equal to the columns of input x.\n"
assert (
x.shape[-1] >= tau.shape[-1]
), "The last dim of x must be greater than tau.\n"
for idx, _ in enumerate(x.shape[:-2]):
assert (
x.shape[idx] == tau.shape[idx]
), "The input x must have the same batch dimensions with input tau.\n"

def _householder_product(x, tau):
m, n = x.shape[-2:]
k = tau.shape[-1]
Q = paddle.eye(m).astype(x.dtype)
for i in range(min(k, n)):
w = x[i:, i]
if in_dynamic_mode():
w[0] = 1
else:
w = paddle.static.setitem(w, 0, 1)
w = w.reshape([-1, 1])
if in_dynamic_mode():
if x.dtype in [paddle.complex128, paddle.complex64]:
Q[:, i:] = Q[:, i:] - (
Q[:, i:] @ w @ paddle.conj(w).T * tau[i]
)
else:
Q[:, i:] = Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i])
else:
Q = paddle.static.setitem(
Q,
(slice(None), slice(i, None)),
Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i])
if x.dtype in [paddle.complex128, paddle.complex64]
else Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]),
)
return Q[:, :n]

if len(x.shape) == 2:
return _householder_product(x, tau)
m, n = x.shape[-2:]
org_x_shape = x.shape
org_tau_shape = tau.shape
x = x.reshape((-1, org_x_shape[-2], org_x_shape[-1]))
tau = tau.reshape((-1, org_tau_shape[-1]))
n_batch = x.shape[0]
out = paddle.zeros([n_batch, m, n], dtype=x.dtype)
for i in range(n_batch):
if in_dynamic_mode():
out[i] = _householder_product(x[i], tau[i])
else:
out = paddle.static.setitem(
out, i, _householder_product(x[i], tau[i])
)
out = out.reshape(org_x_shape)
return out
Loading