Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Commit

Permalink
Merge branch 'u/chapoton/33324' in 9.6.b1
Browse files Browse the repository at this point in the history
  • Loading branch information
fchapoton committed Feb 13, 2022
2 parents 826061a + 1603529 commit 93cc7c0
Showing 1 changed file with 95 additions and 92 deletions.
187 changes: 95 additions & 92 deletions src/sage/calculus/transforms/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,19 @@
- William Stein (2006-11) -- fix many bugs
"""

##########################################################################
# Copyright (C) 2006 David Joyner <wdjoyner@gmail.com>
#
# Distributed under the terms of the GNU General Public License (GPL):
#
# http://www.gnu.org/licenses/
# https://www.gnu.org/licenses/
##########################################################################

from sage.rings.number_field.number_field import CyclotomicField
from sage.plot.all import polygon, line, text
from sage.groups.abelian_gps.abelian_group import AbelianGroup
from sage.groups.perm_gps.permgroup_element import is_PermutationGroupElement
from sage.rings.integer_ring import ZZ
from sage.rings.integer import Integer
from sage.arith.all import factor
from sage.rings.rational_field import QQ
from sage.rings.real_mpfr import RR
from sage.functions.all import sin, cos
Expand All @@ -89,6 +86,7 @@
from sage.structure.sage_object import SageObject
from sage.structure.sequence import Sequence


class IndexedSequence(SageObject):
"""
An indexed sequence.
Expand Down Expand Up @@ -225,9 +223,9 @@ def _repr_(self):
Indexed sequence: [0, 1, 1]
indexed by Finite Field of size 3
"""
return "Indexed sequence: "+str(self.list())+"\n indexed by "+str(self.index_object())
return "Indexed sequence: " + str(self.list()) + "\n indexed by " + str(self.index_object())

def plot_histogram(self, clr=(0,0,1), eps = 0.4):
def plot_histogram(self, clr=(0, 0, 1), eps=0.4):
r"""
Plot the histogram plot of the sequence.
Expand All @@ -249,8 +247,13 @@ def plot_histogram(self, clr=(0,0,1), eps = 0.4):
I = self.index_object()
N = len(I)
S = self.list()
P = [polygon([[RR(I[i])-eps,0],[RR(I[i])-eps,RR(S[i])],[RR(I[i])+eps,RR(S[i])],[RR(I[i])+eps,0],[RR(I[i]),0]], rgbcolor=clr) for i in range(N)]
T = [text(str(I[i]),(RR(I[i]),-0.8),fontsize=15,rgbcolor=(1,0,0)) for i in range(N)]
P = [polygon([[RR(I[i]) - eps, 0],
[RR(I[i]) - eps, RR(S[i])],
[RR(I[i]) + eps, RR(S[i])],
[RR(I[i]) + eps, 0],
[RR(I[i]), 0]], rgbcolor=clr) for i in range(N)]
T = [text(str(I[i]), (RR(I[i]), -0.8), fontsize=15, rgbcolor=(1, 0, 0))
for i in range(N)]
return sum(P) + sum(T)

def plot(self):
Expand All @@ -271,9 +274,9 @@ def plot(self):
# elements must be coercible into RR
I = self.index_object()
S = self.list()
return line([[RR(I[i]),RR(S[i])] for i in range(len(I)-1)])
return line([[RR(I[i]), RR(S[i])] for i in range(len(I) - 1)])

