Skip to content

Commit

Permalink
Fix type interaction with other Rational classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
scoder committed Nov 28, 2024
1 parent 4a3e8a0 commit e965299
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ ChangeLog
* Using ``complex`` numbers in division shows better tracebacks.
https://github.com/python/cpython/pull/102842

* Mixed calculations with other ``Rational`` classes could return the wrong type.
https://github.com/python/cpython/issues/119189


1.18 (2024-04-03)
-----------------
Expand Down
2 changes: 1 addition & 1 deletion src/quicktions.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1487,7 +1487,7 @@ cdef forward(a, b, math_func monomorphic_operator, pyoperator, handle_complex=Tr
return monomorphic_operator(an, ad, (<Fraction>b)._numerator, (<Fraction>b)._denominator)
elif isinstance(b, int):
return monomorphic_operator(an, ad, b, 1)
elif isinstance(b, (Fraction, Rational)):
elif isinstance(b, Fraction):
return monomorphic_operator(an, ad, b.numerator, b.denominator)
elif isinstance(b, float):
return pyoperator(_as_float(an, ad), b)
Expand Down
295 changes: 293 additions & 2 deletions src/test_fractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __float__(self):
assert False, "__float__ should not be invoked"


class DummyFraction(fractions.Fraction):
class DummyFraction(quicktions.Fraction):
"""Dummy Fraction subclass for copy and deepcopy testing."""


Expand Down Expand Up @@ -200,6 +200,197 @@ def test_quicktions_limits(self):
def _components(r):
return (r.numerator, r.denominator)

def typed_approx_eq(a, b):
return type(a) == type(b) and (a == b or math.isclose(a, b))

class Symbolic:
"""Simple non-numeric class for testing mixed arithmetic.
It is not Integral, Rational, Real or Complex, and cannot be conveted
to int, float or complex. but it supports some arithmetic operations.
"""
def __init__(self, value):
self.value = value
def __mul__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(f'{self} * {other}')
def __rmul__(self, other):
return self.__class__(f'{other} * {self}')
def __truediv__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(f'{self} / {other}')
def __rtruediv__(self, other):
return self.__class__(f'{other} / {self}')
def __mod__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(f'{self} % {other}')
def __rmod__(self, other):
return self.__class__(f'{other} % {self}')
def __pow__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(f'{self} ** {other}')
def __rpow__(self, other):
return self.__class__(f'{other} ** {self}')
def __eq__(self, other):
if other.__class__ != self.__class__:
return NotImplemented
return self.value == other.value
def __str__(self):
return f'{self.value}'
def __repr__(self):
return f'{self.__class__.__name__}({self.value!r})'

class SymbolicReal(Symbolic):
pass
numbers.Real.register(SymbolicReal)

class SymbolicComplex(Symbolic):
pass
numbers.Complex.register(SymbolicComplex)

class Rat:
"""Simple Rational class for testing mixed arithmetic."""
def __init__(self, n, d):
self.numerator = n
self.denominator = d
def __mul__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.numerator * other.numerator,
self.denominator * other.denominator)
def __rmul__(self, other):
return self.__class__(other.numerator * self.numerator,
other.denominator * self.denominator)
def __truediv__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.numerator * other.denominator,
self.denominator * other.numerator)
def __rtruediv__(self, other):
return self.__class__(other.numerator * self.denominator,
other.denominator * self.numerator)
def __mod__(self, other):
if isinstance(other, F):
return NotImplemented
d = self.denominator * other.numerator
return self.__class__(self.numerator * other.denominator % d, d)
def __rmod__(self, other):
d = other.denominator * self.numerator
return self.__class__(other.numerator * self.denominator % d, d)

return self.__class__(other.numerator / self.numerator,
other.denominator / self.denominator)
def __pow__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.numerator ** other,
self.denominator ** other)
def __float__(self):
return self.numerator / self.denominator
def __eq__(self, other):
if self.__class__ != other.__class__:
return NotImplemented
return (typed_approx_eq(self.numerator, other.numerator) and
typed_approx_eq(self.denominator, other.denominator))
def __repr__(self):
return f'{self.__class__.__name__}({self.numerator!r}, {self.denominator!r})'
numbers.Rational.register(Rat)

