diff --git a/ivy/functional/frontends/torch/__init__.py b/ivy/functional/frontends/torch/__init__.py index 721fe7bd6214..eb05e3c7ce16 100644 --- a/ivy/functional/frontends/torch/__init__.py +++ b/ivy/functional/frontends/torch/__init__.py @@ -309,6 +309,7 @@ def promote_types_of_torch_inputs( from . import utilities from .utilities import * from . import linalg +from .linalg import lu from . import func from .func import * diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index ed3ca7978e59..34bbedd006e5 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -138,22 +138,34 @@ def inv_ex(A, *, check_errors=False, out=None): {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" ) def lu_factor(A, *, pivot=True, out=None): - return ivy.lu_factor(A, pivot=pivot, out=out) + LU, pivots = ivy.lu_factor(A, pivot=pivot, out=out) + lu_factor_tuple = namedtuple("linalg_lu_factor", ["LU", "pivots"]) + return lu_factor_tuple( + LU=LU, + pivots=pivots, + ) @to_ivy_arrays_and_back def lu_factor_ex(A, *, pivot=True, check_errors=False, out=None): try: - LU = ivy.lu_factor(A, pivot=pivot, out=out) + # Perform LU factorization and get the result as a named tuple + LU, pivots = ivy.lu_factor(A, pivot=pivot, out=out) info = ivy.zeros(A.shape[:-2], dtype=ivy.int32) - return LU, info except RuntimeError as e: if check_errors: raise RuntimeError(e) from e else: - matrix = A * math.nan + # If there's an error and check_errors is False, handle the error + LU = ivy.full_like(A, math.nan) + pivots = ivy.full_like(A.shape[:-1], math.nan) info = ivy.ones(A.shape[:-2], dtype=ivy.int32) - return matrix, info + + # Create a named tuple for the final result + lu_factor_ex_tuple = namedtuple("linalg_lu_factor_ex", ["LU", "pivots", "info"]) + + # Return the results + return lu_factor_ex_tuple(LU=LU, pivots=pivots, info=info) def lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None): @@ -436,3 +448,6 @@ def vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None): return ivy.vector_norm( input, axis=dim, keepdims=keepdim, ord=ord, out=out, dtype=dtype ) + + +lu = lu_factor