Skip to content

Commit

Permalink
Add abstractions for solving linear systems
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 4, 2023
1 parent 7644a2b commit b666902
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ def vertcat(*x) -> "CasadiLike":
y = [xi.array if isinstance(xi, CasadiLike) else xi for xi in x]
return CasadiLike(cs.vertcat(*y))

@staticmethod
def solve(A: "CasadiLike", b: "CasadiLike") -> "CasadiLike":
"""
Args:
A (CasadiLike): matrix
b (CasadiLike): vector
Returns:
CasadiLike: solution of A*x=b
"""
return CasadiLike(cs.solve(A.array, b.array))


if __name__ == "__main__":
math = SpatialMath()
Expand Down
12 changes: 12 additions & 0 deletions src/adam/core/spatial_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ def cos(x: npt.ArrayLike) -> npt.ArrayLike:
def skew(x):
pass

@abc.abstractmethod
def solve(A: npt.ArrayLike, b: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
A (npt.ArrayLike): matrix
b (npt.ArrayLike): vector
Returns:
npt.ArrayLike: solution of the linear system Ax=b
"""
pass

def R_from_axis_angle(self, axis: npt.ArrayLike, q: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
Expand Down
12 changes: 12 additions & 0 deletions src/adam/jax/jax_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,15 @@ def vertcat(*x) -> "JaxLike":
else:
v = jnp.vstack([x[i] for i in range(len(x))]).reshape(-1, 1)
return JaxLike(v)

@staticmethod
def solve(A: "JaxLike", b: "JaxLike") -> "JaxLike":
"""
Args:
A (JaxLike): Matrix
b (JaxLike): Vector
Returns:
JaxLike: Solution of Ax=b
"""
return JaxLike(jnp.linalg.solve(A.array, b.array))
12 changes: 12 additions & 0 deletions src/adam/numpy/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,15 @@ def skew(x: Union["NumpyLike", npt.ArrayLike]) -> "NumpyLike":
return -np.cross(np.array(x), np.eye(3), axisa=0, axisb=0)
x = x.array
return NumpyLike(-np.cross(np.array(x), np.eye(3), axisa=0, axisb=0))

@staticmethod
def solve(A: "NumpyLike", b: "NumpyLike") -> "NumpyLike":
"""
Args:
A (NumpyLike): matrix
b (NumpyLike): vector
Returns:
NumpyLike: solution of Ax=b
"""
return NumpyLike(np.linalg.solve(A.array, b.array))
12 changes: 12 additions & 0 deletions src/adam/pytorch/torch_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,15 @@ def vertcat(*x: ntp.ArrayLike) -> "TorchLike":
else:
v = torch.FloatTensor(x).reshape(-1, 1)
return TorchLike(v)

@staticmethod
def solve(A: "TorchLike", b: "TorchLike") -> "TorchLike":
"""
Args:
A (TorchLike): matrix
b (TorchLike): vector
Returns:
TorchLike: solution of Ax = b
"""
return TorchLike(torch.linalg.solve(A.array, b.array))

0 comments on commit b666902

Please sign in to comment.