Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Commit

Permalink
#10480: fast PowerSeries_poly multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
lftabera authored and Frédéric Chapoton committed Nov 2, 2014
1 parent 8b95db3 commit af07389
Show file tree
Hide file tree
Showing 2 changed files with 345 additions and 2 deletions.
301 changes: 301 additions & 0 deletions src/sage/rings/polynomial/polynomial_element.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,86 @@ cdef class Polynomial(CommutativeAlgebraElement):
return self._parent(do_karatsuba(f,g, K_threshold, 0, 0, n))
return self._parent(do_karatsuba_different_size(f,g, K_threshold))

def _mul_trunc(self, right, prec, K_threshold = None):
"""
Returns the product of two polynomials truncating orders of
the polynomial variable greater or equal to h.
It uses a Karatsuba based truncated multiplication with a threshold to
fall back to a quadratic truncation multiplication algorithm.
INPUT:
- ``self``: polynomial
- ``right``: polynomial
- ``prec``: desired precission, all terms of ``self*right`` of degree >=
prec are discarded
- ``K_threshold``: threshold for Karatsuba multiplication. If one of the
degrees of the polynomials is <= K_threshold, an truncated algorithm
based on classical multiplication is used.
REFERENCE:
- Hanrot, G.; Zimmermann, P. A long note on Mulder's short product. J.
Symbolic Comput. 37 (2004), no. 3, 391-401.
EXAMPLES::
sage: R.<a,b> = QQ[]
sage: K.<t> = R[]
sage: p = 1 + a*t + b*t^2
sage: p._mul_trunc(p,3)
(a^2 + 2*b)*t^2 + 2*a*t + 1
sage: p1 = K.random_element(ZZ.random_element(20,60))
sage: p2 = K.random_element(ZZ.random_element(20,60))
sage: h = p1*p2
sage: p1._mul_trunc(p2,10) == h
False
sage: p1._mul_trunc(p2,10) == h[:10]
True
sage: p1._mul_trunc(p2,50) == h[:50]
True
sage: p1._mul_trunc(p2,150) == h
True
This also works for noncommutative rings::
sage: A.<i,j,k> = QuaternionAlgebra(QQ, -1,-1)
sage: R.<w> = PolynomialRing(A)
sage: f = R.random_element(randint(20,50))
sage: g = R.random_element(randint(20,50))
sage: h = f*g
sage: f._mul_trunc(g,10) == h
False
sage: f._mul_trunc(g,10) == h[:10]
True
sage: f._mul_trunc(g,50) == h[:50]
True
sage: f._mul_trunc(g,110) == h
True
"""
if self.is_zero():
return self
elif right.is_zero():
return right
f = self.list()
g = right.list()
n = len(f)
m = len(g)
if n == 1:
c = f[0]
return self._parent([c*a for a in g[:prec]])
if m == 1:
c = g[0]
return self._parent([a*c for a in f[:prec]])
if K_threshold is None:
K_threshold = self._parent._Karatsuba_threshold
if n <= K_threshold or m <= K_threshold:
return self._parent(do_trunc_classical(f, g, prec))
if n == m:
return self._parent(do_trunc_karatsuba(f, g, prec, K_threshold))
return self._parent(do_trunc_karatsuba_different_size(f, g, prec, K_threshold))