def dft(self, chi = lambda x: x):
def dft(self, chi=lambda x: x):
r"""
A discrete Fourier transform "over `\QQ`" using exact
`N`-th roots of unity.
Expand Down Expand Up @@ -322,34 +325,34 @@ def dft(self, chi = lambda x: x):
implemented Group (permutation, matrix), call .characters()
and test if the index list is the set of conjugacy classes.
"""
J = self.index_object() ## index set of length N
J = self.index_object() # index set of length N
N = len(J)
S = self.list()
F = self.base_ring() ## elements must be coercible into QQ(zeta_N)
F = self.base_ring() # elements must be coercible into QQ(zeta_N)
if not(J[0] in ZZ):
G = J[0].parent() ## if J is not a range it is a group G
if J[0] in ZZ and F.base_ring().fraction_field()==QQ:
## assumes J is range(N)
G = J[0].parent() # if J is not a range it is a group G
if J[0] in ZZ and F.base_ring().fraction_field() == QQ:
# assumes J is range(N)
zeta = CyclotomicField(N).gen()
FT = [sum([S[i]*chi(zeta**(i*j)) for i in J]) for j in J]
elif not(J[0] in ZZ) and G.is_abelian() and F == ZZ or (F.is_field() and F.base_ring()==QQ):
FT = [sum([S[i] * chi(zeta**(i * j)) for i in J]) for j in J]
elif (J[0] not in ZZ) and G.is_abelian() and F == ZZ or (F.is_field() and F.base_ring() == QQ):
if is_PermutationGroupElement(J[0]):
## J is a CyclicPermGp
# J is a CyclicPermGp
n = G.order()
a = list(factor(n))
a = list(n.factor())
invs = [x[0]**x[1] for x in a]
G = AbelianGroup(len(a),invs)
## assumes J is AbelianGroup(...)
G = AbelianGroup(len(a), invs)
# assumes J is AbelianGroup(...)
Gd = G.dual_group()
FT = [sum([S[i]*chid(G.list()[i]) for i in range(N)])
FT = [sum([S[i] * chid(G.list()[i]) for i in range(N)])
for chid in Gd]
elif not(J[0] in ZZ) and G.is_finite() and F == ZZ or (F.is_field() and F.base_ring()==QQ):
## assumes J is the list of conj class representatives of a
## PermutationGroup(...) or Matrixgroup(...)
elif (J[0] not in ZZ) and G.is_finite() and F == ZZ or (F.is_field() and F.base_ring() == QQ):
# assumes J is the list of conj class representatives of a
# PermutationGroup(...) or Matrixgroup(...)
chi = G.character_table()
FT = [sum([S[i]*chi[i,j] for i in range(N)]) for j in range(N)]
FT = [sum([S[i] * chi[i, j] for i in range(N)]) for j in range(N)]
else:
raise ValueError("list elements must be in QQ(zeta_"+str(N)+")")
raise ValueError(f"list elements must be in QQ(zeta_{N})")
return IndexedSequence(FT, J)

def idft(self):
Expand All @@ -370,15 +373,15 @@ def idft(self):
sage: it == s
True
"""
F = self.base_ring() ## elements must be coercible into QQ(zeta_N)
J = self.index_object() ## must be = range(N)
F = self.base_ring() # elements must be coercible into QQ(zeta_N)
J = self.index_object() # must be = range(N)
N = len(J)
S = self.list()
zeta = CyclotomicField(N).gen()
iFT = [sum([S[i]*zeta**(-i*j) for i in J]) for j in J]
if not(J[0] in ZZ) or F.base_ring().fraction_field() != QQ:
iFT = [sum([S[i] * zeta**(-i * j) for i in J]) for j in J]
if (J[0] not in ZZ) or F.base_ring().fraction_field() != QQ:
raise NotImplementedError("Sorry this type of idft is not implemented yet.")
return IndexedSequence(iFT,J)*(Integer(1)/N)
return IndexedSequence(iFT, J) * (Integer(1) / N)

def dct(self):
"""
Expand All @@ -390,17 +393,17 @@ def dct(self):
sage: A = [exp(-2*pi*i*I/5) for i in J]
sage: s = IndexedSequence(A,J)
sage: s.dct()
Indexed sequence: [1/16*(sqrt(5) + I*sqrt(-2*sqrt(5) + 10) + ...
Indexed sequence: [0, 1/16*(sqrt(5) + I*sqrt(-2*sqrt(5) + 10) + ...
indexed by [0, 1, 2, 3, 4]
"""
from sage.symbolic.constants import pi
F = self.base_ring() ## elements must be coercible into RR
J = self.index_object() ## must be = range(N)
F = self.base_ring() # elements must be coercible into RR
J = self.index_object() # must be = range(N)
N = len(J)
S = self.list()
PI = F(pi)
FT = [sum([S[i]*cos(2*PI*i/N) for i in J]) for j in J]
return IndexedSequence(FT,J)
PI = 2 * F(pi) / N
FT = [sum([S[i] * cos(PI * i * j) for i in J]) for j in J]
return IndexedSequence(FT, J)

def dst(self):
"""
Expand All @@ -414,17 +417,17 @@ def dst(self):
sage: s = IndexedSequence(A,J)
sage: s.dst() # discrete sine
Indexed sequence: [1.11022302462516e-16 - 2.50000000000000*I, 1.11022302462516e-16 - 2.50000000000000*I, 1.11022302462516e-16 - 2.50000000000000*I, 1.11022302462516e-16 - 2.50000000000000*I, 1.11022302462516e-16 - 2.50000000000000*I]
indexed by [0, 1, 2, 3, 4]
Indexed sequence: [0.000000000000000, 1.11022302462516e-16 - 2.50000000000000*I, ...]
indexed by [0, 1, 2, 3, 4]
"""
from sage.symbolic.constants import pi
F = self.base_ring() ## elements must be coercible into RR
J = self.index_object() ## must be = range(N)
F = self.base_ring() # elements must be coercible into RR
J = self.index_object() # must be = range(N)
N = len(J)
S = self.list()
PI = F(pi)
FT = [sum([S[i]*sin(2*PI*i/N) for i in J]) for j in J]
return IndexedSequence(FT,J)
PI = 2 * F(pi) / N
FT = [sum([S[i] * sin(PI * i * j) for i in J]) for j in J]
return IndexedSequence(FT, J)

def convolution(self, other):
r"""
Expand Down Expand Up @@ -471,19 +474,18 @@ def convolution(self, other):
raise TypeError("IndexedSequences must have same index set")
M = len(S)
N = len(T)
if M < N: ## first, extend by 0 if necessary
a = [S[i] for i in range(M)]+[F(0) for i in range(2*N)]
b = T+[E(0) for i in range(2*M)]
if M > N: ## python trick - a[-j] is really j from the *right*
b = [T[i] for i in range(N)]+[E(0) for i in range(2*M)]
a = S+[F(0) for i in range(2*M)]
if M==N: ## so need only extend by 0 to the *right*
a = S+[F(0) for i in range(2*M)]
b = T+[E(0) for i in range(2*M)]
N = max(M,N)
c = [sum([a[i]*b[j-i] for i in range(N)]) for j in range(2*N-1)]
#print([[b[j-i] for i in range(N)] for j in range(N)])
return c
if M < N: # first, extend by 0 if necessary
a = [S[i] for i in range(M)] + [F(0) for i in range(2 * N)]
b = T + [E(0) for i in range(2 * M)]
if M > N: # python trick - a[-j] is really j from the *right*
b = [T[i] for i in range(N)] + [E(0) for i in range(2 * M)]
a = S + [F(0) for i in range(2 * M)]
if M == N: # so need only extend by 0 to the *right*
a = S + [F(0) for i in range(2 * M)]
b = T + [E(0) for i in range(2 * M)]
N = max(M, N)
return [sum([a[i] * b[j - i] for i in range(N)])
for j in range(2 * N - 1)]

def convolution_periodic(self, other):
r"""
Expand Down Expand Up @@ -531,17 +533,17 @@ def convolution_periodic(self, other):
M = len(S)
N = len(T)
if M < N: # first, extend by 0 if necessary
a = [S[i] for i in range(M)]+[F(0) for i in range(N-M)]
a = [S[i] for i in range(M)] + [F(0) for i in range(N - M)]
b = other
if M > N:
b = [T[i] for i in range(N)]+[E(0) for i in range(M-N)]
b = [T[i] for i in range(N)] + [E(0) for i in range(M - N)]
a = self
if M == N:
a = S
b = T
N = max(M, N)
c = [sum([a[i]*b[(j-i)%N] for i in range(N)]) for j in range(2*N-1)]
return c
return [sum([a[i] * b[(j - i) % N] for i in range(N)])
for j in range(2 * N - 1)]

def __mul__(self, other):
"""
Expand All @@ -563,7 +565,7 @@ def __mul__(self, other):
S1 = [S[i] * other for i in range(len(self.index_object()))]
return IndexedSequence(S1, self.index_object())

def __eq__(self,other):
def __eq__(self, other):
"""
Implements boolean equals.
Expand All @@ -587,16 +589,17 @@ def __eq__(self,other):
T = other.list()
I = self.index_object()
J = other.index_object()
if I!=J:
if I != J:
return False
for i in I:
try:
if abs(S[i]-T[i]) > 10**(-8): ## tests if they differ as reals -- WHY 10^(-8)???
if abs(S[i] - T[i]) > 10**(-8):
# tests if they differ as reals -- WHY 10^(-8)???
return False
except TypeError:
pass
#if F!=E: ## omitted this test since it
# return 0 ## doesn't take into account coercions -- WHY???
# if F != E: # omitted this test since it
# return 0 # doesn't take into account coercions -- WHY???
return True

def fft(self):
Expand All @@ -623,14 +626,14 @@ def fft(self):
I = CC.gen()

# elements must be coercible into RR
J = self.index_object() ## must be = range(N)
J = self.index_object() # must be = range(N)
N = len(J)
S = self.list()
a = FastFourierTransform(N)
for i in range(N):
a[i] = S[i]
a.forward_transform()
return IndexedSequence([a[j][0]+I*a[j][1] for j in J],J)
return IndexedSequence([a[j][0] + I * a[j][1] for j in J], J)

def ifft(self):
"""
Expand Down Expand Up @@ -660,16 +663,16 @@ def ifft(self):
I = CC.gen()

# elements must be coercible into RR
J = self.index_object() ## must be = range(N)
J = self.index_object() # must be = range(N)
N = len(J)
S = self.list()
a = FastFourierTransform(N)
for i in range(N):
a[i] = S[i]
a.inverse_transform()
return IndexedSequence([a[j][0]+I*a[j][1] for j in J],J)
return IndexedSequence([a[j][0] + I * a[j][1] for j in J], J)

def dwt(self,other="haar",wavelet_k=2):
def dwt(self, other="haar", wavelet_k=2):
r"""
Wraps the gsl ``WaveletTransform.forward`` in :mod:`~sage.calculus.transforms.dwt`
(written by Joshua Kantor). Assumes the length of the sample is a
Expand Down Expand Up @@ -709,28 +712,28 @@ def dwt(self,other="haar",wavelet_k=2):
indexed by [0, 1, 2, 3, 4, 5, 6, 7]
"""
# elements must be coercible into RR
J = self.index_object() ## must be = range(N)
N = len(J) ## must be 1 minus a power of 2
J = self.index_object() # must be = range(N)
N = len(J) # must be 1 minus a power of 2
S = self.list()
if other == "haar" or other == "haar_centered":
if wavelet_k in [2]:
a = WaveletTransform(N,other,wavelet_k)
a = WaveletTransform(N, other, wavelet_k)
else:
raise ValueError("wavelet_k must be = 2")
if other == "debauchies" or other == "debauchies_centered":
if wavelet_k in [4,6,8,10,12,14,16,18,20]:
a = WaveletTransform(N,other,wavelet_k)
if other == "daubechies" or other == "daubechies_centered":
if wavelet_k in [4, 6, 8, 10, 12, 14, 16, 18, 20]:
a = WaveletTransform(N, other, wavelet_k)
else:
raise ValueError("wavelet_k must be in {4,6,8,10,12,14,16,18,20}")
if other == "bspline" or other == "bspline_centered":
if wavelet_k in [103,105,202,204,206,208,301,305,307,309]:
a = WaveletTransform(N,other,103)
if wavelet_k in [103, 105, 202, 204, 206, 208, 301, 305, 307, 309]:
a = WaveletTransform(N, other, 103)
else:
raise ValueError("wavelet_k must be in {103,105,202,204,206,208,301,305,307,309}")
for i in range(N):
a[i] = S[i]
a.forward_transform()
return IndexedSequence([RR(a[j]) for j in J],J)
return IndexedSequence([RR(a[j]) for j in J], J)

def idwt(self, other="haar", wavelet_k=2):
r"""
Expand Down Expand Up @@ -786,26 +789,26 @@ def idwt(self, other="haar", wavelet_k=2):
True
"""
# elements must be coercible into RR
J = self.index_object() ## must be = range(N)
N = len(J) ## must be 1 minus a power of 2
J = self.index_object() # must be = range(N)
N = len(J) # must be 1 minus a power of 2
S = self.list()
k = wavelet_k
if other=="haar" or other=="haar_centered":
if other == "haar" or other == "haar_centered":
if k in [2]:
a = WaveletTransform(N,other,wavelet_k)
a = WaveletTransform(N, other, wavelet_k)
else:
raise ValueError("wavelet_k must be = 2")
if other=="debauchies" or other=="debauchies_centered":
if k in [4,6,8,10,12,14,16,18,20]:
a = WaveletTransform(N,other,wavelet_k)
if other == "daubechies" or other == "daubechies_centered":
if k in [4, 6, 8, 10, 12, 14, 16, 18, 20]:
a = WaveletTransform(N, other, wavelet_k)
else:
raise ValueError("wavelet_k must be in {4,6,8,10,12,14,16,18,20}")
if other=="bspline" or other=="bspline_centered":
if k in [103,105,202,204,206,208,301,305,307,309]:
a = WaveletTransform(N,other,103)
if other == "bspline" or other == "bspline_centered":
if k in [103, 105, 202, 204, 206, 208, 301, 305, 307, 309]:
a = WaveletTransform(N, other, 103)
else:
raise ValueError("wavelet_k must be in {103,105,202,204,206,208,301,305,307,309}")
for i in range(N):
a[i] = S[i]
a.backward_transform()
return IndexedSequence([RR(a[j]) for j in J],J)
return IndexedSequence([RR(a[j]) for j in J], J)

0 comments on commit 93cc7c0

Please sign in to comment.