diff --git a/clifford/_blademap.py b/clifford/_blademap.py index 72c1e4b1..dd863c72 100644 --- a/clifford/_blademap.py +++ b/clifford/_blademap.py @@ -25,8 +25,8 @@ def __init__(self, blades_map, map_scalars=True): if map_scalars: # make scalars in each algebra map - s1 = self.b1[0]._newMV(dtype=int)+1 - s2 = self.b2[0]._newMV(dtype=int)+1 + s1 = self.b1[0].of_zero(dtype=int)+1 + s2 = self.b2[0].of_zero(dtype=int)+1 self.blades_map = [(s1, s2)] + self.blades_map @property @@ -60,7 +60,7 @@ def __call__(self, A): raise ValueError('A doesnt belong to either Algebra in this Map') # create empty MV, and map values - B = to_b[0]._newMV(dtype=int) + B = to_b[0].of_zero(dtype=int) for from_obj, to_obj in zip(from_b, to_b): B += (sum(A.value*from_obj.value)*to_obj) return B diff --git a/clifford/_layout.py b/clifford/_layout.py index 7b0ce45c..2f1b4502 100644 --- a/clifford/_layout.py +++ b/clifford/_layout.py @@ -484,9 +484,7 @@ def __eq__(self, other): def parse_multivector(self, mv_string: str) -> MultiVector: """ Parses a multivector string into a MultiVector object """ - # guarded import in case the parse become heavier weight - from ._parser import parse_multivector - return parse_multivector(self, mv_string) + return MultiVector.from_string(self, mv_string) def _genTables(self): "Generate the multiplication tables." diff --git a/clifford/_multivector.py b/clifford/_multivector.py index 3cfb4d23..641839e2 100644 --- a/clifford/_multivector.py +++ b/clifford/_multivector.py @@ -1,5 +1,6 @@ import numbers import math +import types from typing import List, Set, Tuple import numpy as np @@ -9,25 +10,23 @@ from . import _settings -class MultiVector(object): - """An element of the algebra - - Parameters - ------------- - layout: instance of :class:`clifford.Layout` - The layout of the algebra - - value : sequence of length ``layout.gaDims`` - The coefficients of the base blades +class _layout_binding_classmethod(classmethod): + """ + Helper to pass the `layout` argument automatically to class methods when + called on instances. + """ + def __get__(self, instance, owner): + base = classmethod.__get__(self, instance, owner) + if instance is not None: + # methodtype binds the next unbound argument, + return types.MethodType(base, instance.layout) + else: + return base - dtype : numpy.dtype - The datatype to use for the multivector, if no - value was passed. - .. versionadded:: 1.1.0 +class MultiVector(object): + """An element of the algebra - Notes - ------ The following operators are overloaded: * ``A * B`` : geometric product @@ -40,22 +39,86 @@ class MultiVector(object): """ __array_priority__ = 100 - def __init__(self, layout, value=None, string=None, *, dtype: np.dtype = np.float64) -> None: - """Constructor.""" + @_layout_binding_classmethod + def of_zero(cls, layout, *, dtype=np.float64): + """ Construct a multivector initialized to zero with the specified dtype. + + Can be called either as ``MultiVector.of_zero(layout, ...)``, or + reusing the layout of an existing multivector as ``some_mv.of_zero(...)``. + .. versionadded:: 1.3.0 + + Parameters + ---------- + layout: instance of :class:`clifford.Layout` + The layout of the algebra + dtype : numpy.dtype + The datatype to use for the multivector, if no + value was passed. + """ + self = super().__new__(cls) self.layout = layout + self.value = np.zeros((self.layout.gaDims,), dtype=dtype) + return self - if value is None: - if string is None: - self.value = np.zeros((self.layout.gaDims,), dtype=dtype) - else: - self.value = layout.parse_multivector(string).value + @_layout_binding_classmethod + def from_string(cls, layout, string): + """ Create a multivector from a string representation. + + Can be called either as ``MultiVector.from_string(layout, ...)``, or + reusing the layout of an existing multivector as ``some_mv.from_string(...)``. + + .. versionadded:: 1.3.0 + + Parameters + ---------- + layout: instance of :class:`clifford.Layout` + The layout of the algebra + string : str + The datatype to use for the multivector, if no + value was passed. + """ + # guarded import in case the parse become heavier weight + from ._parser import parse_multivector + return parse_multivector(layout, string, cls) + + @_layout_binding_classmethod + def from_value(cls, layout, value): + """ Construct a multivector from an existing array, optionally copying. + + Can be called either as ``MultiVector.from_value(layout, ...)``, or + reusing the layout of an existing multivector as ``some_mv.from_value(...)``. + + .. versionadded:: 1.3.0 + + Parameters + ---------- + layout: instance of :class:`clifford.Layout` + The layout of the algebra. + value : sequence of length ``layout.gaDims`` + The coefficients of the base blades. + """ + if value.shape != (layout.gaDims,): + raise ValueError( + "value must be a sequence of length %s" % + layout.gaDims) + self = super().__new__(cls) + self.layout = layout + self.value = np.array(value) + return self + + + def __new__(cls, layout, *args, **kwargs) -> 'MultiVector': + """ + Shorthand for :meth:`from_value`, :meth:`from_string`, or :meth:`of_zero`. + + The appropriate function is chosen for the arguments given. """ + if args or 'value' in kwargs: + return cls.from_value(layout, *args, **kwargs) + elif 'string' in kwargs: + return cls.from_string(layout, *args, **kwargs) else: - self.value = np.array(value) - if self.value.shape != (self.layout.gaDims,): - raise ValueError( - "value must be a sequence of length %s" % - self.layout.gaDims) + return cls.of_zero(layout, *args, **kwargs) def __array__(self) -> 'cf.MVArray': # we are a scalar, and the only appropriate dtype is an object array @@ -76,7 +139,7 @@ def _checkOther(self, other, coerce=True) -> Tuple['MultiVector', bool]: elif isinstance(other, numbers.Number): if coerce: # numeric scalar - newOther = self._newMV(dtype=np.result_type(other)) + newOther = self.of_zero(dtype=np.result_type(other)) newOther[()] = other return newOther, True else: @@ -85,14 +148,6 @@ def _checkOther(self, other, coerce=True) -> Tuple['MultiVector', bool]: else: return other, False - def _newMV(self, newValue=None, *, dtype: np.dtype = None) -> 'MultiVector': - """Returns a new MultiVector (or derived class instance). - """ - if newValue is None and dtype is None: - raise TypeError("Must specify either a type or value") - - return self.__class__(self.layout, newValue, dtype=dtype) - # numeric special methods # binary @@ -136,7 +191,7 @@ def __mul__(self, other) -> 'MultiVector': newValue = other * self.value - return self._newMV(newValue) + return self.from_value(newValue) def __rmul__(self, other) -> 'MultiVector': """Right-hand geometric product, :math:`NM`""" @@ -151,7 +206,7 @@ def __rmul__(self, other) -> 'MultiVector': return other*obj newValue = other*self.value - return self._newMV(newValue) + return self.from_value(newValue) def __xor__(self, other) -> 'MultiVector': r""" Outer product, :math:`M \wedge N` """ @@ -166,7 +221,7 @@ def __xor__(self, other) -> 'MultiVector': return obj^other newValue = other*self.value - return self._newMV(newValue) + return self.from_value(newValue) def __rxor__(self, other) -> 'MultiVector': r"""Right-hand outer product, :math:`N \wedge M` """ @@ -181,7 +236,7 @@ def __rxor__(self, other) -> 'MultiVector': return other^obj newValue = other * self.value - return self._newMV(newValue) + return self.from_value(newValue) def __or__(self, other) -> 'MultiVector': r""" Inner product, :math:`M \cdot N` """ @@ -195,9 +250,9 @@ def __or__(self, other) -> 'MultiVector': obj = self.__array__() return obj|other # l * M = M * l = 0 for scalar l - return self._newMV(dtype=np.result_type(self.value.dtype, other)) + return self.of_zero(dtype=np.result_type(self.value.dtype, other)) - return self._newMV(newValue) + return self.from_value(newValue) __ror__ = __or__ @@ -214,7 +269,7 @@ def __add__(self, other) -> 'MultiVector': return obj + other newValue = self.value + other.value - return self._newMV(newValue) + return self.from_value(newValue) __radd__ = __add__ @@ -231,7 +286,7 @@ def __sub__(self, other) -> 'MultiVector': return obj - other newValue = self.value - other.value - return self._newMV(newValue) + return self.from_value(newValue) def __rsub__(self, other) -> 'MultiVector': """Right-hand subtraction @@ -246,7 +301,7 @@ def __rsub__(self, other) -> 'MultiVector': return other - obj newValue = other.value - self.value - return self._newMV(newValue) + return self.from_value(newValue) def right_complement(self) -> 'MultiVector': return self.layout.MultiVector(value=self.layout.right_complement_func(self.value)) @@ -266,7 +321,7 @@ def __truediv__(self, other) -> 'MultiVector': obj = self.__array__() return obj/other newValue = self.value / other - return self._newMV(newValue) + return self.from_value(newValue) def __rtruediv__(self, other) -> 'MultiVector': """Right-hand division, :math:`N M^{-1}`""" @@ -290,9 +345,9 @@ def __pow__(self, other) -> 'MultiVector': other = int(round(other)) if other == 0: - return self._newMV(dtype=self.value.dtype) + 1 + return self.of_zero(dtype=self.value.dtype) + 1 - newMV = self._newMV(np.array(self.value)) # copy + newMV = self.from_value(np.array(self.value)) # copy for i in range(1, other): newMV = newMV * self @@ -323,7 +378,7 @@ def __neg__(self) -> 'MultiVector': newValue = -self.value - return self._newMV(newValue) + return self.from_value(newValue) def as_array(self) -> np.ndarray: return self.value @@ -333,7 +388,7 @@ def __pos__(self) -> 'MultiVector': newValue = self.value + 0 # copy - return self._newMV(newValue) + return self.from_value(newValue) def mag2(self) -> numbers.Number: """Magnitude (modulus) squared, :math:`{|M|}^2` @@ -362,7 +417,7 @@ def adjoint(self) -> 'MultiVector': Note that ``~(N * M) == ~M * ~N``. """ # The multivector created by reversing all multiplications - return self._newMV(self.layout.adjoint_func(self.value)) + return self.from_value(self.layout.adjoint_func(self.value)) __invert__ = adjoint @@ -456,7 +511,7 @@ def __call__(self, other, *others) -> 'MultiVector': newValue = np.multiply(mask, self.value) - return self._newMV(newValue) + return self.from_value(newValue) # fundamental special methods def __str__(self) -> str: @@ -597,7 +652,7 @@ def lc(self, other) -> 'MultiVector': newValue = self.layout.lcmt_func(self.value, other.value) - return self._newMV(newValue) + return self.from_value(newValue) @property def pseudoScalar(self) -> 'MultiVector': @@ -706,7 +761,7 @@ def leftLaInv(self) -> 'MultiVector': """Return left-inverse using a computational linear algebra method proposed by Christian Perwass. """ - return self._newMV(self.layout.inv_func(self.value)) + return self.from_value(self.layout.inv_func(self.value)) def _pick_inv(self, fallback): """Internal helper to choose an appropriate inverse method. @@ -792,7 +847,7 @@ def gradeInvol(self) -> 'MultiVector': newValue = signs * self.value - return self._newMV(newValue) + return self.from_value(newValue) @property def even(self) -> 'MultiVector': @@ -852,7 +907,7 @@ def factorise(self) -> Tuple[List['MultiVector'], numbers.Number]: B_c = self/scale for ind in B_max_factors[1:]: # get the basis vector - ei = self._newMV(dtype=B_c.value.dtype) + ei = self.of_zero(dtype=B_c.value.dtype) ei[(ind,)] = 1 fi = (ei.lc(B_c)*B_c.normalInv(check=False)).normal() @@ -881,7 +936,7 @@ def basis(self) -> List['MultiVector']: if self.layout.gradeList[i] == 1: v = np.zeros((self.layout.gaDims,), dtype=float) v[i] = 1. - wholeBasis.append(self._newMV(v)) + wholeBasis.append(self.from_value(v)) thisBasis = [] # vector basis of this subspace @@ -1001,4 +1056,4 @@ def astype(self, *args, **kwargs): See `np.ndarray.astype` for argument descriptions. """ - return self._newMV(self.value.astype(*args, **kwargs)) + return self.from_value(self.value.astype(*args, **kwargs)) diff --git a/clifford/_parser.py b/clifford/_parser.py index 9064075b..6b425860 100644 --- a/clifford/_parser.py +++ b/clifford/_parser.py @@ -69,9 +69,9 @@ def _tokenize(layout: Layout, mv_string: str): ] -def parse_multivector(layout: Layout, mv_string: str) -> MultiVector: +def parse_multivector(layout: Layout, mv_string: str, mv_class=MultiVector) -> MultiVector: # Create a multivector - mv_out = MultiVector(layout) + mv_out = mv_class.of_zero(layout) # parser state sign = None diff --git a/clifford/test/test_clifford.py b/clifford/test/test_clifford.py index bd38ca50..6516eb05 100644 --- a/clifford/test/test_clifford.py +++ b/clifford/test/test_clifford.py @@ -477,6 +477,14 @@ def test_meet(self, g3): assert equivalent_up_to_scale((b^c).meet(a), 1) assert equivalent_up_to_scale((a).meet(a), a) + def test_multivector_constructor(self, g3): + # either is ok + g3.MultiVector(string='e1') + g3.MultiVector(value=np.zeros(8)) + # both together is illegal + with pytest.raises(TypeError): + g3.MultiVector(string='e1', value=np.zeros(8)) + class TestBasicConformal41: def test_metric(self, g4_1): diff --git a/clifford/tools/g3c/cuda.py b/clifford/tools/g3c/cuda.py index 1329767a..81df851b 100644 --- a/clifford/tools/g3c/cuda.py +++ b/clifford/tools/g3c/cuda.py @@ -45,7 +45,7 @@ def sequential_rotor_estimation_chunks_mvs(reference_model_list, query_model_lis query_model_array = np.array([l.value for l in query_model_list]) reference_model_array = np.array([l.value for l in reference_model_list]) output, cost_array = sequential_rotor_estimation_chunks(reference_model_array, query_model_array, n_samples, n_objects_per_sample, mutation_probability=mutation_probability) - output_mvs = [query_model_list[0]._newMV(output[i, :]) for i in range(output.shape[0])] + output_mvs = [query_model_list[0].from_value(output[i, :]) for i in range(output.shape[0])] return output_mvs, cost_array @@ -203,7 +203,7 @@ def sequential_rotor_estimation_cuda_mvs(reference_model_list, query_model_list, query_model_array = np.array([l.value for l in query_model_list]) reference_model_array = np.array([l.value for l in reference_model_list]) output, cost_array = sequential_rotor_estimation_cuda(reference_model_array, query_model_array, n_samples, n_objects_per_sample, mutation_probability=mutation_probability) - output_mvs = [query_model_list[0]._newMV(output[i, :]) for i in range(output.shape[0])] + output_mvs = [query_model_list[0].from_value(output[i, :]) for i in range(output.shape[0])] return output_mvs, cost_array