def base_ring(self):
"""
Return the base ring of the parent of self.
Expand Down Expand Up @@ -7011,6 +7091,227 @@ cdef do_karatsuba(left, right, Py_ssize_t K_threshold,Py_ssize_t start_l, Py_ssi
ac[i] = ac[i] + tt1[e+i]
return bd + ac

cdef do_trunc_classical(x, y, Py_ssize_t prec):
"""
Method to compute the truncated multiplication of two polynomials
represented by lists. This is the the code that is used by _mul_trunc bellow
a threshold.
INPUT:
- ``x``: a list representing a polynomial.
- ``y``: a list representing a polynomial.
- ``prec``: desired precission, all terms of ``x*y`` of degree >=
``prec`` are discarded
Doctested indirectly in _mul_trunc.
TESTS::
sage: K = ZZ[x]
sage: f = K.random_element(8)
sage: g = K.random_element(8)
sage: (f*g)[:8] - f._mul_trunc(g, 8, 20)
0
"""
cdef Py_ssize_t i, k, start, end
cdef Py_ssize_t d1 = len(x)-1, d2 = len(y)-1
if d1 == -1:
return x
elif d2 == -1:
return y
elif d1 == 0:
c = x[0]
return [c*a for a in y[:prec]] #beware of noncommutative rings
elif d2 == 0:
c = y[0]
return [a*c for a in x[:prec]] #beware of noncommutative rings
coeffs = []
for k from 0 <= k <= min(d1+d2,prec-1):
start = 0 if k <= d2 else k-d2 # max(0, k-d2)
end = k if k <= d1 else d1 # min(k, d1)
sum = x[start] * y[k-start]
for i from start < i <= end:
sum += x[i] * y[k-i]
coeffs.append(sum)
return coeffs

cdef do_trunc_karatsuba(left, right, Py_ssize_t prec, Py_ssize_t K_threshold):
"""
Core routine for truncated karatsuba multiplicacion. This function works for
two polynomials of the same degree represented by lists.
INPUT:
- left: a list representing a polynomial.
- right: a list representing a polynomial with the same length of left.
- prec: precision. All terms of the multiplication of degree greater or
equal than prec are discarded.
- K_threshold: an integer. For lists of length <= K_threshold,
do_trunc_classical is used
OUTPUT:
- a list representing the slicing (left*right)[:prec]
Doctested indirectly in _mul_trunc.
TESTS::
sage: K.<x> = ZZ[]
sage: f = K.random_element(8) + x^9
sage: g = K.random_element(8) + x^9
sage: (f*g)[:8] - f._mul_trunc(g, 8, 2)
0
"""
cdef Py_ssize_t n, n0, n1, len_l, len_h, len_m, i
n = len(left)
if n == 0:
return left
if prec == 0:
return []
if n == 1 or prec == 1:
return [left[0]*right[0]]
if n <= K_threshold:
return do_trunc_classical(left, right, prec)
if 2*n-1 <= prec:
return do_karatsuba(left, right, K_threshold, 0, 0, n)
if prec == 2:
return [left[0]*right[0], left[0]*right[1]+left[1]*right[0]]
#prec >= 3
if n == 2:
b = left[0]
a = left[1]
d = right[0]
c = right[1]
ac = a*c
bd = b*d
return [bd,(a+b)*(c+d)-ac-bd,ac]
n0 = prec // 2
n1 = (prec+1) // 2
left_even = left[::2]
left_odd = left[1::2]
right_even = right[::2]
right_odd = right[1::2]
# l has degree bigger than h
l = do_trunc_karatsuba(left_even, right_even, n1, K_threshold)
h = do_trunc_karatsuba(left_odd, right_odd, n0, K_threshold)
left_add = list(left_even)
right_add = list(right_even)
for i from 0 <= i < len(left_odd):
left_add[i] += left_odd[i]
right_add[i] += right_odd[i]
m = do_trunc_karatsuba(left_add, right_add, n0, K_threshold)
len_l = len(l) # It can be much less than n1
len_h = len(h)
len_m = len(m)
l_h = list(l)
for i in range(len_h):
l_h[i] += h[i]
for i in range(len_m):
m[i] -= l_h[i]
m.extend([-f for f in l_h[len(m):]])
coeffs = []
len_m = len(m)
for i from 0 <= i < len_m:
coeffs.append(l[i])
coeffs.append(m[i])
for i from 0 <= i < len(coeffs)//2-1:
coeffs[2*i+2] += h[i]
return coeffs[:prec]

cdef do_trunc_karatsuba_different_size(left, right, Py_ssize_t prec, Py_ssize_t K_threshold):
"""
This algorithm is to deal with truncated karatsuba multiplication of two polynomials
of different degree.
INPUT:
- `left`: a list representing a polynomial
- `right`: a list representing a polynomial
- prec: precision. All terms of the multiplication of degree greater or
equal than prec are discarded
- `K_threshold`: an Integer, a threshold to pass to the classical
quadratic algorithm. During Karatsuba recursion, if one of the lists
has length <= K_threshold the classical product is used instead.
If `left` is a list representing a polynomial `f` of degree n and right is a list
representing a list of degree m with n < m, then we interpret `right` as
..math::
g0 + g1x^n +g2x^{2n} + ... + gqn^{nq}
where `gi` are polynomials of degree n-1, `gq`of degree <= n-1.
Then compute each product `fgi` with karatsuba multiplication and from then
reconstruct `fg`
This method is indirectly doctested in _mul_karatsuba_
TESTS::
sage: K.<x> = ZZ[]
sage: f = K.random_element(27) + x^28
sage: g = K.random_element(33) + x^34
sage: (f*g)[:20] - f._mul_trunc(g, 20, 2)
0
"""
cdef Py_ssize_t n, m, q, r, mi
n = min(len(left), prec)
m = min(len(right), prec)
if n == 0 or m == 0:
return []
if n == 1:
c = left[0]
return [c*a for a in right[:prec]]
if m == 1:
c = right[0]
return [a*c for a in left[:prec]] #beware of noncommutative rings
if n <= K_threshold or m <= K_threshold:
return do_trunc_classical(left,right, prec)
if n+m-1 <= prec:
return do_karatsuba_different_size(left,right, K_threshold)
if n == m:
return do_trunc_karatsuba(left[:n] ,right[:n], prec, K_threshold)
if n > m:
#left is the bigger list
#n is the bigger number
q = n // m
r = n % m
output = do_karatsuba(left, right, K_threshold, 0, 0, m)
for i from 1 <= i < q:
mi = m*i
carry = do_karatsuba(left, right, K_threshold, mi, 0, m)
for j in range(m-1):
output[mi+j] += carry[j]
output.extend(carry[m-1:])
if r:
mi = m*q
carry = do_trunc_karatsuba_different_size(left[mi:], right, prec-mi, K_threshold)
for j from 0 <= j < min(len(carry),m-1):
output[mi+j] += carry[j]
output.extend(carry[m-1:])
return output[:prec]
else:
# n < m, I need to repeat the code due to the case
# of noncommutative rings.
q = m // n
r = m % n
output = do_karatsuba(left, right, K_threshold, 0, 0, n)
for i from 1 <= i < q:
mi = n*i
carry = do_karatsuba(left, right, K_threshold, 0, mi, n)
for j in range(n-1):
output[mi+j] += carry[j]
output.extend(carry[n-1:])
if r:
mi = n*q
carry = do_trunc_karatsuba_different_size(left, right[mi:], prec-mi, K_threshold)
for j from 0 <= j < min(len(carry),n-1):
output[mi+j] += carry[j]
output.extend(carry[n-1:])
return output[:prec]

cpdef Polynomial_generic_dense _new_constant_dense_poly(list coeffs, Parent P, sample):
cdef Polynomial_generic_dense f = <Polynomial_generic_dense>PY_NEW_SAME_TYPE(sample)
Expand Down
46 changes: 44 additions & 2 deletions src/sage/rings/power_series_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import arith
from sage.libs.all import PariError
from power_series_ring_element import is_PowerSeries
import rational_field
from polynomial.polynomial_element import Polynomial_generic_dense

cdef class PowerSeries_poly(PowerSeries):

Expand Down Expand Up @@ -602,17 +603,58 @@ cdef class PowerSeries_poly(PowerSeries):
"""
Return the product of two power series.
If the underlying polynomial ring is Polynomial_generic_dense use a
truncated multiplication algorithm.
EXAMPLES::
sage: k.<w> = ZZ[[]]
sage: (1+17*w+15*w^3+O(w^5))*(19*w^10+O(w^12))
19*w^10 + 323*w^11 + O(w^12)
sage: O(w)*O(w)
O(w^2)
sage: O(w^2)*(w^3+O(w^5))
O(w^5)
sage: O(w) * (1+w)
O(w^1)
sage: L.<t> = QQ[I][[]]
sage: p1 = L.random_element(50)
sage: p2 = L.random_element(50)
sage: p1*p2 == p1.polynomial()*p2.polynomial()+O(t^50)
True
TESTS::
sage: K.<w> = Qp(3)[[]]
sage: f = 3^2*w + O(w^3)
sage: g = O(3)*w
sage: fg = f*g
sage: fg
0
sage: fg == 0
False
sage: fg.polynomial()
(O(3^3))*w^2
"""
prec = self._mul_prec(right_r)
if prec == 0:
return self._parent(self._parent.one_element(), prec=0)

# Avoid the case prec=+Infinity in the Polynomial_generic_dense case
if prec == infinity or not IS_INSTANCE(self.__f, Polynomial_generic_dense):
return PowerSeries_poly(self._parent,
self.__f * (<PowerSeries_poly>right_r).__f,
prec = prec,
check = True)
# check, since truncation may be needed

# If the underlying polynomial ring is
# Polynomial_generic_dense use a truncated multiplication.
return PowerSeries_poly(self._parent,
self.__f * (<PowerSeries_poly>right_r).__f,
self.__f._mul_trunc((<PowerSeries_poly>right_r).__f, prec),
prec = prec,
check = True) # check, since truncation may be needed
check = True)

cpdef RingElement _imul_(self, RingElement right_r):
"""
Expand Down

0 comments on commit af07389

Please sign in to comment.