class Root:
"""Simple Real class for testing mixed arithmetic."""
def __init__(self, v, n=F(2)):
self.base = v
self.degree = n
def __mul__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.base * other**self.degree, self.degree)
def __rmul__(self, other):
return self.__class__(other**self.degree * self.base, self.degree)
def __truediv__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.base / other**self.degree, self.degree)
def __rtruediv__(self, other):
return self.__class__(other**self.degree / self.base, self.degree)
def __pow__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.base, self.degree / other)
def __float__(self):
return float(self.base) ** (1 / float(self.degree))
def __eq__(self, other):
if self.__class__ != other.__class__:
return NotImplemented
return typed_approx_eq(self.base, other.base) and typed_approx_eq(self.degree, other.degree)
def __repr__(self):
return f'{self.__class__.__name__}({self.base!r}, {self.degree!r})'
numbers.Real.register(Root)

class Polar:
"""Simple Complex class for testing mixed arithmetic."""
def __init__(self, r, phi):
self.r = r
self.phi = phi
def __mul__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.r * other, self.phi)
def __rmul__(self, other):
return self.__class__(other * self.r, self.phi)
def __truediv__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.r / other, self.phi)
def __rtruediv__(self, other):
return self.__class__(other / self.r, -self.phi)
def __pow__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.r ** other, self.phi * other)
def __eq__(self, other):
if self.__class__ != other.__class__:
return NotImplemented
return typed_approx_eq(self.r, other.r) and typed_approx_eq(self.phi, other.phi)
def __repr__(self):
return f'{self.__class__.__name__}({self.r!r}, {self.phi!r})'
numbers.Complex.register(Polar)

class Rect:
"""Other simple Complex class for testing mixed arithmetic."""
def __init__(self, x, y):
self.x = x
self.y = y
def __mul__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.x * other, self.y * other)
def __rmul__(self, other):
return self.__class__(other * self.x, other * self.y)
def __truediv__(self, other):
if isinstance(other, F):
return NotImplemented
return self.__class__(self.x / other, self.y / other)
def __rtruediv__(self, other):
r = self.x * self.x + self.y * self.y
return self.__class__(other * (self.x / r), other * (self.y / r))
def __rpow__(self, other):
return Polar(other ** self.x, math.log(other) * self.y)
def __complex__(self):
return complex(self.x, self.y)
def __eq__(self, other):
if self.__class__ != other.__class__:
return NotImplemented
return typed_approx_eq(self.x, other.x) and typed_approx_eq(self.y, other.y)
def __repr__(self):
return f'{self.__class__.__name__}({self.x!r}, {self.y!r})'
numbers.Complex.register(Rect)

class RectComplex(Rect, complex):
pass

class FractionTest(unittest.TestCase):

Expand Down Expand Up @@ -795,20 +986,66 @@ def testMixedArithmetic(self):
self.assertTypedEquals(0.9, 1.0 - F(1, 10))
self.assertTypedEquals(0.9 + 0j, (1.0 + 0j) - F(1, 10))

def testMixedMultiplication(self):
self.assertTypedEquals(F(1, 10), F(1, 10) * 1)
self.assertTypedEquals(0.1, F(1, 10) * 1.0)
self.assertTypedEquals(0.1 + 0j, F(1, 10) * (1.0 + 0j))
self.assertTypedEquals(F(1, 10), 1 * F(1, 10))
self.assertTypedEquals(0.1, 1.0 * F(1, 10))
self.assertTypedEquals(0.1 + 0j, (1.0 + 0j) * F(1, 10))

self.assertTypedEquals(F(3, 2) * DummyFraction(5, 3), F(5, 2))
self.assertTypedEquals(DummyFraction(5, 3) * F(3, 2), F(5, 2))
self.assertTypedEquals(F(3, 2) * Rat(5, 3), Rat(15, 6))
self.assertTypedEquals(Rat(5, 3) * F(3, 2), F(5, 2))

self.assertTypedEquals(F(3, 2) * Root(4), Root(F(9, 1)))
self.assertTypedEquals(Root(4) * F(3, 2), 3.0)
self.assertEqual(F(3, 2) * SymbolicReal('X'), SymbolicReal('3/2 * X'))
self.assertRaises(TypeError, operator.mul, SymbolicReal('X'), F(3, 2))

