Skip to content

Commit

Permalink
Rename *_gmt_matrix to *_mt_matrix for cases not specific to the …
Browse files Browse the repository at this point in the history
…geometric product (#419)

Co-authored-by: Hugo Hadfield <hadfield.hugo@gmail.com>
  • Loading branch information
eric-wieser and hugohadfield authored Jan 20, 2022
1 parent 8b1428c commit 1102eba
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
10 changes: 5 additions & 5 deletions clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def grades_present(objin: 'MultiVector', threshold=0.0000001) -> Set[int]:

# todo: work out how to let numba use the COO objects directly
@_numba_utils.njit
def _numba_val_get_left_gmt_matrix(x, k_list, l_list, m_list, mult_table_vals, ndims):
def _numba_val_get_left_mt_matrix(x, k_list, l_list, m_list, mult_table_vals, ndims):
# TODO: consider `dtype=result_type(x.dtype, mult_table_vals.dtype)`
intermed = np.zeros((ndims, ndims), dtype=x.dtype)
test_ind = 0
Expand All @@ -267,24 +267,24 @@ def _numba_val_get_left_gmt_matrix(x, k_list, l_list, m_list, mult_table_vals, n
return intermed


def val_get_left_gmt_matrix(mt: sparse.COO, x):
def val_get_left_mt_matrix(mt: sparse.COO, x):
"""
This produces the matrix X that performs left multiplication with x
eg. X@b == (x*b).value
"""
dims = mt.shape[1]
k_list, l_list, m_list = mt.coords
return _numba_val_get_left_gmt_matrix(
return _numba_val_get_left_mt_matrix(
x, k_list, l_list, m_list, mt.data, dims
)


def val_get_right_gmt_matrix(mt: sparse.COO, x):
def val_get_right_mt_matrix(mt: sparse.COO, x):
"""
This produces the matrix X that performs right multiplication with x
eg. X@b == (b*x).value
"""
return val_get_left_gmt_matrix(mt.T, x)
return val_get_left_mt_matrix(mt.T, x)


# TODO: Move this to the top once we remove circular imports
Expand Down
12 changes: 6 additions & 6 deletions clifford/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import clifford as cf
from . import (
get_mult_function,
val_get_left_gmt_matrix,
val_get_right_gmt_matrix,
_numba_val_get_left_gmt_matrix,
val_get_left_mt_matrix,
val_get_right_mt_matrix,
_numba_val_get_left_mt_matrix,
NUMBA_PARALLEL
)
from . import _numba_utils
Expand Down Expand Up @@ -684,7 +684,7 @@ def inv_func(self):

@_numba_utils.njit
def leftLaInvJIT(value):
intermed = _numba_val_get_left_gmt_matrix(value, k_list, l_list, m_list, mult_table_vals, n_dims)
intermed = _numba_val_get_left_mt_matrix(value, k_list, l_list, m_list, mult_table_vals, n_dims)
if abs(np.linalg.det(intermed)) < _settings._eps:
raise ValueError("multivector has no left-inverse")
sol = np.linalg.solve(intermed, identity.astype(intermed.dtype))
Expand All @@ -697,14 +697,14 @@ def get_left_gmt_matrix(self, x):
This produces the matrix X that performs left multiplication with x
eg. ``X@b == (x*b).value``
"""
return val_get_left_gmt_matrix(self.gmt, x.value)
return val_get_left_mt_matrix(self.gmt, x.value)

def get_right_gmt_matrix(self, x):
"""
This produces the matrix X that performs right multiplication with x
eg. ``X@b == (b*x).value``
"""
return val_get_right_gmt_matrix(self.gmt, x.value)
return val_get_right_mt_matrix(self.gmt, x.value)

def load_ga_file(self, filename: str) -> 'cf.MVArray':
"""
Expand Down
4 changes: 2 additions & 2 deletions clifford/tools/g3c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,11 @@ def left_gmt_generator(mt=layout.gmt):
k_list, l_list, m_list = mt.coords
mult_table_vals = mt.data
gaDims = mt.shape[1]
val_get_left_gmt_matrix = cf._numba_val_get_left_gmt_matrix
val_get_left_mt_matrix = cf._numba_val_get_left_mt_matrix

@numba.njit
def get_left_gmt(x_val):
return val_get_left_gmt_matrix(
return val_get_left_mt_matrix(
x_val, k_list, l_list, m_list, mult_table_vals, gaDims)
return get_left_gmt

Expand Down

0 comments on commit 1102eba

Please sign in to comment.