diff --git a/torchsig/datasets/synthetic.py b/torchsig/datasets/synthetic.py index a98b8bc..dd38bb5 100644 --- a/torchsig/datasets/synthetic.py +++ b/torchsig/datasets/synthetic.py @@ -280,7 +280,12 @@ def _generate_samples(self, item: Tuple) -> np.ndarray: symbols = const[symbol_nums] zero_padded = np.zeros((self.iq_samples_per_symbol * len(symbols),), dtype=np.complex64) zero_padded[::self.iq_samples_per_symbol] = symbols - self.pulse_shape_filter = self._rrc_taps(11, signal_description.excess_bandwidth) + + # estimate total filter length for pulse shape + AdB = 72 # sidelobe attenuation level, 72 dB -> 12 bit dynamic range + pulse_shape_filter_length = estimate_filter_length(AdB,1,signal_description.excess_bandwidth) + pulse_shape_filter_span = int((pulse_shape_filter_length-1)/2) # convert filter length into the span + self.pulse_shape_filter = self._rrc_taps(pulse_shape_filter_span, signal_description.excess_bandwidth) xp = cp if self.use_gpu else np filtered = xp.convolve(xp.array(zero_padded), xp.array(self.pulse_shape_filter), "same") @@ -296,7 +301,8 @@ def _rrc_taps(self, size_in_symbols: int, alpha: float = .35) -> np.ndarray: n = np.arange(-M * Ns, M * Ns + 1) taps = np.zeros(int(2 * M * Ns + 1)) for i in range(int(2 * M * Ns + 1)): - if abs(1 - 16 * alpha ** 2 * (n[i] / Ns) ** 2) <= np.finfo(np.float64).eps / 2: + # handle the discontinuity at t=+-Ns/(4*alpha) + if (n[i]*4*alpha == Ns or n[i]*4*alpha == -Ns): taps[i] = 1 / 2. * ((1 + alpha) * np.sin((1 + alpha) * np.pi / (4. * alpha)) - (1 - alpha) * np.cos( (1 - alpha) * np.pi / (4. * alpha)) + (4 * alpha) / np.pi * np.sin( (1 - alpha) * np.pi / (4. * alpha))) @@ -925,3 +931,23 @@ def _generate_samples(self, item: Tuple) -> np.ndarray: np.random.set_state(orig_state) # return numpy back to its previous state return modulated[-self.num_iq_samples:] + + +def estimate_filter_length ( AdB, fs, transitionBandwidth ): + # estimate the length of an FIR filter using harris' approximaion, + # N ~= (sampling rate/transition bandwidth)*(sidelobe attenuation in dB / 22) + # fred harris, Multirate Signal Processing for Communication Systems, + # Second Edition, p.59 + filter_length = int(np.round((fs/transitionBandwidth)*(AdB/22))) + + # odd-length filters are desirable because they do not introduce a half-sample delay + if (np.mod(filter_length,2) == 0): + filter_length += 1 + + return filter_length + + + + + +