Skip to content

Commit

Permalink
fix: Fix frontends for torch.lu_factor, torch.lu_factor_ex and torch.lu
Browse files Browse the repository at this point in the history
  • Loading branch information
hmahmood24 committed Aug 31, 2024
1 parent 759ef5d commit 6aca19c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
1 change: 1 addition & 0 deletions ivy/functional/frontends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def promote_types_of_torch_inputs(
from . import utilities
from .utilities import *
from . import linalg
from .linalg import lu
from . import func
from .func import *

Expand Down
25 changes: 20 additions & 5 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,34 @@ def inv_ex(A, *, check_errors=False, out=None):
{"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
)
def lu_factor(A, *, pivot=True, out=None):
return ivy.lu_factor(A, pivot=pivot, out=out)
LU, pivots = ivy.lu_factor(A, pivot=pivot, out=out)
lu_factor_tuple = namedtuple("linalg_lu_factor", ["LU", "pivots"])
return lu_factor_tuple(
LU=LU,
pivots=pivots,
)


@to_ivy_arrays_and_back
def lu_factor_ex(A, *, pivot=True, check_errors=False, out=None):
try:
LU = ivy.lu_factor(A, pivot=pivot, out=out)
# Perform LU factorization and get the result as a named tuple
LU, pivots = ivy.lu_factor(A, pivot=pivot, out=out)
info = ivy.zeros(A.shape[:-2], dtype=ivy.int32)
return LU, info
except RuntimeError as e:
if check_errors:
raise RuntimeError(e) from e
else:
matrix = A * math.nan
# If there's an error and check_errors is False, handle the error
LU = ivy.full_like(A, math.nan)
pivots = ivy.full_like(A.shape[:-1], math.nan)
info = ivy.ones(A.shape[:-2], dtype=ivy.int32)
return matrix, info

# Create a named tuple for the final result
lu_factor_ex_tuple = namedtuple("linalg_lu_factor_ex", ["LU", "pivots", "info"])

# Return the results
return lu_factor_ex_tuple(LU=LU, pivots=pivots, info=info)


def lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None):
Expand Down Expand Up @@ -436,3 +448,6 @@ def vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None):
return ivy.vector_norm(
input, axis=dim, keepdims=keepdim, ord=ord, out=out, dtype=dtype
)


lu = lu_factor

0 comments on commit 6aca19c

Please sign in to comment.