From b6669024466dde520ea9bfc0f5f648f5779af3b9 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 4 Dec 2023 09:29:12 +0100 Subject: [PATCH] Add abstractions for solving linear systems --- src/adam/casadi/casadi_like.py | 12 ++++++++++++ src/adam/core/spatial_math.py | 12 ++++++++++++ src/adam/jax/jax_like.py | 12 ++++++++++++ src/adam/numpy/numpy_like.py | 12 ++++++++++++ src/adam/pytorch/torch_like.py | 12 ++++++++++++ 5 files changed, 60 insertions(+) diff --git a/src/adam/casadi/casadi_like.py b/src/adam/casadi/casadi_like.py index ebd9f3fe..ffaa24ac 100644 --- a/src/adam/casadi/casadi_like.py +++ b/src/adam/casadi/casadi_like.py @@ -196,6 +196,18 @@ def vertcat(*x) -> "CasadiLike": y = [xi.array if isinstance(xi, CasadiLike) else xi for xi in x] return CasadiLike(cs.vertcat(*y)) + @staticmethod + def solve(A: "CasadiLike", b: "CasadiLike") -> "CasadiLike": + """ + Args: + A (CasadiLike): matrix + b (CasadiLike): vector + + Returns: + CasadiLike: solution of A*x=b + """ + return CasadiLike(cs.solve(A.array, b.array)) + if __name__ == "__main__": math = SpatialMath() diff --git a/src/adam/core/spatial_math.py b/src/adam/core/spatial_math.py index d74bc6d0..72db78d9 100644 --- a/src/adam/core/spatial_math.py +++ b/src/adam/core/spatial_math.py @@ -143,6 +143,18 @@ def cos(x: npt.ArrayLike) -> npt.ArrayLike: def skew(x): pass + @abc.abstractmethod + def solve(A: npt.ArrayLike, b: npt.ArrayLike) -> npt.ArrayLike: + """ + Args: + A (npt.ArrayLike): matrix + b (npt.ArrayLike): vector + + Returns: + npt.ArrayLike: solution of the linear system Ax=b + """ + pass + def R_from_axis_angle(self, axis: npt.ArrayLike, q: npt.ArrayLike) -> npt.ArrayLike: """ Args: diff --git a/src/adam/jax/jax_like.py b/src/adam/jax/jax_like.py index 83df661f..ceb5a123 100644 --- a/src/adam/jax/jax_like.py +++ b/src/adam/jax/jax_like.py @@ -199,3 +199,15 @@ def vertcat(*x) -> "JaxLike": else: v = jnp.vstack([x[i] for i in range(len(x))]).reshape(-1, 1) return JaxLike(v) + + @staticmethod + def solve(A: "JaxLike", b: "JaxLike") -> "JaxLike": + """ + Args: + A (JaxLike): Matrix + b (JaxLike): Vector + + Returns: + JaxLike: Solution of Ax=b + """ + return JaxLike(jnp.linalg.solve(A.array, b.array)) diff --git a/src/adam/numpy/numpy_like.py b/src/adam/numpy/numpy_like.py index 85ee2f78..e070c0da 100644 --- a/src/adam/numpy/numpy_like.py +++ b/src/adam/numpy/numpy_like.py @@ -201,3 +201,15 @@ def skew(x: Union["NumpyLike", npt.ArrayLike]) -> "NumpyLike": return -np.cross(np.array(x), np.eye(3), axisa=0, axisb=0) x = x.array return NumpyLike(-np.cross(np.array(x), np.eye(3), axisa=0, axisb=0)) + + @staticmethod + def solve(A: "NumpyLike", b: "NumpyLike") -> "NumpyLike": + """ + Args: + A (NumpyLike): matrix + b (NumpyLike): vector + + Returns: + NumpyLike: solution of Ax=b + """ + return NumpyLike(np.linalg.solve(A.array, b.array)) diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index a1afde00..ca5ccb86 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -213,3 +213,15 @@ def vertcat(*x: ntp.ArrayLike) -> "TorchLike": else: v = torch.FloatTensor(x).reshape(-1, 1) return TorchLike(v) + + @staticmethod + def solve(A: "TorchLike", b: "TorchLike") -> "TorchLike": + """ + Args: + A (TorchLike): matrix + b (TorchLike): vector + + Returns: + TorchLike: solution of Ax = b + """ + return TorchLike(torch.linalg.solve(A.array, b.array))