self.assertTypedEquals(F(3, 2) * Polar(4, 2), Polar(F(6, 1), 2))
self.assertTypedEquals(F(3, 2) * Polar(4.0, 2), Polar(6.0, 2))
self.assertTypedEquals(F(3, 2) * Rect(4, 3), Rect(F(6, 1), F(9, 2)))
self.assertTypedEquals(F(3, 2) * RectComplex(4, 3), RectComplex(6.0+0j, 4.5+0j))
self.assertRaises(TypeError, operator.mul, Polar(4, 2), F(3, 2))
self.assertTypedEquals(Rect(4, 3) * F(3, 2), 6.0 + 4.5j)
self.assertEqual(F(3, 2) * SymbolicComplex('X'), SymbolicComplex('3/2 * X'))
self.assertRaises(TypeError, operator.mul, SymbolicComplex('X'), F(3, 2))

self.assertEqual(F(3, 2) * Symbolic('X'), Symbolic('3/2 * X'))
self.assertRaises(TypeError, operator.mul, Symbolic('X'), F(3, 2))

def testMixedDivision(self):
self.assertTypedEquals(F(1, 10), F(1, 10) / 1)
self.assertTypedEquals(0.1, F(1, 10) / 1.0)
self.assertTypedEquals(0.1 + 0j, F(1, 10) / (1.0 + 0j))
self.assertTypedEquals(F(10, 1), 1 / F(1, 10))
self.assertTypedEquals(10.0, 1.0 / F(1, 10))
self.assertTypedEquals(10.0 + 0j, (1.0 + 0j) / F(1, 10))

self.assertTypedEquals(F(3, 2) / DummyFraction(3, 5), F(5, 2))
self.assertTypedEquals(DummyFraction(5, 3) / F(2, 3), F(5, 2))
self.assertTypedEquals(F(3, 2) / Rat(3, 5), Rat(15, 6))
self.assertTypedEquals(Rat(5, 3) / F(2, 3), F(5, 2))

self.assertTypedEquals(F(2, 3) / Root(4), Root(F(1, 9)))
self.assertTypedEquals(Root(4) / F(2, 3), 3.0)
self.assertEqual(F(3, 2) / SymbolicReal('X'), SymbolicReal('3/2 / X'))
self.assertRaises(TypeError, operator.truediv, SymbolicReal('X'), F(3, 2))

self.assertTypedEquals(F(3, 2) / Polar(4, 2), Polar(F(3, 8), -2))
self.assertTypedEquals(F(3, 2) / Polar(4.0, 2), Polar(0.375, -2))
self.assertTypedEquals(F(3, 2) / Rect(4, 3), Rect(0.24, 0.18))
self.assertRaises(TypeError, operator.truediv, Polar(4, 2), F(2, 3))
self.assertTypedEquals(Rect(4, 3) / F(2, 3), 6.0 + 4.5j)
self.assertEqual(F(3, 2) / SymbolicComplex('X'), SymbolicComplex('3/2 / X'))
self.assertRaises(TypeError, operator.truediv, SymbolicComplex('X'), F(3, 2))

self.assertEqual(F(3, 2) / Symbolic('X'), Symbolic('3/2 / X'))
self.assertRaises(TypeError, operator.truediv, Symbolic('X'), F(2, 3))

