Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add option to choose dense solver to div #357

Merged
merged 1 commit into from
Apr 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions pylops/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import numpy as np
import scipy as sp
from scipy.linalg import eigvals, lstsq, solve
from numpy.linalg import solve as np_solve
from scipy.linalg import eigvals, lstsq
from scipy.linalg import solve as sp_solve
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import LinearOperator as spLinearOperator
from scipy.sparse.linalg import eigs as sp_eigs
Expand All @@ -19,6 +21,7 @@
from scipy.sparse.linalg.interface import _ProductLinearOperator
else:
from scipy.sparse.linalg._interface import _ProductLinearOperator

from pylops.optimization.solver import cgls
from pylops.utils.backend import get_array_module, get_module, get_sparse_eye
from pylops.utils.estimators import trace_hutchinson, trace_hutchpp, trace_nahutchpp
Expand Down Expand Up @@ -260,7 +263,7 @@ def dot(self, x):
else:
raise ValueError("expected 1-d or 2-d array or matrix, got %r" % x)

def div(self, y, niter=100):
def div(self, y, niter=100, densesolver="scipy"):
r"""Solve the linear problem :math:`\mathbf{y}=\mathbf{A}\mathbf{x}`.

Overloading of operator ``/`` to improve expressivity of `Pylops`
Expand All @@ -272,17 +275,19 @@ def div(self, y, niter=100):
Data
niter : :obj:`int`, optional
Number of iterations (to be used only when ``explicit=False``)
densesolver : :obj:`str`, optional
Use scipy (``scipy``) or numpy (``numpy``) dense solver

Returns
-------
xest : :obj:`np.ndarray`
Estimated model

"""
xest = self.__truediv__(y, niter=niter)
xest = self.__truediv__(y, niter=niter, densesolver=densesolver)
return xest

def __truediv__(self, y, niter=100):
def __truediv__(self, y, niter=100, densesolver="scipy"):
if self.explicit is True:
if sp.sparse.issparse(self.A):
# use scipy solver for sparse matrices
Expand All @@ -291,7 +296,10 @@ def __truediv__(self, y, niter=100):
# use scipy solvers for dense matrices (used for backward
# compatibility, could be switched to numpy equivalents)
if self.A.shape[0] == self.A.shape[1]:
xest = solve(self.A, y)
if densesolver == "scipy":
xest = sp_solve(self.A, y)
else:
xest = np_solve(self.A, y)
else:
xest = lstsq(self.A, y)[0]
else:
Expand Down