diff --git a/.coveragerc b/.coveragerc index a72fd709..3c42fcfb 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,6 @@ [run] plugins = coverage_plugin + +[report] +exclude_also = + assert False diff --git a/.github/workflows/buildwheel.yml b/.github/workflows/buildwheel.yml index cde98caf..f94df605 100644 --- a/.github/workflows/buildwheel.yml +++ b/.github/workflows/buildwheel.yml @@ -193,7 +193,7 @@ jobs: # Test that we can make a coverage build and report coverage test_coverage_build: - name: Test coverage setuptools build + name: Test coverage build runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 @@ -202,7 +202,7 @@ jobs: python-version: '3.12' - run: sudo apt-get update - run: sudo apt-get install libflint-dev - # Need Cython's master branch until 3.1 is released because of: + # Need Cython's master branch until 3.1 is released because we need: # https://github.com/cython/cython/pull/6341 - run: pip install git+https://github.com/cython/cython.git@master - run: pip install -r requirements-dev.txt diff --git a/coverage_plugin.py b/coverage_plugin.py index 8382dc26..69c9aeeb 100644 --- a/coverage_plugin.py +++ b/coverage_plugin.py @@ -62,7 +62,7 @@ def get_cython_build_rules(): @cache -def parse_all_cfile_lines(): +def parse_all_cfile_lines(exclude_patterns=None): """Parse all generated C files from the build directory.""" # # Each .c file can include code generated from multiple Cython files (e.g. @@ -80,7 +80,7 @@ def parse_all_cfile_lines(): for c_file, _ in get_cython_build_rules(): - cfile_lines = parse_cfile_lines(c_file) + cfile_lines = parse_cfile_lines(c_file, exclude_patterns=exclude_patterns) for cython_file, line_map in cfile_lines.items(): if cython_file == '(tree fragment)': @@ -94,15 +94,22 @@ def parse_all_cfile_lines(): return all_code_lines -def parse_cfile_lines(c_file): +def parse_cfile_lines(c_file, exclude_patterns=None): """Use Cython's coverage plugin to parse the C code.""" from Cython.Coverage import Plugin - return Plugin()._parse_cfile_lines(c_file) + p = Plugin() + p._excluded_line_patterns = list(exclude_patterns) + return p._parse_cfile_lines(c_file) class Plugin(CoveragePlugin): """A coverage plugin for a spin/meson project with Cython code.""" + def configure(self, config): + # Entry point for coverage "configurer". + # Read the regular expressions from the coverage config that match lines to be excluded from coverage. + self.exclude_patterns = tuple(config.get_option("report:exclude_lines")) + def file_tracer(self, filename): """Find a tracer for filename to handle trace events.""" path = Path(filename) @@ -121,7 +128,7 @@ def file_tracer(self, filename): def file_reporter(self, filename): """Return a file reporter for filename.""" srcfile = Path(filename).relative_to(src_dir) - return CyFileReporter(srcfile) + return CyFileReporter(srcfile, exclude_patterns=self.exclude_patterns) class CyFileTracer(FileTracer): @@ -157,7 +164,7 @@ def get_source_filename(filename): class CyFileReporter(FileReporter): """File reporter for Cython or Python files (.pyx,.pxd,.py).""" - def __init__(self, srcpath): + def __init__(self, srcpath, exclude_patterns): abspath = (src_dir / srcpath) assert abspath.exists() @@ -165,6 +172,7 @@ def __init__(self, srcpath): super().__init__(str(abspath)) self.srcpath = srcpath + self.exclude_patterns = exclude_patterns def relative_filename(self): """Path displayed in the coverage reports.""" @@ -173,7 +181,7 @@ def relative_filename(self): def lines(self): """Set of line numbers for possibly traceable lines.""" srcpath = str(self.srcpath) - all_line_maps = parse_all_cfile_lines() + all_line_maps = parse_all_cfile_lines(exclude_patterns=self.exclude_patterns) line_map = all_line_maps[srcpath] return set(line_map) diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index 9494302f..a6cd6834 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -21,6 +21,11 @@ def raises(f, exception): return False +def test_raises(): + assert raises(lambda: 1/0, ZeroDivisionError) is True + assert raises(lambda: 1/1, ZeroDivisionError) is False + + _default_ctx_string = """\ pretty = True # pretty-print repr() output unicode = False # use unicode characters in output @@ -153,8 +158,9 @@ def test_fmpz(): # https://github.com/flintlib/python-flint/issues/74 if not PYPY: assert pow(a, flint.fmpz(b), c) == ab_mod_c - assert pow(a, b, flint.fmpz(c)) == ab_mod_c assert pow(a, flint.fmpz(b), flint.fmpz(c)) == ab_mod_c + assert pow(a, b, flint.fmpz(c)) == ab_mod_c + assert raises(lambda: pow([], flint.fmpz(2), 2), TypeError) assert raises(lambda: pow(flint.fmpz(2), 2, 0), ValueError) # XXX: Handle negative modulus like int? @@ -602,7 +608,8 @@ def set_bad(i,j): assert raises(lambda: M([[1,1],[1,1]]).solve(b), ZeroDivisionError) assert raises(lambda: M([[1,2],[3,4],[5,6]]).solve(b), ValueError) assert M([[1,0],[1,2]]).solve(b) == flint.fmpq_mat([[3],[2]]) - assert raises(lambda: M([[1,0],[1,2]]).solve(b, integer=True), ValueError) + assert raises(lambda: M([[1,0],[1,0]]).solve(b, integer=True), ZeroDivisionError) + assert raises(lambda: M([[1,0],[1,2]]).solve(b, integer=True), DomainError) assert raises(lambda: M([[1,2,3],[4,5,6]]).inv(), ValueError) assert raises(lambda: M([[1,1],[1,1]]).inv(), ZeroDivisionError) assert raises(lambda: M([[1,0],[1,2]]).inv(integer=True), ValueError) @@ -632,6 +639,7 @@ def set_bad(i,j): for gram in "approx", "exact": assert M4.lll(rep=rep, gram=gram) == L4 assert M4.lll(rep=rep, gram=gram, transform=True) == (L4, T4) + assert raises(lambda: M4.lll(rep="gram"), AssertionError) assert raises(lambda: M4.lll(rep="bad"), ValueError) assert raises(lambda: M4.lll(gram="bad"), ValueError) M5 = M([[1,2,3],[4,5,6]]) @@ -764,6 +772,7 @@ def test_fmpq(): assert raises(lambda: Q([]), TypeError) assert raises(lambda: Q(1, []), TypeError) assert raises(lambda: Q([], 1), TypeError) + assert raises(lambda: Q(1, 1, 1), TypeError) assert bool(Q(0)) == False assert bool(Q(1)) == True assert Q(1,3) + Q(2,3) == 1 @@ -925,6 +934,7 @@ def test_fmpq_poly(): assert raises(lambda: Q([1,[]]), TypeError) assert raises(lambda: Q({}), TypeError) assert raises(lambda: Q([1], []), TypeError) + assert raises(lambda: Q(1, 1, 1), TypeError) assert raises(lambda: Q([1], 0), ZeroDivisionError) assert bool(Q()) == False assert bool(Q([1])) == True @@ -1103,17 +1113,27 @@ def set_bad(i): raises(lambda: Q(1,2,[3,4]) * Q(1,3,[5,6,7]), ValueError) raises(lambda: Q(1,2,[3,4]) * Z(1,3,[5,6,7]), ValueError) raises(lambda: Z(1,2,[3,4]) * Q(1,3,[5,6,7]), ValueError) + A = Q([[3,4],[5,7]]) / 11 X = Q([[1,2],[3,4]]) B = A*X assert A.solve(B) == X for algorithm in None, "fflu", "dixon": assert A.solve(B, algorithm=algorithm) == X + for _ in range(2): + A = Q(flint.fmpz_mat.randtest(30, 30, 10)) + if A.det() == 0: + continue # pragma: no cover + B = Q(flint.fmpz_mat.randtest(30, 1, 10)) + X = A.solve(B) + assert A*X == B + assert raises(lambda: A.solve(B, algorithm="invalid"), ValueError) assert raises(lambda: A.solve(None), TypeError) assert raises(lambda: A.solve([1,2]), TypeError) assert raises(lambda: A.solve(Q([[1,2]])), ValueError) assert raises(lambda: Q([[1,2],[2,4]]).solve(Q([[1],[2]])), ZeroDivisionError) + M = Q([[1,2,3],[flint.fmpq(1,2),5,6]]) Mcopy = Q(M) Mrref = Q([[1,0,flint.fmpq(3,4)],[0,1,flint.fmpq(9,8)]]) @@ -1357,6 +1377,10 @@ def test_nmod(): assert str(G(3,5)) == "3" assert G(3,5).repr() == "nmod(3, 5)" + G = flint.nmod_ctx.new(7) + assert G(0) == G(7) == G(-7) + + def test_nmod_poly(): N = flint.nmod P = flint.nmod_poly @@ -1452,9 +1476,15 @@ def set_bad2(): assert raises(set_bad2, TypeError) assert bool(P([], 5)) is False assert bool(P([1], 5)) is True + assert P([1,2,1],3).gcd(P([1,1],3)) == P([1,1],3) - raises(lambda: P([1,2],3).gcd([]), TypeError) - raises(lambda: P([1,2],3).gcd(P([1,2],5)), ValueError) + assert raises(lambda: P([1,2],3).gcd([]), TypeError) + assert raises(lambda: P([1,2],3).gcd(P([1,2],5)), ValueError) + assert P([1,2,1],3).xgcd(P([1,1],3)) == (P([1, 1], 3), P([0], 3), P([1], 3)) + assert raises(lambda: P([1,2],3).xgcd([]), TypeError) + assert raises(lambda: P([1,2],3).xgcd(P([1,2],5)), ValueError) + assert raises(lambda: P([1,2],6).xgcd(P([1,2],6)), DomainError) + p3 = P([1,2,3,4,5,6],7) f3 = (N(6,7), [(P([6, 1],7), 5)]) assert p3.factor() == f3 @@ -1462,6 +1492,8 @@ def set_bad2(): for alg in [None, 'berlekamp', 'cantor-zassenhaus']: assert p3.factor(alg) == f3 assert p3.factor(algorithm=alg) == f3 + assert raises(lambda: p3.factor(algorithm="invalid"), ValueError) + assert P([1], 11).roots() == [] assert P([1, 2, 3], 11).roots() == [(8, 1), (6, 1)] assert P([1, 6, 1, 8], 11).roots() == [(5, 3)] @@ -1516,6 +1548,8 @@ def test_nmod_mat(): assert raises(lambda: M([1], 5), TypeError) assert raises(lambda: M([[1],[2,3]], 5), ValueError) assert raises(lambda: M([[1],[2]], 0), ValueError) + assert raises(lambda: M([[1, 2], [3, 4.0]], 5), TypeError) + assert raises(lambda: M(2, 2, [1, 2, 3, 4.0], 5), TypeError) assert raises(lambda: M(None), TypeError) assert raises(lambda: M(None,17), TypeError) assert M(2,3,17) == M(2,3,[0,0,0,0,0,0],17) @@ -1594,6 +1628,40 @@ def test_nmod_series(): # XXX: currently no code in nmod_series.pyx pass + +def test_nmod_contexts(): + # XXX: Generalise this test to cover fmpz_mod, fq_default, etc. + CS = flint.nmod_ctx + CP = flint.nmod_poly_ctx + CM = flint.nmod_mat_ctx + S = flint.nmod + P = flint.nmod_poly + M = flint.nmod_mat + + for c, name in [(CS, 'nmod'), (CP, 'nmod_poly'), (CM, 'nmod_mat')]: + ctx = c.new(17) + assert ctx.modulus() == 17 + assert str(ctx) == f"Context for {name} with modulus: 17" + assert repr(ctx) == f"{name}_ctx(17)" + assert raises(lambda: c(3), TypeError) + assert raises(lambda: ctx.new(3.0), TypeError) + + ctx = CS.new(17) + assert ctx(3) == S(3,17) == S(3, ctx) + assert raises(lambda: ctx(3.0), TypeError) + assert raises(lambda: S(3, []), TypeError) + + ctx_poly = CP.new(17) + assert ctx_poly([1,2,3]) == P([1,2,3],17) == P([1,2,3], ctx_poly) + assert raises(lambda: ctx_poly([1,2.0,3]), TypeError) + assert raises(lambda: P([1,2,3], []), TypeError) + + ctx_mat = CM.new(17) + assert ctx_mat([[1,2],[3,4]]) == M([[1,2],[3,4]],17) == M([[1,2],[3,4]], ctx_mat) + assert raises(lambda: ctx_mat([[1,2.0],[3,4]]), TypeError) + assert raises(lambda: M([[1,2],[3,4]], []), TypeError) + + def test_arb(): A = flint.arb assert A(3) > A(2.5) @@ -2104,7 +2172,7 @@ def test_fmpz_mod_poly(): assert pow(f, 2**60, g) == pow(pow(f, 2**30, g), 2**30, g) assert pow(R_gen, 2**60, g) == pow(pow(R_gen, 2**30, g), 2**30, g) - # Check other typechecks for pow_mod + # Check other typechecks for pow_mod assert raises(lambda: pow(f, -2, g), ValueError) assert raises(lambda: pow(f, 1, "A"), TypeError) assert raises(lambda: pow(f, "A", g), TypeError) @@ -2197,7 +2265,7 @@ def test_fmpz_mod_poly(): f_inv = f.inverse_series_trunc(2) assert (f * f_inv) % R_test([0,0,1]) == 1 - assert raises(lambda: R_cmp([0,0,1]).inverse_series_trunc(2), ValueError) + assert raises(lambda: R_cmp([0,0,1]).inverse_series_trunc(2), ZeroDivisionError) # Resultant f1 = R_test([-3, 1]) @@ -2485,67 +2553,70 @@ def test_division_matrix(): def _all_polys(): - return [ - # (poly_type, scalar_type, is_field, characteristic) + # (poly_type, scalar_type, is_field, characteristic) + FMPZ = (flint.fmpz_poly, flint.fmpz, False, flint.fmpz(0)) + FMPQ = (flint.fmpq_poly, flint.fmpq, True, flint.fmpz(0)) + + def NMOD(n): + return ( + lambda *a: flint.nmod_poly(*a, n), + lambda x: flint.nmod(x, n), + flint.fmpz(n).is_prime(), + flint.fmpz(n) + ) + + def FMPZ_MOD(n): + return ( + lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(n)), + lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(n)), + flint.fmpz(n).is_prime(), + flint.fmpz(n) + ) + + def FQ_DEFAULT(n, k): + return ( + lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(n, k)), + lambda x: flint.fq_default(x, flint.fq_default_ctx(n, k)), + True, + flint.fmpz(n) + ) + + ALL_POLYS = [ # Z and Q - (flint.fmpz_poly, flint.fmpz, False, flint.fmpz(0)), - (flint.fmpq_poly, flint.fmpq, True, flint.fmpz(0)), + FMPZ, + FMPQ, # Z/pZ (p prime) - (lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True, flint.fmpz(17)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)), - True, flint.fmpz(163)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)), - True, flint.fmpz(2**127 - 1)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)), - True, flint.fmpz(2**255 - 19)), + NMOD(17), + FMPZ_MOD(163), + FMPZ_MOD(2**127 - 1), + FMPZ_MOD(2**255 - 19), # GF(p^k) (p prime) - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(2**127 - 1)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(2**127 - 1)), - True, flint.fmpz(2**127 - 1)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(2**127 - 1, 2)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(2**127 - 1, 2)), - True, flint.fmpz(2**127 - 1)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(65537)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(65537)), - True, flint.fmpz(65537)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(65537, 5)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(65537, 5)), - True, flint.fmpz(65537)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(11)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(11)), - True, flint.fmpz(11)), - (lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(11, 5)), - lambda x: flint.fq_default(x, flint.fq_default_ctx(11, 5)), - True, flint.fmpz(11)), + FQ_DEFAULT(2**127 - 1, 1), + FQ_DEFAULT(2**127 - 1, 2), + FQ_DEFAULT(65537, 1), + FQ_DEFAULT(65537, 5), + FQ_DEFAULT(11, 1), + FQ_DEFAULT(11, 5), # Z/nZ (n composite) - (lambda *a: flint.nmod_poly(*a, 16), lambda x: flint.nmod(x, 16), False, flint.fmpz(16)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(164)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(164)), - False, flint.fmpz(164)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127)), - False, flint.fmpz(2**127)), - (lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255)), - lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255)), - False, flint.fmpz(2**255)), + NMOD(9), + NMOD(16), + FMPZ_MOD(164), + FMPZ_MOD(9), + FMPZ_MOD(2**127), + FMPZ_MOD(2**255), ] + return ALL_POLYS + def test_polys(): for P, S, is_field, characteristic in _all_polys(): composite_characteristic = characteristic != 0 and not characteristic.is_prime() - # nmod_poly crashes for many operations with non-prime modulus - # https://github.com/flintlib/python-flint/issues/124 - # so we can't even test it... - nmod_poly_will_crash = type(P(1)) is flint.nmod_poly and composite_characteristic assert P([S(1)]) == P([1]) == P(P([1])) == P(1) @@ -2690,29 +2761,56 @@ def setbad(obj, i, val): assert raises(lambda: P([1, 2, 3]) * None, TypeError) assert raises(lambda: None * P([1, 2, 3]), TypeError) - assert P([1, 2, 1]) // P([1, 1]) == P([1, 1]) - assert P([1, 2, 1]) % P([1, 1]) == P([0]) - assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0])) + if composite_characteristic and type(P(1)) is flint.nmod_poly: + # Z/nZ for n not prime + # + # fmpz_mod_poly and nmod_poly can sometimes compute division with + # composite characteristic, but it is not guaranteed to work. For + # fmpz_mod_poly, we can detect the failure and raise an exception. + # For nmod_poly, we cannot detect the failure and calling e.g. + # nmod_poly_divrem would crash the process so for nmod_poly we + # raise an exception in all cases if the modulus is not prime. + assert raises(lambda: P([1, 2, 1]) // P([1, 1]), DomainError) + assert raises(lambda: P([1, 2, 1]) % P([1, 1]), DomainError) + assert raises(lambda: divmod(P([1, 2, 1]), P([1, 1])), DomainError) + + assert raises(lambda: 1 // P([1, 1]), DomainError) + assert raises(lambda: 1 % P([1, 1]), DomainError) + assert raises(lambda: divmod(1, P([1, 1])), DomainError) + else: + assert P([1, 2, 1]) // P([1, 1]) == P([1, 1]) + assert P([1, 2, 1]) % P([1, 1]) == P([0]) + assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0])) + + assert 1 // P([1, 1]) == P([0]) + assert 1 % P([1, 1]) == P([1]) + assert divmod(1, P([1, 1])) == (P([0]), P([1])) + + assert P([1, 2, 1]) / P([1, 1]) == P([1, 1]) + + assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError) + + assert raises(lambda: 1 / P([1, 1]), DomainError) + assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError) + assert raises(lambda: [] / P([1, 1]), TypeError) + assert raises(lambda: P([1, 1]) / [], TypeError) if is_field: assert P([1, 1]) // 2 == P([S(1)/2, S(1)/2]) assert P([1, 1]) % 2 == P([0]) + assert P([2, 2]) / 2 == P([1, 1]) + assert P([1, 2]) / 2 == P([S(1)/2, 1]) elif characteristic == 0: assert P([1, 1]) // 2 == P([0, 0]) assert P([1, 1]) % 2 == P([1, 1]) - elif nmod_poly_will_crash: - pass - else: + assert P([2, 2]) / 2 == P([1, 1]) + assert raises(lambda: P([1, 2]) / 2, DomainError) + elif characteristic.gcd(2) != 1 or type(P(1)) is flint.nmod_poly: # Z/nZ for n not prime - if characteristic % 2 == 0: - assert raises(lambda: P([1, 1]) // 2, DomainError) - assert raises(lambda: P([1, 1]) % 2, DomainError) - else: - 1/0 - - assert 1 // P([1, 1]) == P([0]) - assert 1 % P([1, 1]) == P([1]) - assert divmod(1, P([1, 1])) == (P([0]), P([1])) + assert raises(lambda: P([1, 1]) // 2, DomainError) + assert raises(lambda: P([1, 1]) % 2, DomainError) + assert raises(lambda: P([2, 2]) / 2, DomainError) + assert raises(lambda: P([1, 2]) / 2, DomainError) assert raises(lambda: P([1, 2, 1]) // None, TypeError) assert raises(lambda: P([1, 2, 1]) % None, TypeError) @@ -2730,50 +2828,43 @@ def setbad(obj, i, val): assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError) assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError) - # Exact/field scalar division - if is_field: - assert P([2, 2]) / 2 == P([1, 1]) - assert P([1, 2]) / 2 == P([S(1)/2, 1]) - elif characteristic == 0: - assert P([2, 2]) / 2 == P([1, 1]) - assert raises(lambda: P([1, 2]) / 2, DomainError) - elif nmod_poly_will_crash: - pass - else: - # Z/nZ for n not prime - assert raises(lambda: P([2, 2]) / 2, DomainError) - assert raises(lambda: P([1, 2]) / 2, DomainError) - - assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError) - - if not nmod_poly_will_crash: - assert P([1, 2, 1]) / P([1, 1]) == P([1, 1]) - assert raises(lambda: 1 / P([1, 1]), DomainError) - assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError) - assert P([1, 1]) ** 0 == P([1]) assert P([1, 1]) ** 1 == P([1, 1]) assert P([1, 1]) ** 2 == P([1, 2, 1]) assert raises(lambda: P([1, 1]) ** -1, ValueError) assert raises(lambda: P([1, 1]) ** None, TypeError) - - # XXX: Not sure what this should do in general: + + # 3-arg pow: (x^2 + 1)**3 mod x-1 + + pow3_types = [ + # flint.fmpq_poly, XXX + flint.nmod_poly, + flint.fmpz_mod_poly, + flint.fq_default_poly + ] + p = P([1, 1]) mod = P([1, 1]) - if type(p) not in [flint.fmpz_mod_poly, flint.nmod_poly, flint.fq_default_poly]: + + if type(p) not in pow3_types: assert raises(lambda: pow(p, 2, mod), NotImplementedError) + assert p * p % mod == 0 + elif composite_characteristic and type(p) == flint.nmod_poly: + # nmod_poly does not support % with composite characteristic + assert pow(p, 2, mod) == 0 + assert raises(lambda: p * p % mod, DomainError) else: + # Should be for any is_field including fmpq_poly. Works also in + # some cases for fmpz_mod_poly with non-prime modulus. assert p * p % mod == pow(p, 2, mod) if not composite_characteristic: assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1]) - assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError) - elif nmod_poly_will_crash: - pass else: # Z/nZ for n not prime assert raises(lambda: P([1, 2, 1]).gcd(P([1, 1])), DomainError) - assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError) + + assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError) if is_field: p1 = P([1, 0, 1]) @@ -2784,21 +2875,18 @@ def setbad(obj, i, val): if not composite_characteristic: assert P([1, 2, 1]).factor() == (S(1), [(P([1, 1]), 2)]) - elif nmod_poly_will_crash: - pass else: assert raises(lambda: P([1, 2, 1]).factor(), DomainError) if not composite_characteristic: assert P([1, 2, 1]).sqrt() == P([1, 1]) - assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError) - elif nmod_poly_will_crash: - pass else: assert raises(lambda: P([1, 2, 1]).sqrt(), DomainError) + assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError) + if P == flint.fmpq_poly: - assert raises(lambda: P([1, 2, 1], 3).sqrt(), ValueError) + assert raises(lambda: P([1, 2, 1], 3).sqrt(), DomainError) assert P([1, 2, 1], 4).sqrt() == P([1, 1], 2) assert P([]).deflation() == (P([]), 1) @@ -2813,6 +2901,50 @@ def setbad(obj, i, val): if type(p) == flint.fq_default_poly: assert raises(lambda: p.integral(), NotImplementedError) + if characteristic == 0: + assert not hasattr(P(0), "inverse_series_trunc") + elif composite_characteristic: + x = P([0, 1]) + if type(x) is flint.fmpz_mod_poly: + assert (1 + x).inverse_series_trunc(4) == 1 - x + x**2 - x**3 + if characteristic.gcd(3) != 1: + assert raises(lambda: (3 + x).inverse_series_trunc(4), DomainError) + else: + assert (3 + x).inverse_series_trunc(4)\ + == S(1)/3 - S(1)/9*x + S(1)/27*x**2 - S(1)/81*x**3 + elif type(x) is flint.nmod_poly: + assert raises(lambda: (1 + x).inverse_series_trunc(4), DomainError) + else: + assert False + else: + x = P([0, 1]) + assert (1 + x).inverse_series_trunc(4) == 1 - x + x**2 - x**3 + assert (3 + x).inverse_series_trunc(4)\ + == S(1)/3 - S(1)/9*x + S(1)/27*x**2 - S(1)/81*x**3 + assert raises(lambda: (1 + x).inverse_series_trunc(-1), ValueError) + assert raises(lambda: x.inverse_series_trunc(4), ZeroDivisionError) + + if characteristic == 0: + assert not hasattr(P(0), "pow_mod") + elif composite_characteristic: + pass + else: + x = P([0, 1]) + assert (1 + x).pow_mod(4, x**2 + 1) == -4 + assert (3 + x).pow_mod(4, x**2 + 1) == 96*x + 28 + assert x.pow_mod(4, x**2 + 1) == 1 + + assert x.pow_mod(2**127, x - 1) == 1 + assert (1 + x).pow_mod(2**127, x - 1) == pow(2, 2**127, int(characteristic)) + + if type(x) is not flint.fq_default_poly: + assert (1 + x).pow_mod(2**127, x - 1, S(1)/2) == pow(2, 2**127, int(characteristic)) + assert raises(lambda: (1 + x).pow_mod(2**127, x - 1, []), TypeError) + + assert raises(lambda: (1 + x).pow_mod(4, []), TypeError) + assert raises(lambda: (1 + x).pow_mod([], x), TypeError) + assert raises(lambda: (1 + x).pow_mod(-1, x), ValueError) + def _all_mpolys(): return [ @@ -3424,13 +3556,6 @@ def factor_sqf(p): for P, S, [x, y], is_field, characteristic in _all_polys_mpolys(): if characteristic != 0 and not characteristic.is_prime(): - # nmod_poly crashes for many operations with non-prime modulus - # https://github.com/flintlib/python-flint/issues/124 - # so we can't even test it... - nmod_poly_will_crash = type(x) is flint.nmod_poly - if nmod_poly_will_crash: - continue - try: S(4).sqrt() ** 2 == S(4) except DomainError: @@ -3449,6 +3574,18 @@ def factor_sqf(p): assert S(1).sqrt() == S(1) assert S(4).sqrt()**2 == S(4) + if is_field: + for n in range(1, 10): + try: + sqrtn = S(n).sqrt() + except DomainError: + sqrtn = None + if sqrtn is None: + assert raises(lambda: ((x + 1)**2/n).sqrt(), DomainError) + else: + assert ((x + 1)**2/n).sqrt() ** 2 == (x + 1)**2/n + assert raises(lambda: ((x**2 + 1)/n).sqrt(), DomainError) + for i in range(-100, 100): try: assert S(i).sqrt() ** 2 == S(i) @@ -3595,17 +3732,34 @@ def factor_sqf(p): def _all_matrices(): """Return a list of matrix types and scalar types.""" + # Prime modulus R163 = flint.fmpz_mod_ctx(163) R127 = flint.fmpz_mod_ctx(2**127 - 1) R255 = flint.fmpz_mod_ctx(2**255 - 19) + + # Composite modulus + R164_C = flint.fmpz_mod_ctx(164) + R127_C = flint.fmpz_mod_ctx(2**127) + R255_C = flint.fmpz_mod_ctx(2**255) + return [ - # (matrix_type, scalar_type, is_field) - (flint.fmpz_mat, flint.fmpz, False), - (flint.fmpq_mat, flint.fmpq, True), - (lambda *a: flint.nmod_mat(*a, 17), lambda x: flint.nmod(x, 17), True), - (lambda *a: flint.fmpz_mod_mat(*a, R163), lambda x: flint.fmpz_mod(x, R163), True), - (lambda *a: flint.fmpz_mod_mat(*a, R127), lambda x: flint.fmpz_mod(x, R127), True), - (lambda *a: flint.fmpz_mod_mat(*a, R255), lambda x: flint.fmpz_mod(x, R255), True), + # (matrix_type, scalar_type, is_field, characteristic) + + # Z and Q + (flint.fmpz_mat, flint.fmpz, False, 0), + (flint.fmpq_mat, flint.fmpq, True, 0), + + # Z/pZ + (lambda *a: flint.nmod_mat(*a, 17), lambda x: flint.nmod(x, 17), True, 17), + (lambda *a: flint.fmpz_mod_mat(*a, R163), lambda x: flint.fmpz_mod(x, R163), True, 163), + (lambda *a: flint.fmpz_mod_mat(*a, R127), lambda x: flint.fmpz_mod(x, R127), True, 2**127 - 1), + (lambda *a: flint.fmpz_mod_mat(*a, R255), lambda x: flint.fmpz_mod(x, R255), True, 2**255 - 19), + + # Z/nZ (n composite) + (lambda *a: flint.nmod_mat(*a, 16), lambda x: flint.nmod(x, 16), False, 16), + (lambda *a: flint.fmpz_mod_mat(*a, R164_C), lambda x: flint.fmpz_mod(x, R164_C), False, 164), + (lambda *a: flint.fmpz_mod_mat(*a, R127_C), lambda x: flint.fmpz_mod(x, R127_C), False, 2**127), + (lambda *a: flint.fmpz_mod_mat(*a, R255_C), lambda x: flint.fmpz_mod(x, R255_C), False, 2**255), ] @@ -3726,7 +3880,7 @@ def _poly_type_from_matrix_type(mat_type): def test_matrices_eq(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): A1 = M([[1, 2], [3, 4]]) A2 = M([[1, 2], [3, 4]]) B = M([[5, 6], [7, 8]]) @@ -3751,7 +3905,7 @@ def test_matrices_eq(): def test_matrices_constructor(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): assert raises(lambda: M(), TypeError) # Empty matrices @@ -3823,7 +3977,7 @@ def _matrix_repr(M): def test_matrices_strrepr(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): A = M([[1, 2], [3, 4]]) A_str = "[1, 2]\n[3, 4]" A_repr = _matrix_repr(A) @@ -3846,7 +4000,7 @@ def test_matrices_strrepr(): def test_matrices_getitem(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) assert M1234[0, 0] == S(1) assert M1234[0, 1] == S(2) @@ -3862,7 +4016,7 @@ def test_matrices_getitem(): def test_matrices_setitem(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) assert M1234[0, 0] == S(1) @@ -3888,7 +4042,7 @@ def setbad(obj, key, val): def test_matrices_bool(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): assert bool(M([])) is False assert bool(M([[0]])) is False assert bool(M([[1]])) is True @@ -3899,14 +4053,14 @@ def test_matrices_bool(): def test_matrices_pos_neg(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) assert +M1234 == M1234 assert -M1234 == M([[-1, -2], [-3, -4]]) def test_matrices_add(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) M5678 = M([[5, 6], [7, 8]]) assert M1234 + M5678 == M([[6, 8], [10, 12]]) @@ -3926,7 +4080,7 @@ def test_matrices_add(): def test_matrices_sub(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) M5678 = M([[5, 6], [7, 8]]) assert M1234 - M5678 == M([[-4, -4], [-4, -4]]) @@ -3946,7 +4100,7 @@ def test_matrices_sub(): def test_matrices_mul(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) M5678 = M([[5, 6], [7, 8]]) assert M1234 * M5678 == M([[19, 22], [43, 50]]) @@ -3972,18 +4126,26 @@ def test_matrices_mul(): def test_matrices_pow(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) + assert M1234**0 == M([[1, 0], [0, 1]]) assert M1234**1 == M1234 assert M1234**2 == M([[7, 10], [15, 22]]) assert M1234**3 == M([[37, 54], [81, 118]]) + if is_field: assert M1234**-1 == M([[-4, 2], [3, -1]]) / 2 assert M1234**-2 == M([[22, -10], [-15, 7]]) / 4 assert M1234**-3 == M([[-118, 54], [81, -37]]) / 8 Ms = M([[1, 2], [3, 6]]) assert raises(lambda: Ms**-1, ZeroDivisionError) + else: + # XXX: Allow unimodular matrices? + assert raises(lambda: M1234**-1, DomainError) + + assert raises(lambda: pow(M1234, 2, 3), NotImplementedError) + Mr = M([[1, 2, 3], [4, 5, 6]]) assert raises(lambda: Mr**0, ValueError) assert raises(lambda: Mr**1, ValueError) @@ -3993,31 +4155,49 @@ def test_matrices_pow(): def test_matrices_div(): - for M, S, is_field in _all_matrices(): + + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) + if is_field: assert M1234 / 2 == M([[S(1)/2, S(1)], [S(3)/2, 2]]) assert M1234 / S(2) == M([[S(1)/2, S(1)], [S(3)/2, 2]]) assert raises(lambda: M1234 / 0, ZeroDivisionError) assert raises(lambda: M1234 / S(0), ZeroDivisionError) + else: + assert raises(lambda: M1234 / 2, DomainError) + if characteristic == 0: + assert (2*M1234) / 2 == M1234 + else: + assert raises(lambda: (2*M1234) / 2, DomainError) + raises(lambda: M1234 / None, TypeError) raises(lambda: None / M1234, TypeError) def test_matrices_inv(): - for M, S, is_field in _all_matrices(): - if is_field: - M1234 = M([[1, 2], [3, 4]]) + + for M, S, is_field, characteristic in _all_matrices(): + + M1234 = M([[1, 2], [3, 4]]) + M1236 = M([[1, 2], [3, 6]]) + Mr = M([[1, 2, 3], [4, 5, 6]]) + + if characteristic > 0 and not is_field: + assert raises(lambda: M([[1, 2], [3, 4]]).inv(), DomainError) + elif is_field: assert M1234.inv() == M([[-2, 1], [S(3)/2, -S(1)/2]]) - M1236 = M([[1, 2], [3, 6]]) assert raises(lambda: M1236.inv(), ZeroDivisionError) - Mr = M([[1, 2, 3], [4, 5, 6]]) assert raises(lambda: Mr.inv(), ValueError) - # XXX: Test non-field matrices. unimodular? + else: + # assert M1234.inv() == (M([[-4, 2], [3, -1]]), 2) + # assert M1236.inv() == (M([[-6, 2], [3, -1]]), 3) + # XXX: fmpz_mat.inv() return fmpq_mat... + assert M1234.inv() * M1234.det() == M([[4, -2], [-3, 1]]) def test_matrices_det(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2], [3, 4]]) assert M1234.det() == S(-2) M9 = M([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) @@ -4027,7 +4207,7 @@ def test_matrices_det(): def test_matrices_charpoly(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): P = _poly_type_from_matrix_type(M) M1234 = M([[1, 2], [3, 4]]) assert M1234.charpoly() == P([-2, -5, 1]) @@ -4038,18 +4218,21 @@ def test_matrices_charpoly(): def test_matrices_minpoly(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): + if characteristic > 0 and not is_field: + assert raises(lambda: M([[1, 2], [3, 4]]).minpoly(), DomainError) + continue P = _poly_type_from_matrix_type(M) - M1234 = M([[1, 2], [3, 4]]) - assert M1234.minpoly() == P([-2, -5, 1]) - M9 = M([[2, 1, 0], [0, 2, 0], [0, 0, 2]]) - assert M9.minpoly() == P([4, -4, 1]) - Mr = M([[1, 2, 3], [4, 5, 6]]) - assert raises(lambda: Mr.minpoly(), ValueError) + assert M([[1, 2], [3, 4]]).minpoly() == P([-2, -5, 1]) + assert M([[2, 1, 0], [0, 2, 0], [0, 0, 2]]).minpoly() == P([4, -4, 1]) + assert raises(lambda: M([[1, 2, 3], [4, 5, 6]]).minpoly(), ValueError) def test_matrices_rank(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): + if characteristic > 0 and not is_field: + assert raises(lambda: M([[1, 2], [3, 4]]).rank(), DomainError) + continue M1234 = M([[1, 2], [3, 4]]) assert M1234.rank() == 2 Mr = M([[1, 2, 3], [4, 5, 6]]) @@ -4061,37 +4244,57 @@ def test_matrices_rank(): def test_matrices_rref(): - for M, S, is_field in _all_matrices(): - if is_field: - Mr = M([[1, 2, 3], [4, 5, 6]]) - Mr_rref = M([[1, 0, -1], [0, 1, 2]]) + for M, S, is_field, characteristic in _all_matrices(): + + Mr = M([[1, 2, 3], [4, 5, 6]]) + Mr_rref = M([[1, 0, -1], [0, 1, 2]]) + + if characteristic > 0 and not is_field: + # Z/nZ (n composite) raises + assert raises(lambda: Mr.rref(), DomainError) + elif is_field: + # Q, Z/pZ and GF(p^d) return usual RREF assert Mr.rref() == (Mr_rref, 2) assert Mr == M([[1, 2, 3], [4, 5, 6]]) assert Mr.rref(inplace=True) == (Mr_rref, 2) assert Mr == Mr_rref + else: + # Z returns RREF with divisor -3 + d = -3 + assert Mr.rref() == (d*Mr_rref, d, 2) + assert Mr == M([[1, 2, 3], [4, 5, 6]]) + assert Mr.rref(inplace=True) == (d*Mr_rref, d, 2) + assert Mr == d*Mr_rref def test_matrices_solve(): - for M, S, is_field in _all_matrices(): - if is_field: - A = M([[1, 2], [3, 4]]) - x = M([[1], [2]]) - b = M([[5], [11]]) - assert A*x == b + for M, S, is_field, characteristic in _all_matrices(): + + A = M([[1, 2], [3, 4]]) + x = M([[1], [2]]) + b = M([[5], [11]]) + assert A*x == b + + A2 = M([[1, 2], [2, 4]]) + + if characteristic > 0 and not is_field: + assert raises(lambda: A.solve(b), DomainError) + assert raises(lambda: A2.solve(b), DomainError) + else: assert A.solve(b) == x - A22 = M([[1, 2], [3, 4]]) - A23 = M([[1, 2, 3], [4, 5, 6]]) - b2 = M([[5], [11]]) - b3 = M([[5], [11], [17]]) - assert raises(lambda: A22.solve(b3), ValueError) - assert raises(lambda: A23.solve(b2), ValueError) - assert raises(lambda: A.solve(None), TypeError) - A = M([[1, 2], [2, 4]]) - assert raises(lambda: A.solve(b), ZeroDivisionError) + assert raises(lambda: A2.solve(b), ZeroDivisionError) + + A22 = M([[1, 2], [3, 4]]) + A23 = M([[1, 2, 3], [4, 5, 6]]) + b2 = M([[5], [11]]) + b3 = M([[5], [11], [17]]) + assert raises(lambda: A22.solve(b3), ValueError) + assert raises(lambda: A23.solve(b2), ValueError) + assert raises(lambda: A.solve(None), TypeError) def test_matrices_transpose(): - for M, S, is_field in _all_matrices(): + for M, S, is_field, characteristic in _all_matrices(): M1234 = M([[1, 2, 3], [4, 5, 6]]) assert M1234.transpose() == M([[1, 4], [2, 5], [3, 6]]) @@ -4126,7 +4329,7 @@ def test_fq_default(): # p must be prime assert raises(lambda: flint.fq_default_ctx(10), ValueError) - + # degree must be positive assert raises(lambda: flint.fq_default_ctx(11, -1), ValueError) @@ -4483,6 +4686,7 @@ def test_all_tests(): all_tests = [ + test_raises, test_pyflint, test_showgood, @@ -4505,6 +4709,8 @@ def test_all_tests(): test_nmod_mat, test_nmod_series, + test_nmod_contexts, + test_fmpz_mod, test_fmpz_mod_dlog, test_fmpz_mod_poly, diff --git a/src/flint/types/arb.pyx b/src/flint/types/arb.pyx index 2553a92f..977d1f0e 100644 --- a/src/flint/types/arb.pyx +++ b/src/flint/types/arb.pyx @@ -2268,7 +2268,7 @@ cdef class arb(flint_scalar): >>> from flint import showgood >>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5, abc=True), dps=25) 1.447530478120770807945697 - >>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5), dps=25) + >>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5), dps=25) # doctest: +SKIP Traceback (most recent call last): ... ValueError: no convergence (maxprec=960, try higher maxprec) diff --git a/src/flint/types/fmpq.pyx b/src/flint/types/fmpq.pyx index 4481f20f..308f1907 100644 --- a/src/flint/types/fmpq.pyx +++ b/src/flint/types/fmpq.pyx @@ -107,9 +107,6 @@ cdef class fmpq(flint_scalar): def __richcmp__(s, t, int op): cdef bint res - s = any_as_fmpq(s) - if s is NotImplemented: - return s t = any_as_fmpq(t) if t is NotImplemented: return t @@ -119,17 +116,17 @@ cdef class fmpq(flint_scalar): res = not res return res else: - # todo: use fmpq_cmp when available + res = fmpq_cmp(s.val, (t).val) if op == 0: - res = (s-t).p < 0 + res = res < 0 elif op == 1: - res = (s-t).p <= 0 + res = res <= 0 elif op == 4: - res = (s-t).p > 0 + res = res > 0 elif op == 5: - res = (s-t).p >= 0 + res = res >= 0 else: - raise ValueError + assert False return res def numer(self): @@ -442,9 +439,9 @@ cdef class fmpq(flint_scalar): import sys from fractions import Fraction if sys.version_info < (3, 12): - return hash(Fraction(int(self.p), int(self.q), _normalize=False)) + return hash(Fraction(int(self.p), int(self.q), _normalize=False)) # pragma: no cover else: - return hash(Fraction._from_coprime_ints(int(self.p), int(self.q))) + return hash(Fraction._from_coprime_ints(int(self.p), int(self.q))) # pragma: no cover def height_bits(self, bint signed=False): """ diff --git a/src/flint/types/fmpq_mat.pyx b/src/flint/types/fmpq_mat.pyx index 364e6726..daddaa19 100644 --- a/src/flint/types/fmpq_mat.pyx +++ b/src/flint/types/fmpq_mat.pyx @@ -100,9 +100,6 @@ cdef class fmpq_mat(flint_mat): cdef bint r if op != 2 and op != 3: raise TypeError("matrices cannot be ordered") - s = any_as_fmpq_mat(s) - if t is NotImplemented: - return s t = any_as_fmpq_mat(t) if t is NotImplemented: return t @@ -269,9 +266,6 @@ cdef class fmpq_mat(flint_mat): def __truediv__(s, t): return fmpq_mat._div_(s, t) - def __div__(s, t): - return fmpq_mat._div_(s, t) - def inv(self): """ Returns the inverse matrix of *self*. @@ -489,7 +483,7 @@ cdef class fmpq_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("matrix must be square") if z is not None: - raise TypeError("fmpq_mat does not support modular exponentiation") + raise NotImplementedError("fmpq_mat does not support modular exponentiation") n = int(n) if n == 0: diff --git a/src/flint/types/fmpq_poly.pyx b/src/flint/types/fmpq_poly.pyx index ff4879d4..87ca48fa 100644 --- a/src/flint/types/fmpq_poly.pyx +++ b/src/flint/types/fmpq_poly.pyx @@ -120,9 +120,6 @@ cdef class fmpq_poly(flint_poly): cdef bint r if op != 2 and op != 3: raise TypeError("polynomials cannot be ordered") - self = any_as_fmpq_poly(self) - if self is NotImplemented: - return self other = any_as_fmpq_poly(other) if other is NotImplemented: return other @@ -476,15 +473,8 @@ cdef class fmpq_poly(flint_poly): 1/2*x + 1/2 """ - d = self.denom() - n = self.numer() - d, r = d.sqrtrem() - if r != 0: - raise ValueError(f"Cannot compute square root of {self}") - n = n.sqrt() - if n is None: - raise ValueError(f"Cannot compute square root of {self}") - return fmpq_poly(n, d) + d = self.denom().sqrt() + return fmpq_poly(self.numer().sqrt(), d) def deflation(self): num, n = self.numer().deflation() diff --git a/src/flint/types/fmpz.pyx b/src/flint/types/fmpz.pyx index 9ae1dec8..616fb4c5 100644 --- a/src/flint/types/fmpz.pyx +++ b/src/flint/types/fmpz.pyx @@ -396,18 +396,15 @@ cdef class fmpz(flint_scalar): return u def __pow__(s, t, m): - cdef fmpz_struct sval[1] - cdef fmpz_struct tval[1] - cdef fmpz_struct mval[1] - cdef int stype = FMPZ_UNKNOWN + cdef fmpz_t tval + cdef fmpz_t mval cdef int ttype = FMPZ_UNKNOWN cdef int mtype = FMPZ_UNKNOWN cdef int success u = NotImplemented try: - stype = fmpz_set_any_ref(sval, s) - if stype == FMPZ_UNKNOWN: + if not typecheck(s, fmpz): return NotImplemented ttype = fmpz_set_any_ref(tval, t) if ttype == FMPZ_UNKNOWN: @@ -441,12 +438,10 @@ cdef class fmpz(flint_scalar): raise ValueError("pow(): negative modulus not supported") u = fmpz.__new__(fmpz) - fmpz_powm((u).val, sval, tval, mval) + fmpz_powm((u).val, s.val, tval, mval) return u finally: - if stype == FMPZ_TMP: - fmpz_clear(sval) if ttype == FMPZ_TMP: fmpz_clear(tval) if mtype == FMPZ_TMP: diff --git a/src/flint/types/fmpz_mat.pyx b/src/flint/types/fmpz_mat.pyx index 097a90ae..e80cba66 100644 --- a/src/flint/types/fmpz_mat.pyx +++ b/src/flint/types/fmpz_mat.pyx @@ -305,14 +305,16 @@ cdef class fmpz_mat(flint_mat): def __pow__(self, e, m): cdef fmpz_mat t cdef ulong ee - if not typecheck(self, fmpz_mat): - return NotImplemented - if not fmpz_mat_is_square((self).val): + if not fmpz_mat_is_square(self.val): raise ValueError("matrix must be square") if m is not None: raise NotImplementedError("modular matrix exponentiation") + if e < 0: + # Allow unimodular? + raise DomainError("negative power of integer matrix: M**%i" % e) ee = e - t = fmpz_mat(self) # XXX + t = fmpz_mat.__new__(fmpz_mat) + fmpz_mat_init_set(t.val, self.val) fmpz_mat_pow(t.val, t.val, ee) return t @@ -520,7 +522,7 @@ cdef class fmpz_mat(flint_mat): >>> A.solve(B, integer=True) Traceback (most recent call last): ... - ValueError: matrix is not invertible over the integers + flint.utils.flint_exceptions.DomainError: matrix is not invertible over the integers >>> fmpz_mat([[1,2], [3,5]]).solve(B, integer=True) [ 6, 3, 0] [-3, -1, 1] @@ -556,11 +558,11 @@ cdef class fmpz_mat(flint_mat): fmpz_mat_ncols((t).val)) d = fmpz.__new__(fmpz) result = fmpz_mat_solve(u.val, d.val, self.val, (t).val) - if not fmpz_is_pm1(d.val): - raise ValueError("matrix is not invertible over the integers") - u *= d if not result: raise ZeroDivisionError("singular matrix in solve()") + if not fmpz_is_pm1(d.val): + raise DomainError("matrix is not invertible over the integers") + u *= d return u def rref(self, inplace=False): @@ -642,7 +644,10 @@ cdef class fmpz_mat(flint_mat): if rep == "zbasis": rt = rep_type.Z_BASIS elif rep == "gram": + # XXX: This consumes all memory and crashes. Maybe the parameters + # need to be different or something? rt = rep_type.GRAM + assert False, "rep = gram does not currently work." else: raise ValueError("rep must be 'zbasis' or 'gram'") if gram == "approx": diff --git a/src/flint/types/fmpz_mod_mat.pyx b/src/flint/types/fmpz_mod_mat.pyx index 91d7839f..c971861b 100644 --- a/src/flint/types/fmpz_mod_mat.pyx +++ b/src/flint/types/fmpz_mod_mat.pyx @@ -53,6 +53,10 @@ from flint.types.nmod_mat cimport ( nmod_mat, ) +from flint.utils.flint_exceptions import ( + DomainError, +) + cdef any_as_fmpz_mod_mat(x): if typecheck(x, fmpz_mod_mat): @@ -401,7 +405,15 @@ cdef class fmpz_mod_mat(flint_mat): def _div(self, fmpz_mod other): """Divide an ``fmpz_mod_mat`` matrix by an ``fmpz_mod`` scalar.""" - return self._scalarmul(other.inverse()) + try: + inv = other.inverse() + except ZeroDivisionError: + # XXX: Maybe fmpz_mod should raise DomainError? + if other == 0: + raise ZeroDivisionError("fmpz_mod_mat div: division by zero") + else: + raise DomainError("fmpz_mod_mat div: division by non-invertible element") + return self._scalarmul(inv) def __add__(self, other): """``M + N``: Add two matrices.""" @@ -455,8 +467,10 @@ cdef class fmpz_mod_mat(flint_mat): return self._scalarmul(e) return NotImplemented - def __pow__(self, other): + def __pow__(self, other, m=None): """``M ** n``: Raise a matrix to an integer power.""" + if m is not None: + raise NotImplementedError("fmpz_mod_mat pow: modulo not supported") if not isinstance(other, int): return NotImplemented return self._pow(other) @@ -483,8 +497,12 @@ cdef class fmpz_mod_mat(flint_mat): Assumes that the modulus is prime. """ cdef fmpz_mod_mat res + if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat inv: matrix must be square") + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat inv: modulus must be prime") + res = self._newlike() r = compat_fmpz_mod_mat_inv(res.val, self.val, self.ctx.val) if r == 0: @@ -546,6 +564,8 @@ cdef class fmpz_mod_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat minpoly: matrix must be square") + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat minpoly: modulus must be prime") pctx = fmpz_mod_poly_ctx(self.ctx) res = fmpz_mod_poly(0, pctx) @@ -596,6 +616,8 @@ cdef class fmpz_mod_mat(flint_mat): raise ValueError("fmpz_mod_mat solve: matrix must be square") if self.nrows() != rhs.nrows(): raise ValueError("fmpz_mod_mat solve: shape mismatch") + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat solve: modulus must be prime") res = self._new(rhs.nrows(), rhs.ncols(), self.ctx) success = compat_fmpz_mod_mat_solve(res.val, self.val, ( rhs).val, self.ctx.val) @@ -616,6 +638,8 @@ cdef class fmpz_mod_mat(flint_mat): Assumes that the modulus is prime. """ + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat rank: modulus must be prime") return self.rref()[1] def rref(self, inplace=False): @@ -637,6 +661,10 @@ cdef class fmpz_mod_mat(flint_mat): """ cdef fmpz_mod_mat res cdef slong r + + if not self.ctx.is_prime(): + raise DomainError("fmpz_mod_mat rref: modulus must be prime") + if inplace: res = self else: diff --git a/src/flint/types/fmpz_mod_poly.pyx b/src/flint/types/fmpz_mod_poly.pyx index a28e56be..90acb83f 100644 --- a/src/flint/types/fmpz_mod_poly.pyx +++ b/src/flint/types/fmpz_mod_poly.pyx @@ -433,10 +433,6 @@ cdef class fmpz_mod_poly(flint_poly): def _div_(self, other): cdef fmpz_mod_poly res - other = self.ctx.mod.any_as_fmpz_mod(other) - if other is NotImplemented: - return NotImplemented - if other == 0: raise ZeroDivisionError("Cannot divide by zero") elif not other.is_unit(): @@ -1438,17 +1434,27 @@ cdef class fmpz_mod_poly(flint_poly): """ cdef fmpz_t f cdef fmpz_mod_poly res + cdef bint is_one + + if n < 1: + raise ValueError(f"n = {n} must be positive") + + if self.constant_coefficient() == 0: + raise ZeroDivisionError("fmpz_mod_poly inverse_series_trunc: zero constant term") res = self.ctx.new_ctype_poly() fmpz_init(f) fmpz_mod_poly_inv_series_f( f, res.val, self.val, n, res.ctx.mod.val ) - if not fmpz_is_one(f): - fmpz_clear(f) - raise ValueError( + is_one = fmpz_is_one(f) + fmpz_clear(f) + + if not is_one: + raise DomainError( f"Cannot compute inverse series of {self} modulo x^{n}" ) + return res def resultant(self, other): diff --git a/src/flint/types/fmpz_poly.pyx b/src/flint/types/fmpz_poly.pyx index 7a1c33c3..393fc4ce 100644 --- a/src/flint/types/fmpz_poly.pyx +++ b/src/flint/types/fmpz_poly.pyx @@ -107,9 +107,6 @@ cdef class fmpz_poly(flint_poly): cdef bint r if op != 2 and op != 3: raise TypeError("polynomials cannot be ordered") - self = any_as_fmpz_poly(self) - if self is NotImplemented: - return self other = any_as_fmpz_poly(other) if other is NotImplemented: return other @@ -456,7 +453,7 @@ cdef class fmpz_poly(flint_poly): return [] flags = 0 if verbose: - flags = 1 + flags = 1 # pragma: no cover roots = [] fmpz_poly_factor_init(fac) fmpz_poly_factor_squarefree(fac, self.val) @@ -551,8 +548,8 @@ cdef class fmpz_poly(flint_poly): arb_poly_init(t) arb_poly_swinnerton_dyer_ui(t, n, 0) if not arb_poly_get_unique_fmpz_poly((u).val, t): - arb_poly_clear(t) - raise ValueError("insufficient precision") + arb_poly_clear(t) # pragma: no cover + raise ValueError("insufficient precision") # pragma: no cover arb_poly_clear(t) else: fmpz_poly_swinnerton_dyer((u).val, n) diff --git a/src/flint/types/fq_default_poly.pyx b/src/flint/types/fq_default_poly.pyx index ec2dfe83..4582f0e5 100644 --- a/src/flint/types/fq_default_poly.pyx +++ b/src/flint/types/fq_default_poly.pyx @@ -1240,6 +1240,9 @@ cdef class fq_default_poly(flint_poly): """ cdef fq_default_poly res + if n < 1: + raise ValueError(f"n = {n} must be positive") + if self.constant_coefficient().is_zero(): raise ZeroDivisionError("constant coefficient must be invertible") diff --git a/src/flint/types/nmod.pxd b/src/flint/types/nmod.pxd index 44625f59..9259e2a6 100644 --- a/src/flint/types/nmod.pxd +++ b/src/flint/types/nmod.pxd @@ -1,9 +1,95 @@ +cimport cython + +from flint.flintlib.types.flint cimport mp_limb_t, ulong +from flint.flintlib.types.fmpz cimport fmpz_t +from flint.flintlib.types.nmod cimport nmod_t + +from flint.flintlib.functions.nmod cimport nmod_init +from flint.flintlib.functions.ulong_extras cimport n_is_prime +from flint.flintlib.functions.fmpq cimport fmpq_mod_fmpz +from flint.flintlib.functions.fmpz cimport ( + fmpz_t, + fmpz_fdiv_ui, + fmpz_init, + fmpz_clear, + fmpz_set_ui, + fmpz_get_ui, +) + from flint.flint_base.flint_base cimport flint_scalar -from flint.flintlib.types.flint cimport mp_limb_t -from flint.flintlib.functions.nmod cimport nmod_t +from flint.utils.typecheck cimport typecheck + +from flint.types.fmpz cimport fmpz, any_as_fmpz +from flint.types.fmpq cimport fmpq, any_as_fmpq -cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1 +cdef dict _nmod_ctx_cache + + +@cython.no_gc +cdef class nmod_ctx: + cdef nmod_t mod + cdef bint _is_prime + + @staticmethod + cdef inline nmod_ctx any_as_nmod_ctx(obj): + """Convert an ``nmod_ctx`` or ``int`` to an ``nmod_ctx``.""" + if typecheck(obj, nmod_ctx): + return obj + if typecheck(obj, int): + return nmod_ctx._get_ctx(obj) + raise TypeError("Invalid context/modulus for nmod: %s" % obj) + + @staticmethod + cdef inline nmod_ctx _get_ctx(int mod): + """Retrieve an nmod context from the cache or create a new one.""" + ctx = _nmod_ctx_cache.get(mod) + if ctx is None: + ctx = _nmod_ctx_cache.setdefault(mod, nmod_ctx._new_ctx(mod)) + return ctx + + @staticmethod + cdef inline nmod_ctx _new_ctx(ulong mod): + """Create a new nmod context.""" + cdef nmod_ctx ctx = nmod_ctx.__new__(nmod_ctx) + nmod_init(&ctx.mod, mod) + ctx._is_prime = n_is_prime(mod) + return ctx + + @cython.final + cdef inline int any_as_nmod(nmod_ctx ctx, mp_limb_t * val, obj) except -1: + """Convert an object to an nmod element.""" + cdef int success + cdef fmpz_t t + if typecheck(obj, nmod): + if (obj).ctx.mod.n != ctx.mod.n: + raise ValueError("cannot coerce integers mod n with different n") + val[0] = (obj).val + return 1 + z = any_as_fmpz(obj) + if z is not NotImplemented: + val[0] = fmpz_fdiv_ui((z).val, ctx.mod.n) + return 1 + q = any_as_fmpq(obj) + if q is not NotImplemented: + fmpz_init(t) + fmpz_set_ui(t, ctx.mod.n) + success = fmpq_mod_fmpz(t, (q).val, t) + val[0] = fmpz_get_ui(t) + fmpz_clear(t) + if not success: + raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n)) + return 1 + return 0 + + @cython.final + cdef inline nmod new_nmod(self): + cdef nmod r = nmod.__new__(nmod) + r.ctx = self + return r + + +@cython.no_gc cdef class nmod(flint_scalar): cdef mp_limb_t val - cdef nmod_t mod + cdef nmod_ctx ctx diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 0f4d7c3f..97b57590 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -1,51 +1,113 @@ +cimport cython + from flint.flint_base.flint_base cimport flint_scalar from flint.utils.typecheck cimport typecheck -from flint.types.fmpq cimport any_as_fmpq from flint.types.fmpz cimport any_as_fmpz from flint.types.fmpz cimport fmpz -from flint.types.fmpq cimport fmpq -from flint.flintlib.types.flint cimport ulong, nmod_t, mp_limb_t -from flint.flintlib.functions.fmpz cimport fmpz_t +from flint.flintlib.types.flint cimport ulong, mp_limb_t from flint.flintlib.functions.nmod cimport ( - nmod_init, nmod_pow_fmpz, nmod_neg, nmod_add, nmod_sub, nmod_mul, ) -from flint.flintlib.functions.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear -from flint.flintlib.functions.fmpz cimport fmpz_set_ui, fmpz_get_ui -from flint.flintlib.functions.fmpq cimport fmpq_mod_fmpz from flint.flintlib.functions.ulong_extras cimport n_gcdinv, n_sqrtmod from flint.utils.flint_exceptions import DomainError -cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1: - cdef int success - cdef fmpz_t t - if typecheck(obj, nmod): - if (obj).mod.n != mod.n: - raise ValueError("cannot coerce integers mod n with different n") - val[0] = (obj).val - return 1 - z = any_as_fmpz(obj) - if z is not NotImplemented: - val[0] = fmpz_fdiv_ui((z).val, mod.n) - return 1 - q = any_as_fmpq(obj) - if q is not NotImplemented: - fmpz_init(t) - fmpz_set_ui(t, mod.n) - success = fmpq_mod_fmpz(t, (q).val, t) - val[0] = fmpz_get_ui(t) - fmpz_clear(t) - if not success: - raise ZeroDivisionError("%s does not exist mod %i!" % (q, mod.n)) - return 1 - return 0 +_nmod_ctx_cache = {} + + +@cython.no_gc +cdef class nmod_ctx: + """ + Context object for creating :class:`~.nmod` initalised + with modulus :math:`N`. + + >>> ctx = nmod_ctx.new(17) + >>> ctx + nmod_ctx(17) + >>> ctx.modulus() + 17 + >>> e = ctx(10) + >>> e + 10 + >>> e + 10 + 3 + + """ + def __init__(self, *args, **kwargs): + raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.new()") + + @staticmethod + def new(mod): + """Get an nmod context with modulus ``mod``.""" + return nmod_ctx.any_as_nmod_ctx(mod) + + def modulus(self): + """Get the modulus of the context. + + >>> ctx = nmod_ctx.new(17) + >>> ctx.modulus() + 17 + + """ + return fmpz(self.mod.n) + + def is_prime(self): + """Check if the modulus is prime. + + >>> ctx = nmod_ctx.new(17) + >>> ctx.is_prime() + True + + """ + return self._is_prime + + def zero(self): + """Return the zero element of the context. + + >>> ctx = nmod_ctx.new(17) + >>> ctx.zero() + 0 + + """ + return self(0) + + def one(self): + """Return the one element of the context. + + >>> ctx = nmod_ctx.new(17) + >>> ctx.one() + 1 + + """ + return self(1) + + def __str__(self): + return f"Context for nmod with modulus: {self.modulus()}" + + def __repr__(self): + return f"nmod_ctx({self.modulus()})" + + def __call__(self, val): + """Create an nmod element from an integer. + + >>> ctx = nmod_ctx.new(17) + >>> ctx(10) + 10 + + """ + r = self.new_nmod() + if not self.any_as_nmod(&r.val, val): + raise TypeError("cannot create nmod from object of type %s" % type(val)) + return r + + +@cython.no_gc cdef class nmod(flint_scalar): """ The nmod type represents elements of Z/nZ for word-size n. @@ -54,16 +116,14 @@ cdef class nmod(flint_scalar): 3 """ - def __init__(self, val, mod): - cdef mp_limb_t m - m = mod - nmod_init(&self.mod, m) - if not any_as_nmod(&self.val, val, self.mod): + cdef nmod_ctx ctx = nmod_ctx.any_as_nmod_ctx(mod) + if not ctx.any_as_nmod(&self.val, val): raise TypeError("cannot create nmod from object of type %s" % type(val)) + self.ctx = ctx def repr(self): - return "nmod(%s, %s)" % (self.val, self.mod.n) + return "nmod(%s, %s)" % (self.val, self.ctx.mod.n) def str(self): return str(int(self.val)) @@ -72,7 +132,7 @@ cdef class nmod(flint_scalar): return int(self.val) def modulus(self): - return self.mod.n + return self.ctx.mod.n def __richcmp__(s, t, int op): cdef bint res @@ -80,13 +140,13 @@ cdef class nmod(flint_scalar): raise TypeError("nmods cannot be ordered") if typecheck(s, nmod) and typecheck(t, nmod): res = ((s).val == (t).val) and \ - ((s).mod.n == (t).mod.n) + ((s).ctx.mod.n == (t).ctx.mod.n) if op == 2: return res else: return not res elif typecheck(s, nmod) and typecheck(t, int): - res = s.val == (t % s.mod.n) + res = s.val == (t % s.ctx.mod.n) if op == 2: return res else: @@ -103,101 +163,101 @@ cdef class nmod(flint_scalar): return self def __neg__(self): - cdef nmod r = nmod.__new__(nmod) - r.mod = self.mod - r.val = nmod_neg(self.val, self.mod) + r = self.ctx.new_nmod() + r.val = nmod_neg(self.val, self.ctx.mod) return r def __add__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): - r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_add(val, (s).val, r.mod) + s2 = s + if s2.ctx.any_as_nmod(&val, t): + r = s2.ctx.new_nmod() + r.val = nmod_add(val, s2.val, s2.ctx.mod) return r return NotImplemented def __radd__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): - r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_add((s).val, val, r.mod) + s2 = s + if s2.ctx.any_as_nmod(&val, t): + r = s2.ctx.new_nmod() + r.val = nmod_add(s2.val, val, s2.ctx.mod) return r return NotImplemented def __sub__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): - r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_sub((s).val, val, r.mod) + s2 = s + if s2.ctx.any_as_nmod(&val, t): + r = s2.ctx.new_nmod() + r.val = nmod_sub(s2.val, val, s2.ctx.mod) return r return NotImplemented def __rsub__(s, t): cdef nmod r cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): - r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_sub(val, (s).val, r.mod) + s2 = s + if s2.ctx.any_as_nmod(&val, t): + r = s2.ctx.new_nmod() + r.val = nmod_sub(val, s2.val, s2.ctx.mod) return r return NotImplemented def __mul__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): - r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_mul(val, (s).val, r.mod) + s2 = s + if s2.ctx.any_as_nmod(&val, t): + r = s2.ctx.new_nmod() + r.val = nmod_mul(val, s2.val, s2.ctx.mod) return r return NotImplemented def __rmul__(s, t): - cdef nmod r + cdef nmod r, s2 cdef mp_limb_t val - if any_as_nmod(&val, t, (s).mod): - r = nmod.__new__(nmod) - r.mod = (s).mod - r.val = nmod_mul((s).val, val, r.mod) + s2 = s + if s2.ctx.any_as_nmod(&val, t): + r = s2.ctx.new_nmod() + r.val = nmod_mul(s2.val, val, s2.ctx.mod) return r return NotImplemented @staticmethod def _div_(s, t): - cdef nmod r + cdef nmod r, s2, t2 cdef mp_limb_t sval, tval - cdef nmod_t mod + cdef nmod_ctx ctx cdef ulong tinvval if typecheck(s, nmod): - mod = (s).mod - sval = (s).val - if not any_as_nmod(&tval, t, mod): + s2 = s + ctx = s2.ctx + sval = s2.val + if not ctx.any_as_nmod(&tval, t): return NotImplemented else: - mod = (t).mod - tval = (t).val - if not any_as_nmod(&sval, s, mod): + t2 = t + ctx = t2.ctx + tval = t2.val + if not ctx.any_as_nmod(&sval, s): return NotImplemented if tval == 0: - raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n)) + raise ZeroDivisionError("%s is not invertible mod %s" % (tval, ctx.mod.n)) if not s: return s - g = n_gcdinv(&tinvval, tval, mod.n) + g = n_gcdinv(&tinvval, tval, ctx.mod.n) if g != 1: - raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n)) + raise ZeroDivisionError("%s is not invertible mod %s" % (tval, ctx.mod.n)) - r = nmod.__new__(nmod) - r.mod = mod - r.val = nmod_mul(sval, tinvval, mod) + r = ctx.new_nmod() + r.val = nmod_mul(sval, tinvval, ctx.mod) return r def __truediv__(s, t): @@ -207,19 +267,21 @@ cdef class nmod(flint_scalar): return nmod._div_(t, s) def __invert__(self): - cdef nmod r + cdef nmod r, s + cdef nmod_ctx ctx cdef ulong g, inv, sval - sval = (self).val - g = n_gcdinv(&inv, sval, self.mod.n) + s = self + ctx = s.ctx + sval = s.val + g = n_gcdinv(&inv, sval, ctx.mod.n) if g != 1: - raise ZeroDivisionError("%s is not invertible mod %s" % (sval, self.mod.n)) - r = nmod.__new__(nmod) - r.mod = self.mod + raise ZeroDivisionError("%s is not invertible mod %s" % (sval, ctx.mod.n)) + r = ctx.new_nmod() r.val = inv return r def __pow__(self, exp, modulus=None): - cdef nmod r + cdef nmod r, s cdef mp_limb_t rval, mod cdef ulong g, rinv @@ -230,8 +292,10 @@ cdef class nmod(flint_scalar): if e is NotImplemented: return NotImplemented - rval = (self).val - mod = (self).mod.n + s = self + ctx = s.ctx + rval = s.val + mod = ctx.mod.n # XXX: It is not clear that it is necessary to special case negative # exponents here. The nmod_pow_fmpz function seems to handle this fine @@ -243,9 +307,8 @@ cdef class nmod(flint_scalar): rval = rinv e = -e - r = nmod.__new__(nmod) - r.mod = self.mod - r.val = nmod_pow_fmpz(rval, (e).val, self.mod) + r = ctx.new_nmod() + r.val = nmod_pow_fmpz(rval, (e).val, ctx.mod) return r def sqrt(self): @@ -267,15 +330,14 @@ cdef class nmod(flint_scalar): """ cdef nmod r cdef mp_limb_t val - r = nmod.__new__(nmod) - r.mod = self.mod + r = self.ctx.new_nmod() if self.val == 0: return r - val = n_sqrtmod(self.val, self.mod.n) + val = n_sqrtmod(self.val, self.ctx.mod.n) if val == 0: - raise DomainError("no square root exists for %s mod %s" % (self.val, self.mod.n)) + raise DomainError("no square root exists for %s mod %s" % (self.val, self.ctx.mod.n)) r.val = val return r diff --git a/src/flint/types/nmod_mat.pxd b/src/flint/types/nmod_mat.pxd index b6fe89b6..37a920b8 100644 --- a/src/flint/types/nmod_mat.pxd +++ b/src/flint/types/nmod_mat.pxd @@ -1,10 +1,125 @@ +cimport cython + +from flint.flintlib.types.flint cimport mp_limb_t, ulong +from flint.flintlib.types.nmod cimport nmod_t, nmod_mat_t + +from flint.flintlib.functions.fmpz_mat cimport ( + fmpz_mat_nrows, + fmpz_mat_ncols, + fmpz_mat_get_nmod_mat, +) +from flint.flintlib.functions.nmod_mat cimport ( + nmod_mat_init, + nmod_mat_init_set, +) + +from flint.utils.typecheck cimport typecheck from flint.flint_base.flint_base cimport flint_mat -from flint.flintlib.functions.nmod_mat cimport nmod_mat_t -from flint.flintlib.types.flint cimport mp_limb_t +from flint.types.fmpz cimport fmpz +from flint.types.fmpz_mat cimport fmpz_mat, any_as_fmpz_mat +from flint.types.nmod cimport nmod_ctx, nmod +from flint.types.nmod_poly cimport nmod_poly_ctx, nmod_poly + + +cdef dict _nmod_mat_ctx_cache + + +@cython.no_gc +cdef class nmod_mat_ctx: + cdef nmod_t mod + cdef bint _is_prime + cdef nmod_ctx scalar_ctx + cdef nmod_poly_ctx poly_ctx + + @staticmethod + cdef inline nmod_mat_ctx any_as_nmod_mat_ctx(obj): + """Convert an ``nmod_mat_ctx`` or ``int`` to an ``nmod_mat_ctx``.""" + if typecheck(obj, nmod_mat_ctx): + return obj + if typecheck(obj, int): + return nmod_mat_ctx._get_ctx(obj) + elif typecheck(obj, fmpz): + return nmod_mat_ctx._get_ctx(int(obj)) + raise TypeError("nmod_mat: expected last argument to be an nmod_mat_ctx or an integer") + + @staticmethod + cdef inline nmod_mat_ctx _get_ctx(int mod): + """Retrieve an nmod_mat context from the cache or create a new one.""" + ctx = _nmod_mat_ctx_cache.get(mod) + if ctx is None: + ctx = _nmod_mat_ctx_cache.setdefault(mod, nmod_mat_ctx._new_ctx(mod)) + return ctx + + @staticmethod + cdef inline nmod_mat_ctx _new_ctx(ulong mod): + """Create a new nmod_mat context.""" + cdef nmod_ctx scalar_ctx + cdef nmod_poly_ctx poly_ctx + cdef nmod_mat_ctx ctx + + poly_ctx = nmod_poly_ctx.new(mod) + scalar_ctx = poly_ctx.scalar_ctx + + ctx = nmod_mat_ctx.__new__(nmod_mat_ctx) + ctx.mod = scalar_ctx.mod + ctx._is_prime = scalar_ctx._is_prime + ctx.scalar_ctx = scalar_ctx + ctx.poly_ctx = poly_ctx + return ctx + + @cython.final + cdef inline int any_as_nmod(self, mp_limb_t * val, obj) except -1: + return self.scalar_ctx.any_as_nmod(val, obj) + + @cython.final + cdef inline any_as_nmod_mat(self, obj): + """Convert obj to nmod_mat or return NotImplemented.""" + cdef nmod_mat r + + if typecheck(obj, nmod_mat): + return obj + + x = any_as_fmpz_mat(obj) + if x is not NotImplemented: + r = self.new_nmod_mat(fmpz_mat_nrows((x).val), + fmpz_mat_ncols((x).val)) + fmpz_mat_get_nmod_mat(r.val, (x).val) + return r + + return NotImplemented + + @cython.final + cdef inline nmod new_nmod(self): + return self.scalar_ctx.new_nmod() + + @cython.final + cdef inline nmod_poly new_nmod_poly(self): + return self.poly_ctx.new_nmod_poly() + + @cython.final + cdef inline nmod_mat new_nmod_mat(self, ulong m, ulong n): + """New initialized nmod_mat of size m x n with context ctx.""" + cdef nmod_mat r = nmod_mat.__new__(nmod_mat) + nmod_mat_init(r.val, m, n, self.mod.n) + r.ctx = self + return r + + @cython.final + cdef inline nmod_mat new_nmod_mat_copy(self, nmod_mat other): + """New copy of nmod_mat other.""" + cdef nmod_mat r = nmod_mat.__new__(nmod_mat) + nmod_mat_init_set(r.val, other.val) + r.ctx = other.ctx + return r + + +@cython.no_gc cdef class nmod_mat(flint_mat): cdef nmod_mat_t val + cdef nmod_mat_ctx ctx + cpdef long nrows(self) cpdef long ncols(self) cpdef mp_limb_t modulus(self) diff --git a/src/flint/types/nmod_mat.pyx b/src/flint/types/nmod_mat.pyx index 60f12f5c..e51fa779 100644 --- a/src/flint/types/nmod_mat.pyx +++ b/src/flint/types/nmod_mat.pyx @@ -1,14 +1,18 @@ cimport cython -from flint.flintlib.types.flint cimport ulong, mp_limb_t -from flint.flintlib.functions.nmod cimport nmod_t +from flint.utils.typecheck cimport typecheck +from flint.pyflint cimport global_random_state -from flint.flintlib.functions.nmod_poly cimport ( - nmod_poly_init, +from flint.flint_base.flint_context cimport thectx +from flint.flint_base.flint_base cimport flint_mat + +from flint.flintlib.functions.fmpz_mat cimport ( + fmpz_mat_nrows, + fmpz_mat_ncols, + fmpz_mat_get_nmod_mat, ) -from flint.flintlib.functions.fmpz_mat cimport fmpz_mat_nrows, fmpz_mat_ncols -from flint.flintlib.functions.fmpz_mat cimport fmpz_mat_get_nmod_mat +from flint.flintlib.types.flint cimport ulong, mp_limb_t from flint.flintlib.types.nmod cimport nmod_mat_struct @@ -44,36 +48,85 @@ from flint.flintlib.functions.nmod_mat cimport ( nmod_mat_randtest, ) -from flint.utils.typecheck cimport typecheck -from flint.types.fmpz_mat cimport any_as_fmpz_mat +from flint.types.fmpz cimport fmpz from flint.types.fmpz_mat cimport fmpz_mat from flint.types.nmod cimport nmod -from flint.types.nmod cimport any_as_nmod from flint.types.nmod_poly cimport nmod_poly -from flint.pyflint cimport global_random_state -from flint.flint_base.flint_context cimport thectx -from flint.flint_base.flint_base cimport flint_mat +from flint.utils.flint_exceptions import DomainError ctx = thectx -cdef any_as_nmod_mat(obj, nmod_t mod): - cdef nmod_mat r - if typecheck(obj, nmod_mat): - return obj - x = any_as_fmpz_mat(obj) - if x is not NotImplemented: - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, - fmpz_mat_nrows((x).val), - fmpz_mat_ncols((x).val), mod.n) - fmpz_mat_get_nmod_mat(r.val, (x).val) - return r - return NotImplemented +_nmod_mat_ctx_cache = {} + + +@cython.no_gc +cdef class nmod_mat_ctx: + """ + Context object for creating :class:`~.nmod_mat` initalised + with modulus :math:`N`. + + >>> ctx = nmod_mat_ctx.new(17) + >>> M = ctx([[1,2],[3,4]]) + >>> M + [1, 2] + [3, 4] + + """ + def __init__(self, *args, **kwargs): + raise TypeError("cannot create nmod_poly_ctx directly: use nmod_poly_ctx.new()") + + @staticmethod + def new(mod): + """Get an ``nmod_poly`` context with modulus ``mod``.""" + return nmod_mat_ctx.any_as_nmod_mat_ctx(mod) + + def modulus(self): + """Get the modulus of the context. + + >>> ctx = nmod_mat_ctx.new(17) + >>> ctx.modulus() + 17 + + """ + return fmpz(self.mod.n) + + def is_prime(self): + """Check if the modulus is prime. + + >>> ctx = nmod_mat_ctx.new(17) + >>> ctx.is_prime() + True + + """ + return self._is_prime + def __str__(self): + return f"Context for nmod_mat with modulus: {self.mod.n}" + def __repr__(self): + return f"nmod_mat_ctx({self.mod.n})" + + def __call__(self, *args): + """Create an ``nmod_mat``. + + >>> mat5 = nmod_mat_ctx.new(5) + >>> M = mat5([[1,2],[3,4]]) + >>> M + [1, 2] + [3, 4] + >>> M2 = mat5(2, 3, [1,2,3,4,5,6]) + >>> M2 + [1, 2, 3] + [4, 0, 1] + + """ + return nmod_mat(*args, self) + + +@cython.no_gc cdef class nmod_mat(flint_mat): """ The nmod_mat type represents dense matrices over Z/nZ for word-size n (see @@ -81,8 +134,8 @@ cdef class nmod_mat(flint_mat): Some operations may assume that n is a prime. """ - # cdef nmod_mat_t val +# cdef nmod_mat_ctx ctx def __dealloc__(self): nmod_mat_clear(self.val) @@ -90,21 +143,28 @@ cdef class nmod_mat(flint_mat): @cython.embedsignature(False) def __init__(self, *args): cdef long m, n, i, j - cdef mp_limb_t mod + cdef nmod_mat_ctx ctx + if len(args) == 1: val = args[0] if typecheck(val, nmod_mat): nmod_mat_init_set(self.val, (val).val) + self.ctx = (val).ctx return + mod = args[-1] args = args[:-1] + + self.ctx = ctx = nmod_mat_ctx.any_as_nmod_mat_ctx(mod) + if mod == 0: raise ValueError("modulus must be nonzero") + if len(args) == 1: val = args[0] if typecheck(val, fmpz_mat): nmod_mat_init(self.val, fmpz_mat_nrows((val).val), - fmpz_mat_ncols((val).val), mod) + fmpz_mat_ncols((val).val), ctx.mod.n) fmpz_mat_get_nmod_mat(self.val, (val).val) elif isinstance(val, (list, tuple)): m = len(val) @@ -116,27 +176,27 @@ cdef class nmod_mat(flint_mat): for i from 1 <= i < m: if len(val[i]) != n: raise ValueError("input rows have different lengths") - nmod_mat_init(self.val, m, n, mod) + nmod_mat_init(self.val, m, n, ctx.mod.n) for i from 0 <= i < m: row = val[i] for j from 0 <= j < n: - x = nmod(row[j], mod) - self.val.rows[i][j] = (x).val + if not ctx.any_as_nmod(&self.val.rows[i][j], row[j]): + raise TypeError("cannot create nmod from input of type %s" % type(row[j])) else: raise TypeError("cannot create nmod_mat from input of type %s" % type(val)) elif len(args) == 2: m, n = args - nmod_mat_init(self.val, m, n, mod) + nmod_mat_init(self.val, m, n, ctx.mod.n) elif len(args) == 3: m, n, entries = args - nmod_mat_init(self.val, m, n, mod) + nmod_mat_init(self.val, m, n, ctx.mod.n) entries = list(entries) if len(entries) != m*n: raise ValueError("list of entries has the wrong length") for i from 0 <= i < m: for j from 0 <= j < n: - x = nmod(entries[i*n + j], mod) # XXX: slow - self.val.rows[i][j] = (x).val + if not ctx.any_as_nmod(&self.val.rows[i][j], entries[i*n + j]): + raise TypeError("cannot create nmod from input of type %s" % type(entries[i*n + j])) else: raise TypeError("nmod_mat: expected 1-3 arguments plus modulus") @@ -202,7 +262,8 @@ cdef class nmod_mat(flint_mat): i, j = index if i < 0 or i >= self.nrows() or j < 0 or j >= self.ncols(): raise IndexError("index %i,%i exceeds matrix dimensions" % (i, j)) - x = nmod(nmod_mat_entry(self.val, i, j), self.modulus()) # XXX: slow + x = self.ctx.new_nmod() + x.val = nmod_mat_entry(self.val, i, j) return x def __setitem__(self, index, value): @@ -211,7 +272,7 @@ cdef class nmod_mat(flint_mat): i, j = index if i < 0 or i >= self.nrows() or j < 0 or j >= self.ncols(): raise IndexError("index %i,%i exceeds matrix dimensions" % (i, j)) - if any_as_nmod(&v, value, self.val.mod): + if self.ctx.any_as_nmod(&v, value): nmod_mat_set_entry(self.val, i, j, v) else: raise TypeError("cannot set item of type %s" % type(value)) @@ -229,7 +290,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, sv.mod) + t = s.ctx.any_as_nmod_mat(t) if t is NotImplemented: return t tv = &(t).val[0] @@ -237,8 +298,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot add nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix addition") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n) + r = s.ctx.new_nmod_mat(sv.r, sv.c) nmod_mat_add(r.val, sv, tv) return r @@ -247,16 +307,15 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, sv.mod) + t = s.ctx.any_as_nmod_mat(t) if t is NotImplemented: return t tv = &(t).val[0] if sv.mod.n != tv.mod.n: - raise ValueError("cannot add nmod_mats with different moduli") + raise ValueError("cannot add nmod_mats with different moduli") # pragma: no cover if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix addition") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n) + r = s.ctx.new_nmod_mat(sv.r, sv.c) nmod_mat_add(r.val, sv, tv) return r @@ -265,7 +324,7 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, sv.mod) + t = s.ctx.any_as_nmod_mat(t) if t is NotImplemented: return t tv = &(t).val[0] @@ -273,8 +332,7 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot subtract nmod_mats with different moduli") if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix subtraction") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n) + r = s.ctx.new_nmod_mat(sv.r, sv.c) nmod_mat_sub(r.val, sv, tv) return r @@ -283,22 +341,21 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv sv = &(s).val[0] - t = any_as_nmod_mat(t, sv.mod) + t = s.ctx.any_as_nmod_mat(t) if t is NotImplemented: return t tv = &(t).val[0] if sv.mod.n != tv.mod.n: - raise ValueError("cannot subtract nmod_mats with different moduli") + raise ValueError("cannot subtract nmod_mats with different moduli") # pragma: no cover if sv.r != tv.r or sv.c != tv.c: raise ValueError("incompatible shapes for matrix subtraction") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n) + r = s.ctx.new_nmod_mat(sv.r, sv.c) nmod_mat_sub(r.val, tv, sv) return r cdef __mul_nmod(self, mp_limb_t c): - cdef nmod_mat r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, self.val.r, self.val.c, self.val.mod.n) + cdef nmod_mat r + r = self.ctx.new_nmod_mat(self.val.r, self.val.c) nmod_mat_scalar_mul(r.val, self.val, c) return r @@ -307,10 +364,10 @@ cdef class nmod_mat(flint_mat): cdef nmod_mat_struct *sv cdef nmod_mat_struct *tv cdef mp_limb_t c - sv = &(s).val[0] - u = any_as_nmod_mat(t, sv.mod) + sv = &s.val[0] + u = s.ctx.any_as_nmod_mat(t) if u is NotImplemented: - if any_as_nmod(&c, t, sv.mod): + if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) return NotImplemented tv = &(u).val[0] @@ -318,18 +375,15 @@ cdef class nmod_mat(flint_mat): raise ValueError("cannot multiply nmod_mats with different moduli") if sv.c != tv.r: raise ValueError("incompatible shapes for matrix multiplication") - r = nmod_mat.__new__(nmod_mat) - nmod_mat_init(r.val, sv.r, tv.c, sv.mod.n) + r = s.ctx.new_nmod_mat(sv.r, tv.c) nmod_mat_mul(r.val, sv, tv) return r def __rmul__(s, t): - cdef nmod_mat_struct *sv cdef mp_limb_t c - sv = &(s).val[0] - if any_as_nmod(&c, t, sv.mod): + if s.ctx.any_as_nmod(&c, t): return (s).__mul_nmod(c) - u = any_as_nmod_mat(t, sv.mod) + u = s.ctx.any_as_nmod_mat(t) if u is NotImplemented: return u return u * s @@ -342,27 +396,34 @@ cdef class nmod_mat(flint_mat): if m is not None: raise NotImplementedError("modular matrix exponentiation") if e < 0: + if not self.ctx._is_prime: + raise DomainError("negative matrix power needs prime modulus") self = self.inv() e = -e ee = e - t = nmod_mat(self) # XXX + t = self.ctx.new_nmod_mat_copy(self) nmod_mat_pow(t.val, t.val, ee) return t @staticmethod def _div_(nmod_mat s, t): cdef mp_limb_t v - if not any_as_nmod(&v, t, s.val.mod): + if not s.ctx.any_as_nmod(&v, t): return NotImplemented t = nmod(v, s.val.mod.n) - return s * (~t) + try: + tinv = ~t + except ZeroDivisionError: + # XXX: Maybe nmod.__invert__ should raise DomainError instead? + if t == 0: + raise ZeroDivisionError("division by zero") + else: + raise DomainError("nmod_mat division: modulus must be prime") + return s * tinv def __truediv__(s, t): return nmod_mat._div_(s, t) - def __div__(s, t): - return nmod_mat._div_(s, t) - def det(self): """ Returns the determinant of self as an nmod. @@ -371,9 +432,13 @@ cdef class nmod_mat(flint_mat): 15 """ + cdef nmod r if not nmod_mat_is_square(self.val): raise ValueError("matrix must be square") - return nmod(nmod_mat_det(self.val), self.modulus()) + + r = self.ctx.new_nmod() + r.val = nmod_mat_det(self.val) + return r def inv(self): """ @@ -387,11 +452,14 @@ cdef class nmod_mat(flint_mat): """ cdef nmod_mat u + if not nmod_mat_is_square(self.val): raise ValueError("matrix must be square") - u = nmod_mat.__new__(nmod_mat) - nmod_mat_init(u.val, nmod_mat_nrows(self.val), - nmod_mat_ncols(self.val), self.val.mod.n) + if not self.ctx._is_prime: + raise DomainError("nmod_mat inv: modulus must be prime") + + u = self.ctx.new_nmod_mat(nmod_mat_nrows(self.val), + nmod_mat_ncols(self.val)) if not nmod_mat_inv(u.val, self.val): raise ZeroDivisionError("matrix is singular") return u @@ -406,9 +474,8 @@ cdef class nmod_mat(flint_mat): [2, 5] """ cdef nmod_mat u - u = nmod_mat.__new__(nmod_mat) - nmod_mat_init(u.val, nmod_mat_ncols(self.val), - nmod_mat_nrows(self.val), self.val.mod.n) + u = self.ctx.new_nmod_mat(nmod_mat_ncols(self.val), + nmod_mat_nrows(self.val)) nmod_mat_transpose(u.val, self.val) return u @@ -437,15 +504,18 @@ cdef class nmod_mat(flint_mat): """ cdef nmod_mat u cdef int result - t = any_as_nmod_mat(other, self.val.mod) + t = self.ctx.any_as_nmod_mat(other) if t is NotImplemented: raise TypeError("cannot convert input to nmod_mat") if (nmod_mat_nrows(self.val) != nmod_mat_ncols(self.val) or nmod_mat_nrows(self.val) != nmod_mat_nrows((t).val)): raise ValueError("need a square system and compatible right hand side") - u = nmod_mat.__new__(nmod_mat) - nmod_mat_init(u.val, nmod_mat_nrows((t).val), - nmod_mat_ncols((t).val), self.val.mod.n) + # XXX: Should check for same modulus. + if not self.ctx._is_prime: + raise DomainError("nmod_mat solve: modulus must be prime") + + u = self.ctx.new_nmod_mat(nmod_mat_nrows((t).val), + nmod_mat_ncols((t).val)) result = nmod_mat_solve(u.val, self.val, (t).val) if not result: raise ZeroDivisionError("singular matrix in solve()") @@ -472,11 +542,13 @@ cdef class nmod_mat(flint_mat): [0, 0, 0] """ + if not self.ctx._is_prime: + raise DomainError("rref only works for prime moduli") + if inplace: res = self else: - res = nmod_mat.__new__(nmod_mat) - nmod_mat_init_set((res).val, self.val) + res = self.ctx.new_nmod_mat_copy(self) rank = nmod_mat_rref((res).val) return res, rank @@ -488,6 +560,8 @@ cdef class nmod_mat(flint_mat): >>> M.rank() 2 """ + if not self.ctx._is_prime: + raise DomainError("rank only works for prime moduli") return nmod_mat_rank(self.val) def nullspace(self): @@ -514,8 +588,8 @@ cdef class nmod_mat(flint_mat): """ cdef nmod_mat res - res = nmod_mat.__new__(nmod_mat) - nmod_mat_init(res.val, nmod_mat_ncols(self.val), nmod_mat_ncols(self.val), self.val.mod.n) + res = self.ctx.new_nmod_mat(nmod_mat_ncols(self.val), + nmod_mat_ncols(self.val)) nullity = nmod_mat_nullspace(res.val, self.val) return res, nullity @@ -533,8 +607,7 @@ cdef class nmod_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat charpoly: matrix must be square") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init(res.val, self.val.mod.n) + res = self.ctx.new_nmod_poly() nmod_mat_charpoly(res.val, self.val) return res @@ -553,8 +626,9 @@ cdef class nmod_mat(flint_mat): if self.nrows() != self.ncols(): raise ValueError("fmpz_mod_mat minpoly: matrix must be square") + if not self.ctx._is_prime: + raise DomainError("minpoly only works for prime moduli") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init(res.val, self.val.mod.n) + res = self.ctx.new_nmod_poly() nmod_mat_minpoly(res.val, self.val) return res diff --git a/src/flint/types/nmod_poly.pxd b/src/flint/types/nmod_poly.pxd index fe872bde..dd47c5e4 100644 --- a/src/flint/types/nmod_poly.pxd +++ b/src/flint/types/nmod_poly.pxd @@ -1,10 +1,116 @@ +cimport cython + +from cpython.list cimport PyList_GET_SIZE + +from flint.flintlib.types.flint cimport mp_limb_t, ulong +from flint.flintlib.types.nmod cimport nmod_t, nmod_poly_t + +from flint.flintlib.functions.fmpz_poly cimport fmpz_poly_get_nmod_poly +from flint.flintlib.functions.nmod_poly cimport ( + nmod_poly_t, + nmod_poly_init_preinv, + nmod_poly_fit_length, + nmod_poly_set_coeff_ui, +) + +from flint.utils.typecheck cimport typecheck from flint.flint_base.flint_base cimport flint_poly -from flint.flintlib.functions.nmod_poly cimport nmod_poly_t -from flint.flintlib.types.flint cimport mp_limb_t +from flint.types.fmpz_poly cimport fmpz_poly, any_as_fmpz_poly +from flint.types.nmod cimport nmod_ctx, nmod + + +cdef dict _nmod_poly_ctx_cache = {} + + +@cython.no_gc +cdef class nmod_poly_ctx: + cdef nmod_t mod + cdef bint _is_prime + cdef nmod_ctx scalar_ctx + + @staticmethod + cdef inline nmod_poly_ctx any_as_nmod_poly_ctx(obj): + """Convert an ``nmod_poly_ctx`` or ``int`` to an ``nmod_poly_ctx``.""" + if typecheck(obj, nmod_poly_ctx): + return obj + if typecheck(obj, int): + return nmod_poly_ctx._get_ctx(obj) + raise TypeError("Invalid context/modulus for nmod_poly: %s" % obj) + + @staticmethod + cdef inline nmod_poly_ctx _get_ctx(int mod): + """Retrieve an nmod_poly context from the cache or create a new one.""" + ctx = _nmod_poly_ctx_cache.get(mod) + if ctx is None: + ctx = _nmod_poly_ctx_cache.setdefault(mod, nmod_poly_ctx._new_ctx(mod)) + return ctx + @staticmethod + cdef inline nmod_poly_ctx _new_ctx(ulong mod): + """Create a new nmod_poly context.""" + cdef nmod_ctx scalar_ctx + cdef nmod_poly_ctx ctx + scalar_ctx = nmod_ctx.new(mod) + + ctx = nmod_poly_ctx.__new__(nmod_poly_ctx) + ctx.mod = scalar_ctx.mod + ctx._is_prime = scalar_ctx._is_prime + ctx.scalar_ctx = scalar_ctx + + return ctx + + @cython.final + cdef inline int any_as_nmod(self, mp_limb_t * val, obj) except -1: + return self.scalar_ctx.any_as_nmod(val, obj) + + @cython.final + cdef inline any_as_nmod_poly(self, obj): + cdef nmod_poly r + cdef mp_limb_t v + # XXX: should check that modulus is the same here, and not all over the place + if typecheck(obj, nmod_poly): + return obj + if self.any_as_nmod(&v, obj): + r = self.new_nmod_poly() + nmod_poly_set_coeff_ui(r.val, 0, v) + return r + x = any_as_fmpz_poly(obj) + if x is not NotImplemented: + r = self.new_nmod_poly() + fmpz_poly_get_nmod_poly(r.val, (x).val) + return r + return NotImplemented + + @cython.final + cdef inline nmod new_nmod(self): + return self.scalar_ctx.new_nmod() + + @cython.final + cdef inline nmod_poly new_nmod_poly(self): + cdef nmod_poly p = nmod_poly.__new__(nmod_poly) + nmod_poly_init_preinv(p.val, self.mod.n, self.mod.ninv) + p.ctx = self + return p + + @cython.final + cdef inline nmod_poly_set_list(self, nmod_poly_t poly, list val): + cdef long i, n + cdef mp_limb_t v + n = PyList_GET_SIZE(val) + nmod_poly_fit_length(poly, n) + for i from 0 <= i < n: + if self.any_as_nmod(&v, val[i]): + nmod_poly_set_coeff_ui(poly, i, v) + else: + raise TypeError("unsupported coefficient in list") + + +@cython.no_gc cdef class nmod_poly(flint_poly): cdef nmod_poly_t val + cdef nmod_poly_ctx ctx + cpdef long length(self) cpdef long degree(self) cpdef mp_limb_t modulus(self) diff --git a/src/flint/types/nmod_poly.pyx b/src/flint/types/nmod_poly.pyx index 8727d601..486ad3f4 100644 --- a/src/flint/types/nmod_poly.pyx +++ b/src/flint/types/nmod_poly.pyx @@ -1,52 +1,150 @@ -from cpython.list cimport PyList_GET_SIZE +cimport cython + +from flint.flintlib.types.flint cimport mp_limb_t, ulong, slong +from flint.flintlib.functions.fmpz_poly cimport fmpz_poly_get_nmod_poly +from flint.flintlib.functions.nmod_poly cimport ( + nmod_poly_init, + nmod_poly_clear, + nmod_poly_set, + nmod_poly_fit_length, + nmod_poly_zero, + nmod_poly_set_coeff_ui, + nmod_poly_get_coeff_ui, + nmod_poly_length, + nmod_poly_degree, + nmod_poly_modulus, + nmod_poly_equal, + nmod_poly_is_zero, + nmod_poly_is_one, + nmod_poly_is_gen, + nmod_poly_reverse, + nmod_poly_evaluate_nmod, + nmod_poly_derivative, + nmod_poly_integral, + nmod_poly_neg, + nmod_poly_add, + nmod_poly_sub, + nmod_poly_mul, + nmod_poly_div, + nmod_poly_divrem, + nmod_poly_pow, + nmod_poly_compose, + nmod_poly_compose_mod, + nmod_poly_powmod_ui_binexp, + nmod_poly_powmod_fmpz_binexp_preinv, + nmod_poly_powmod_x_fmpz_preinv, + nmod_poly_inv_series, + nmod_poly_sqrt, + nmod_poly_xgcd, + nmod_poly_gcd, + nmod_poly_deflation, + nmod_poly_deflate, +) +from flint.flintlib.functions.nmod_poly_factor cimport ( + nmod_poly_factor_t, + nmod_poly_factor_init, + nmod_poly_factor_clear, + nmod_poly_factor, + nmod_poly_factor_with_berlekamp, + nmod_poly_factor_with_cantor_zassenhaus, + nmod_poly_factor_squarefree, +) + from flint.flint_base.flint_base cimport flint_poly from flint.utils.typecheck cimport typecheck + from flint.types.fmpz cimport fmpz, any_as_fmpz -from flint.types.fmpz_poly cimport any_as_fmpz_poly from flint.types.fmpz_poly cimport fmpz_poly -from flint.types.nmod cimport any_as_nmod from flint.types.nmod cimport nmod -from flint.flintlib.functions.nmod cimport nmod_init -from flint.flintlib.functions.nmod_poly cimport * -from flint.flintlib.functions.nmod_poly_factor cimport * -from flint.flintlib.functions.fmpz_poly cimport fmpz_poly_get_nmod_poly - from flint.utils.flint_exceptions import DomainError -cdef any_as_nmod_poly(obj, nmod_t mod): - cdef nmod_poly r - cdef mp_limb_t v - # XXX: should check that modulus is the same here, and not all over the place - if typecheck(obj, nmod_poly): - return obj - if any_as_nmod(&v, obj, mod): - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init(r.val, mod.n) - nmod_poly_set_coeff_ui(r.val, 0, v) +_nmod_poly_ctx_cache = {} + + +@cython.no_gc +cdef class nmod_poly_ctx: + """ + Context object for creating :class:`~.nmod_poly` initalised + with modulus :math:`N`. + + >>> nmod_poly_ctx.new(17) + nmod_poly_ctx(17) + + """ + def __init__(self, *args, **kwargs): + raise TypeError("cannot create nmod_poly_ctx directly: use nmod_poly_ctx.new()") + + @staticmethod + def new(mod): + """Get an ``nmod_poly`` context with modulus ``mod``.""" + return nmod_poly_ctx.any_as_nmod_poly_ctx(mod) + + def modulus(self): + """Get the modulus of the context. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx.modulus() + 17 + + """ + return fmpz(self.mod.n) + + def is_prime(self): + """Check if the modulus is prime. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx.is_prime() + True + + """ + return self._is_prime + + def zero(self): + """Return the zero ``nmod_poly``. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx.zero() + 0 + + """ + cdef nmod_poly r = self.new_nmod_poly() + nmod_poly_zero(r.val) return r - x = any_as_fmpz_poly(obj) - if x is not NotImplemented: - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this? - fmpz_poly_get_nmod_poly(r.val, (x).val) + + def one(self): + """Return the one ``nmod_poly``. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx.one() + 1 + + """ + cdef nmod_poly r = self.new_nmod_poly() + nmod_poly_set_coeff_ui(r.val, 0, 1) return r - return NotImplemented - -cdef nmod_poly_set_list(nmod_poly_t poly, list val): - cdef long i, n - cdef nmod_t mod - cdef mp_limb_t v - nmod_init(&mod, nmod_poly_modulus(poly)) # XXX - n = PyList_GET_SIZE(val) - nmod_poly_fit_length(poly, n) - for i from 0 <= i < n: - if any_as_nmod(&v, val[i], mod): - nmod_poly_set_coeff_ui(poly, i, v) - else: - raise TypeError("unsupported coefficient in list") + def __str__(self): + return f"Context for nmod_poly with modulus: {self.mod.n}" + + def __repr__(self): + return f"nmod_poly_ctx({self.mod.n})" + + def __call__(self, arg): + """Create an ``nmod_poly``. + + >>> ctx = nmod_poly_ctx.new(17) + >>> ctx(10) + 10 + >>> ctx([1,2,3]) + 3*x^2 + 2*x + 1 + + """ + return nmod_poly(arg, self) + + +@cython.no_gc cdef class nmod_poly(flint_poly): """ The nmod_poly type represents dense univariate polynomials @@ -77,24 +175,29 @@ cdef class nmod_poly(flint_poly): def __dealloc__(self): nmod_poly_clear(self.val) - def __init__(self, val=None, ulong mod=0): + def __init__(self, val=None, mod=0): cdef ulong m2 cdef mp_limb_t v + cdef nmod_poly_ctx ctx + if typecheck(val, nmod_poly): m2 = nmod_poly_modulus((val).val) if m2 != mod: raise ValueError("different moduli!") nmod_poly_init(self.val, m2) nmod_poly_set(self.val, (val).val) + self.ctx = (val).ctx else: if mod == 0: raise ValueError("a nonzero modulus is required") - nmod_poly_init(self.val, mod) + ctx = nmod_poly_ctx.any_as_nmod_poly_ctx(mod) + self.ctx = ctx + nmod_poly_init(self.val, ctx.mod.n) if typecheck(val, fmpz_poly): fmpz_poly_get_nmod_poly(self.val, (val).val) elif typecheck(val, list): - nmod_poly_set_list(self.val, val) - elif any_as_nmod(&v, val, self.val.mod): + ctx.nmod_poly_set_list(self.val, val) + elif ctx.any_as_nmod(&v, val): nmod_poly_fit_length(self.val, 1) nmod_poly_set_coeff_ui(self.val, 0, v) else: @@ -113,10 +216,12 @@ cdef class nmod_poly(flint_poly): return nmod_poly_modulus(self.val) def __richcmp__(s, t, int op): + cdef mp_limb_t v + cdef slong length cdef bint res if op != 2 and op != 3: raise TypeError("nmod_polys cannot be ordered") - if typecheck(s, nmod_poly) and typecheck(t, nmod_poly): + if typecheck(t, nmod_poly): if (s).val.mod.n != (t).val.mod.n: res = False else: @@ -125,22 +230,19 @@ cdef class nmod_poly(flint_poly): return res if op == 3: return not res - else: - if not typecheck(s, nmod_poly): - s, t = t, s - try: - t = nmod_poly([t], (s).val.mod.n) - except TypeError: - pass - if typecheck(s, nmod_poly) and typecheck(t, nmod_poly): - if (s).val.mod.n != (t).val.mod.n: - res = False - else: - res = nmod_poly_equal((s).val, (t).val) - if op == 2: - return res - if op == 3: - return not res + + # zero or constant poly can be equal to a scalar + length = nmod_poly_length(s.val) + if length <= 1 and s.ctx.any_as_nmod(&v, t): + if length == 0: + res = (v == 0) + else: + res = (v == nmod_poly_get_coeff_ui(s.val, 0)) + if op == 2: + return res + if op == 3: + return not res + return NotImplemented def __iter__(self): @@ -175,7 +277,7 @@ cdef class nmod_poly(flint_poly): cdef mp_limb_t v if i < 0: raise ValueError("cannot assign to index < 0 of polynomial") - if any_as_nmod(&v, x, self.val.mod): + if self.ctx.any_as_nmod(&v, x): nmod_poly_set_coeff_ui(self.val, i, v) else: raise TypeError("cannot set element of type %s" % type(x)) @@ -219,8 +321,7 @@ cdef class nmod_poly(flint_poly): else: length = nmod_poly_length(self.val) - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res = self.ctx.new_nmod_poly() nmod_poly_reverse(res.val, self.val, length) return res @@ -242,10 +343,8 @@ cdef class nmod_poly(flint_poly): else: cu = nmod_poly_get_coeff_ui(self.val, d) - c = nmod.__new__(nmod) - c.mod = self.val.mod + c = self.ctx.new_nmod() c.val = cu - return c def inverse_series_trunc(self, slong n): @@ -264,12 +363,13 @@ cdef class nmod_poly(flint_poly): if n <= 0: raise ValueError(f"n = {n} must be positive") - if self.is_zero(): - raise ValueError("cannot invert the zero element") + if nmod_poly_get_coeff_ui(self.val, 0) == 0: + raise ZeroDivisionError("nmod_poly inverse_series_trunc: leading coefficient is zero") - cdef nmod_poly res - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + if not self.ctx._is_prime: + raise DomainError(f"nmod_poly inverse_series_trunc: modulus {self.ctx.mod.n} is not prime") + + cdef nmod_poly res = self.ctx.new_nmod_poly() nmod_poly_inv_series(res.val, self.val, n) return res @@ -288,11 +388,10 @@ cdef class nmod_poly(flint_poly): 9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1 """ cdef nmod_poly res - other = any_as_nmod_poly(other, (self).val.mod) + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res = self.ctx.new_nmod_poly() nmod_poly_compose(res.val, self.val, (other).val) return res @@ -313,45 +412,42 @@ cdef class nmod_poly(flint_poly): 147*x^3 + 159*x^2 + 4*x + 7 """ cdef nmod_poly res - g = any_as_nmod_poly(other, self.val.mod) + g = self.ctx.any_as_nmod_poly(other) if g is NotImplemented: raise TypeError(f"cannot convert other = {other} to nmod_poly") - h = any_as_nmod_poly(modulus, self.val.mod) + h = self.ctx.any_as_nmod_poly(modulus) if h is NotImplemented: raise TypeError(f"cannot convert modulus = {modulus} to nmod_poly") if modulus.is_zero(): raise ZeroDivisionError("cannot reduce modulo zero") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res = self.ctx.new_nmod_poly() nmod_poly_compose_mod(res.val, self.val, (other).val, (modulus).val) return res def __call__(self, other): + cdef nmod_poly r cdef mp_limb_t c - if any_as_nmod(&c, other, self.val.mod): + if self.ctx.any_as_nmod(&c, other): v = nmod(0, self.modulus()) (v).val = nmod_poly_evaluate_nmod(self.val, c) return v - t = any_as_nmod_poly(other, self.val.mod) + t = self.ctx.any_as_nmod_poly(other) if t is not NotImplemented: - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv((r).val, self.val.mod.n, self.val.mod.ninv) - nmod_poly_compose((r).val, self.val, (t).val) + r = self.ctx.new_nmod_poly() + nmod_poly_compose(r.val, self.val, (t).val) return r raise TypeError("cannot call nmod_poly with input of type %s", type(other)) def derivative(self): - cdef nmod_poly res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + cdef nmod_poly res = self.ctx.new_nmod_poly() nmod_poly_derivative(res.val, self.val) return res def integral(self): - cdef nmod_poly res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + cdef nmod_poly res = self.ctx.new_nmod_poly() nmod_poly_integral(res.val, self.val) return res @@ -359,20 +455,18 @@ cdef class nmod_poly(flint_poly): return self def __neg__(self): - cdef nmod_poly r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, self.val.mod.n, self.val.mod.ninv) + cdef nmod_poly r = self.ctx.new_nmod_poly() nmod_poly_neg(r.val, self.val) return r def _add_(s, t): cdef nmod_poly r - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot add nmod_polys with different moduli") - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) + r = s.ctx.new_nmod_poly() nmod_poly_add(r.val, (s).val, (t).val) return r @@ -386,32 +480,30 @@ cdef class nmod_poly(flint_poly): cdef nmod_poly r if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot subtract nmod_polys with different moduli") - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) + r = s.ctx.new_nmod_poly() nmod_poly_sub(r.val, (s).val, (t).val) return r def __sub__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return s._sub_(t) def __rsub__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return t._sub_(s) def _mul_(s, t): cdef nmod_poly r - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot multiply nmod_polys with different moduli") - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) + r = s.ctx.new_nmod_poly() nmod_poly_mul(r.val, (s).val, (t).val) return r @@ -422,7 +514,7 @@ cdef class nmod_poly(flint_poly): return s._mul_(t) def __truediv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t res, r = s._divmod_(t) @@ -431,7 +523,7 @@ cdef class nmod_poly(flint_poly): return res def __rtruediv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t res, r = t._divmod_(s) @@ -441,48 +533,53 @@ cdef class nmod_poly(flint_poly): def _floordiv_(s, t): cdef nmod_poly r + if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot divide nmod_polys with different moduli") if nmod_poly_is_zero((t).val): raise ZeroDivisionError("polynomial division by zero") - r = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(r.val, (t).val.mod.n, (t).val.mod.ninv) + if not s.ctx._is_prime: + raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime") + + r = s.ctx.new_nmod_poly() nmod_poly_div(r.val, (s).val, (t).val) return r def __floordiv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return s._floordiv_(t) def __rfloordiv__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return t._floordiv_(s) def _divmod_(s, t): cdef nmod_poly P, Q + if (s).val.mod.n != (t).val.mod.n: raise ValueError("cannot divide nmod_polys with different moduli") if nmod_poly_is_zero((t).val): raise ZeroDivisionError("polynomial division by zero") - P = nmod_poly.__new__(nmod_poly) - Q = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(P.val, (t).val.mod.n, (t).val.mod.ninv) - nmod_poly_init_preinv(Q.val, (t).val.mod.n, (t).val.mod.ninv) + if not s.ctx._is_prime: + raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime") + + P = s.ctx.new_nmod_poly() + Q = s.ctx.new_nmod_poly() nmod_poly_divrem(P.val, Q.val, (s).val, (t).val) return P, Q def __divmod__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return s._divmod_(t) def __rdivmod__(s, t): - t = any_as_nmod_poly(t, (s).val.mod) + t = s.ctx.any_as_nmod_poly(t) if t is NotImplemented: return t return t._divmod_(s) @@ -499,8 +596,7 @@ cdef class nmod_poly(flint_poly): return self.pow_mod(exp, mod) if exp < 0: raise ValueError("negative exponent") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, (self).val.mod.n, (self).val.mod.ninv) + res = self.ctx.new_nmod_poly() nmod_poly_pow(res.val, self.val, exp) return res @@ -517,7 +613,6 @@ cdef class nmod_poly(flint_poly): >>> f = 30*x**6 + 104*x**5 + 76*x**4 + 33*x**3 + 70*x**2 + 44*x + 65 >>> g = 43*x**6 + 91*x**5 + 77*x**4 + 113*x**3 + 71*x**2 + 132*x + 60 >>> mod = x**4 + 93*x**3 + 78*x**2 + 72*x + 149 - >>> >>> f.pow_mod(123, mod) 3*x^3 + 25*x^2 + 115*x + 161 >>> f.pow_mod(2**64, mod) @@ -528,32 +623,31 @@ cdef class nmod_poly(flint_poly): """ cdef nmod_poly res - if e < 0: - raise ValueError("Exponent must be non-negative") - - modulus = any_as_nmod_poly(modulus, (self).val.mod) + modulus = self.ctx.any_as_nmod_poly(modulus) if modulus is NotImplemented: raise TypeError("cannot convert input to nmod_poly") + # For larger exponents we need an fmpz + e_fmpz = any_as_fmpz(e) + if e_fmpz is NotImplemented: + raise TypeError(f"exponent cannot be cast to an fmpz type: {e}") + + if e < 0: + raise ValueError("Exponent must be non-negative") + # Output polynomial - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + res = self.ctx.new_nmod_poly() # For small exponents, use a simple binary exponentiation method if e.bit_length() < 32: nmod_poly_powmod_ui_binexp( - res.val, self.val, e, (modulus).val + res.val, self.val, int(e), (modulus).val ) return res - # For larger exponents we need to cast e to an fmpz first - e_fmpz = any_as_fmpz(e) - if e_fmpz is NotImplemented: - raise TypeError(f"exponent cannot be cast to an fmpz type: {e}") - # To optimise powering, we precompute the inverse of the reverse of the modulus if mod_rev_inv is not None: - mod_rev_inv = any_as_nmod_poly(mod_rev_inv, (self).val.mod) + mod_rev_inv = self.ctx.any_as_nmod_poly(mod_rev_inv) if mod_rev_inv is NotImplemented: raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial") else: @@ -580,30 +674,52 @@ cdef class nmod_poly(flint_poly): >>> (A * B).gcd(B) * 5 5*x^2 + x + 4 + The modulus must be prime. """ cdef nmod_poly res - other = any_as_nmod_poly(other, (self).val.mod) + + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to nmod_poly") if self.val.mod.n != (other).val.mod.n: raise ValueError("moduli must be the same") - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) + if not self.ctx._is_prime: + raise DomainError("nmod_poly gcd: modulus {self.ctx.mod.n} is not prime") + + res = self.ctx.new_nmod_poly() nmod_poly_gcd(res.val, self.val, (other).val) return res def xgcd(self, other): + r""" + Computes the extended gcd of self and other: (`G`, `S`, `T`) + where `G` is the ``gcd(self, other)`` and `S`, `T` are such that: + + :math:`G = \textrm{self}*S + \textrm{other}*T` + + >>> f = nmod_poly([143, 19, 37, 138, 102, 127, 95], 163) + >>> g = nmod_poly([139, 9, 35, 154, 87, 120, 24], 163) + >>> f.xgcd(g) + (x^3 + 128*x^2 + 123*x + 91, 17*x^2 + 49*x + 104, 21*x^2 + 5*x + 25) + + The modulus must be prime. + """ cdef nmod_poly res1, res2, res3 - other = any_as_nmod_poly(other, (self).val.mod) + + other = self.ctx.any_as_nmod_poly(other) if other is NotImplemented: raise TypeError("cannot convert input to fmpq_poly") - res1 = nmod_poly.__new__(nmod_poly) - res2 = nmod_poly.__new__(nmod_poly) - res3 = nmod_poly.__new__(nmod_poly) - nmod_poly_init(res1.val, (self).val.mod.n) - nmod_poly_init(res2.val, (self).val.mod.n) - nmod_poly_init(res3.val, (self).val.mod.n) + if self.val.mod.n != (other).val.mod.n: + raise ValueError("moduli must be the same") + if not self.ctx._is_prime: + raise DomainError("nmod_poly xgcd: modulus {self.ctx.mod.n} is not prime") + + res1 = self.ctx.new_nmod_poly() + res2 = self.ctx.new_nmod_poly() + res3 = self.ctx.new_nmod_poly() + nmod_poly_xgcd(res1.val, res2.val, res3.val, self.val, (other).val) + return (res1, res2, res3) def factor(self, algorithm=None): @@ -628,11 +744,14 @@ cdef class nmod_poly(flint_poly): >>> nmod_poly([3,2,1,2,3], 7).factor(algorithm='cantor-zassenhaus') (3, [(x + 4, 1), (x + 2, 1), (x^2 + 4*x + 1, 1)]) + The modulus must be prime. """ if algorithm is None: algorithm = 'irreducible' elif algorithm not in ('berlekamp', 'cantor-zassenhaus'): raise ValueError(f"unknown factorization algorithm: {algorithm}") + if not self.ctx._is_prime: + raise DomainError(f"nmod_poly factor: modulus {self.ctx.mod.n} is not prime") return self._factor(algorithm) def factor_squarefree(self): @@ -651,10 +770,14 @@ cdef class nmod_poly(flint_poly): (2, [(x, 2), (x + 5, 2), (x + 1, 3)]) """ + if not self.ctx._is_prime: + raise DomainError(f"nmod_poly factor_squarefree: modulus {self.ctx.mod.n} is not prime") return self._factor('squarefree') def _factor(self, factor_type): cdef nmod_poly_factor_t fac + cdef nmod_poly u + cdef nmod c cdef mp_limb_t lead cdef int i @@ -674,16 +797,13 @@ cdef class nmod_poly(flint_poly): res = [None] * fac.num for 0 <= i < fac.num: - u = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv((u).val, - (self).val.mod.n, (self).val.mod.ninv) - nmod_poly_set((u).val, &fac.p[i]) + u = self.ctx.new_nmod_poly() + nmod_poly_set(u.val, &fac.p[i]) exp = fac.exp[i] res[i] = (u, exp) - c = nmod.__new__(nmod) - (c).mod = self.val.mod - (c).val = lead + c = self.ctx.new_nmod() + c.val = lead nmod_poly_factor_clear(fac) @@ -692,13 +812,17 @@ cdef class nmod_poly(flint_poly): def sqrt(nmod_poly self): """Return exact square root or ``None``. """ cdef nmod_poly res - res = nmod_poly.__new__(nmod_poly) - nmod_poly_init_preinv(res.val, self.val.mod.n, self.val.mod.ninv) - if nmod_poly_sqrt(res.val, self.val): - return res - else: + + if not self.ctx._is_prime: + raise DomainError(f"nmod_poly sqrt: modulus {self.ctx.mod.n} is not prime") + + res = self.ctx.new_nmod_poly() + + if not nmod_poly_sqrt(res.val, self.val): raise DomainError(f"Cannot compute square root of {self}") + return res + def deflation(self): cdef nmod_poly v cdef ulong n @@ -708,8 +832,7 @@ cdef class nmod_poly(flint_poly): if n == 1: return self, int(n) else: - v = nmod_poly.__new__(nmod_poly) - nmod_poly_init(v.val, self.val.mod.n) + v = self.ctx.new_nmod_poly() nmod_poly_deflate(v.val, self.val, n) return v, int(n)