def testMixedIntegerDivision(self):
self.assertTypedEquals(0, F(1, 10) // 1)
self.assertTypedEquals(0.0, F(1, 10) // 1.0)
self.assertTypedEquals(10, 1 // F(1, 10))
Expand All @@ -835,6 +1072,26 @@ def testMixedArithmetic(self):
self.assertTypedTupleEquals(divmod(-0.1, float('inf')), divmod(F(-1, 10), float('inf')))
self.assertTypedTupleEquals(divmod(-0.1, float('-inf')), divmod(F(-1, 10), float('-inf')))

self.assertTypedEquals(F(3, 2) % DummyFraction(3, 5), F(3, 10))
self.assertTypedEquals(DummyFraction(5, 3) % F(2, 3), F(1, 3))
self.assertTypedEquals(F(3, 2) % Rat(3, 5), Rat(3, 6))
self.assertTypedEquals(Rat(5, 3) % F(2, 3), F(1, 3))

self.assertRaises(TypeError, operator.mod, F(2, 3), Root(4))
self.assertTypedEquals(Root(4) % F(3, 2), 0.5)
self.assertEqual(F(3, 2) % SymbolicReal('X'), SymbolicReal('3/2 % X'))
self.assertRaises(TypeError, operator.mod, SymbolicReal('X'), F(3, 2))

self.assertRaises(TypeError, operator.mod, F(3, 2), Polar(4, 2))
self.assertRaises(TypeError, operator.mod, F(3, 2), RectComplex(4, 3))
self.assertRaises(TypeError, operator.mod, Rect(4, 3), F(2, 3))
self.assertEqual(F(3, 2) % SymbolicComplex('X'), SymbolicComplex('3/2 % X'))
self.assertRaises(TypeError, operator.mod, SymbolicComplex('X'), F(3, 2))

self.assertEqual(F(3, 2) % Symbolic('X'), Symbolic('3/2 % X'))
self.assertRaises(TypeError, operator.mod, Symbolic('X'), F(2, 3))

def testMixedPower(self):
# ** has more interesting conversion rules.
self.assertTypedEquals(F(100, 1), F(1, 10) ** -2)
self.assertTypedEquals(F(100, 1), F(10, 1) ** 2)
Expand All @@ -855,6 +1112,40 @@ def testMixedArithmetic(self):
self.assertRaises(ZeroDivisionError, operator.pow,
F(0, 1), -2)

self.assertTypedEquals(F(3, 2) ** Rat(3, 1), F(27, 8))
self.assertTypedEquals(F(3, 2) ** Rat(-3, 1), F(8, 27))
self.assertTypedEquals(F(-3, 2) ** Rat(-3, 1), F(-8, 27))
self.assertTypedEquals(F(9, 4) ** Rat(3, 2), 3.375)
self.assertIsInstance(F(4, 9) ** Rat(-3, 2), float)
self.assertAlmostEqual(F(4, 9) ** Rat(-3, 2), 3.375)
self.assertAlmostEqual(F(-4, 9) ** Rat(-3, 2), 3.375j)
self.assertTypedEquals(Rat(9, 4) ** F(3, 2), 3.375)
self.assertTypedEquals(Rat(3, 2) ** F(3, 1), Rat(27, 8))
self.assertTypedEquals(Rat(3, 2) ** F(-3, 1), F(8, 27))
self.assertIsInstance(Rat(4, 9) ** F(-3, 2), float)
self.assertAlmostEqual(Rat(4, 9) ** F(-3, 2), 3.375)

self.assertTypedEquals(Root(4) ** F(2, 3), Root(4, 3.0))
self.assertTypedEquals(Root(4) ** F(2, 1), Root(4, F(1)))
self.assertTypedEquals(Root(4) ** F(-2, 1), Root(4, -F(1)))
self.assertTypedEquals(Root(4) ** F(-2, 3), Root(4, -3.0))
self.assertEqual(F(3, 2) ** SymbolicReal('X'), SymbolicReal('1.5 ** X'))
self.assertEqual(SymbolicReal('X') ** F(3, 2), SymbolicReal('X ** 1.5'))

self.assertTypedEquals(F(3, 2) ** Rect(2, 0), Polar(2.25, 0.0))
self.assertTypedEquals(F(1, 1) ** Rect(2, 3), Polar(1.0, 0.0))
self.assertTypedEquals(F(3, 2) ** RectComplex(2, 0), Polar(2.25, 0.0))
self.assertTypedEquals(F(1, 1) ** RectComplex(2, 3), Polar(1.0, 0.0))
self.assertTypedEquals(Polar(4, 2) ** F(3, 2), Polar(8.0, 3.0))
self.assertTypedEquals(Polar(4, 2) ** F(3, 1), Polar(64, 6))
self.assertTypedEquals(Polar(4, 2) ** F(-3, 1), Polar(0.015625, -6))
self.assertTypedEquals(Polar(4, 2) ** F(-3, 2), Polar(0.125, -3.0))
self.assertEqual(F(3, 2) ** SymbolicComplex('X'), SymbolicComplex('1.5 ** X'))
self.assertEqual(SymbolicComplex('X') ** F(3, 2), SymbolicComplex('X ** 1.5'))

self.assertEqual(F(3, 2) ** Symbolic('X'), Symbolic('1.5 ** X'))
self.assertEqual(Symbolic('X') ** F(3, 2), Symbolic('X ** 1.5'))

def testMixingWithDecimal(self):
# Decimal refuses mixed arithmetic (but not mixed comparisons)
self.assertRaises(TypeError, operator.add,
Expand Down Expand Up @@ -1087,7 +1378,7 @@ def numerator(self):
def denominator(self):
return type(self)(1)

f = fractions.Fraction(myint(1 * 3), myint(2 * 3))
f = F(myint(1 * 3), myint(2 * 3))
self.assertEqual(f.numerator, 1)
self.assertEqual(f.denominator, 2)
self.assertEqual(type(f.numerator), myint)
Expand Down

0 comments on commit e965299

Please sign in to comment.