Skip to content

Commit

Permalink
Merge pull request #357 from mrava87/npsolve_v1
Browse files Browse the repository at this point in the history
feature: add option to choose dense solver to div
  • Loading branch information
cako authored Apr 2, 2022
2 parents c8e1155 + 04030db commit 0a4c365
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions pylops/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import numpy as np
import scipy as sp
from scipy.linalg import eigvals, lstsq, solve
from numpy.linalg import solve as np_solve
from scipy.linalg import eigvals, lstsq
from scipy.linalg import solve as sp_solve
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import LinearOperator as spLinearOperator
from scipy.sparse.linalg import eigs as sp_eigs
Expand All @@ -19,6 +21,7 @@
from scipy.sparse.linalg.interface import _ProductLinearOperator
else:
from scipy.sparse.linalg._interface import _ProductLinearOperator

from pylops.optimization.solver import cgls
from pylops.utils.backend import get_array_module, get_module, get_sparse_eye
from pylops.utils.estimators import trace_hutchinson, trace_hutchpp, trace_nahutchpp
Expand Down Expand Up @@ -260,7 +263,7 @@ def dot(self, x):
else:
raise ValueError("expected 1-d or 2-d array or matrix, got %r" % x)

def div(self, y, niter=100):
def div(self, y, niter=100, densesolver="scipy"):
r"""Solve the linear problem :math:`\mathbf{y}=\mathbf{A}\mathbf{x}`.
Overloading of operator ``/`` to improve expressivity of `Pylops`
Expand All @@ -272,17 +275,19 @@ def div(self, y, niter=100):
Data
niter : :obj:`int`, optional
Number of iterations (to be used only when ``explicit=False``)
densesolver : :obj:`str`, optional
Use scipy (``scipy``) or numpy (``numpy``) dense solver
Returns
-------
xest : :obj:`np.ndarray`
Estimated model
"""
xest = self.__truediv__(y, niter=niter)
xest = self.__truediv__(y, niter=niter, densesolver=densesolver)
return xest

def __truediv__(self, y, niter=100):
def __truediv__(self, y, niter=100, densesolver="scipy"):
if self.explicit is True:
if sp.sparse.issparse(self.A):
# use scipy solver for sparse matrices
Expand All @@ -291,7 +296,10 @@ def __truediv__(self, y, niter=100):
# use scipy solvers for dense matrices (used for backward
# compatibility, could be switched to numpy equivalents)
if self.A.shape[0] == self.A.shape[1]:
xest = solve(self.A, y)
if densesolver == "scipy":
xest = sp_solve(self.A, y)
else:
xest = np_solve(self.A, y)
else:
xest = lstsq(self.A, y)[0]
else:
Expand Down

0 comments on commit 0a4c365

Please sign in to comment.