Skip to content

Commit

Permalink
integrating Matt's changes into code base
Browse files Browse the repository at this point in the history
	modified:   torchsig/datasets/synthetic.py
	modified:   torchsig/datasets/wideband.py
	modified:   torchsig/models/model_utils/layer_tools.py
	modified:   torchsig/models/model_utils/model_utils_1d/conversions_to_1d.py
	modified:   torchsig/transforms/transforms.py
	modified:   torchsig/utils/dsp.py
	modified:   torchsig/utils/visualize.py
	modified:   torchsig/utils/writer.py
  • Loading branch information
pvallance committed Jul 9, 2024
1 parent 0ba79f8 commit d8f32f9
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 158 deletions.
199 changes: 176 additions & 23 deletions torchsig/datasets/synthetic.py

Large diffs are not rendered by default.

181 changes: 63 additions & 118 deletions torchsig/datasets/wideband.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,30 +217,17 @@ def __init__(
self.meta["upper_freq"] = self.meta["center_freq"] + self.meta["bandwidth"] / 2

def generate_iq(self):
# Read mod_index to determine which synthetic dataset to read from
self.meta["class_name"] = self.classes()
self.meta["class_index"] = self.class_list.index(self.meta["class_name"])

# estimate the approximate samples per symbol (used as an oversampling estimate for ConstellationDataset and FSK/MSK modulations)
approx_samp_per_sym = int(np.ceil(self.meta["bandwidth"] ** -1)) if self.meta["bandwidth"] < 1.0 else int(np.ceil(self.meta["bandwidth"]))
approx_bandwidth = approx_samp_per_sym**-1 if self.meta["bandwidth"] < 1.0 else int(np.ceil(self.meta["bandwidth"]))

# Determine if the new rate of the requested signal to determine how many samples to request
if "ofdm" in self.meta["class_name"]:
occupied_bandwidth = 0.5
elif "g" in self.meta["class_name"]:
if "m" in self.meta["class_name"]:
occupied_bandwidth = approx_bandwidth * (1 - 0.5 + self.meta["excess_bandwidth"])
else:
occupied_bandwidth = approx_bandwidth * (1 + 0.25 + self.meta["excess_bandwidth"])
elif "fsk" in self.meta["class_name"]:
occupied_bandwidth = approx_bandwidth * (1 + 1)
elif "msk" in self.meta["class_name"]:
occupied_bandwidth = approx_bandwidth
else:
occupied_bandwidth = approx_bandwidth * (1 + self.meta["excess_bandwidth"])
# self.meta["num_samples"] is the total number of IQ samples used in the snapshot. the duration (self.meta["duration"]) represents
# what proportion of the snapshot that the modulated waveform will occupy. the duration is on a range of [0,1]. for example, a
# duration of 0.75 means that the modulated waveform will occupy 75% of the total length of the snapshot.
self.meta["duration"] = self.meta["stop"] - self.meta["start"]
new_rate = occupied_bandwidth / self.meta["bandwidth"]
# TODO: why the 1.1
num_iq_samples = int(np.ceil(self.meta["num_samples"] * self.meta["duration"] / new_rate * 1.1) )
# calculate how many IQ samples are needed from the modulator based on the duration
num_iq_samples = int(np.ceil(self.meta["num_samples"] * self.meta["duration"]) )

# Create modulated burst
if "ofdm" in self.meta["class_name"]:
Expand All @@ -262,77 +249,41 @@ def generate_iq(self):
sidelobe_suppression_methods=sidelobe_suppression_methods,
dc_subcarrier=("on", "off"),
time_varying_realism=("on", "off"),
center_freq=self.meta["center_freq"],
bandwidth=self.meta["bandwidth"]
)
elif "g" in self.meta["class_name"]:
elif "fsk" in self.meta["class_name"] or "msk" in self.meta["class_name"]: # FSK, GFSK, MSK, GMSK
modulated_burst = FSKDataset(
modulations=[self.meta["class_name"]],
num_iq_samples=num_iq_samples,
num_samples_per_class=1,
iq_samples_per_symbol=approx_samp_per_sym,
random_data=True,
random_pulse_shaping=True,
center_freq=self.meta["center_freq"],
bandwidth=self.meta["bandwidth"]
)
elif "fsk" in self.meta["class_name"] or "msk" in self.meta["class_name"]:
modulated_burst = FSKDataset(
modulations=[self.meta["class_name"]],
num_iq_samples=num_iq_samples,
num_samples_per_class=1,
iq_samples_per_symbol=approx_samp_per_sym,
random_data=True,
random_pulse_shaping=False,
)
else:
else: # QAM/PSK and related
modulated_burst = ConstellationDataset(
constellations=[self.meta["class_name"]],
num_iq_samples=num_iq_samples,
num_samples_per_class=1,
iq_samples_per_symbol=approx_samp_per_sym,
random_data=True,
random_pulse_shaping=False, #True, TODO fix pulse shaping code.
center_freq=self.meta["center_freq"],
)

# Extract IQ samples from dataset example
iq_samples = modulated_burst[0][0]
# Resample to target bandwidth * oversample to avoid freq wrap during shift
if (self.meta["center_freq"] + self.meta["bandwidth"] / 2 > 0.4 or self.meta["center_freq"] - self.meta["bandwidth"] / 2 < -0.4):
oversample = 2 if self.meta["bandwidth"] < 1.0 else int(np.ceil(self.meta["bandwidth"] * 2))
else:
oversample = 1
# TODO : this is poor resampling.
up_rate = np.floor(new_rate * 100 * oversample).astype(np.int32)
down_rate = 100
iq_samples = sp.resample_poly(iq_samples, up_rate, down_rate)

# Freq shift to desired center freq
time_vector = np.arange(iq_samples.shape[0], dtype=float)
iq_samples = iq_samples * np.exp(2j * np.pi * self.meta["center_freq"] / oversample * time_vector)

if oversample == 1:
# Prune to length
iq_samples = iq_samples[:int(self.meta["num_samples"] * self.meta["duration"])]

else:
# Pre-prune to reduce filtering cost
iq_samples = iq_samples[:int(self.meta["num_samples"] * self.meta["duration"] * oversample)]
taps = low_pass(cutoff=1 / oversample / 2, transition_bandwidth=(0.5 - 1 / oversample / 2) / 4)
filtered = sp.convolve(iq_samples, taps, mode="full")
lidx = (len(filtered) - len(iq_samples)) // 2
ridx = lidx + len(iq_samples)
iq_samples = filtered[lidx:ridx]

# Decimate back down to correct sample rate
iq_samples = sp.resample_poly(iq_samples, 1, oversample)

# limit the number of samples to the desired duration
iq_samples = iq_samples[:int(self.meta["num_samples"] * self.meta["duration"])]

# Set power
iq_samples = iq_samples / np.sqrt(np.mean(np.abs(iq_samples) ** 2))
iq_samples = np.sqrt(self.meta["bandwidth"]) * (10 ** (self.meta["snr"] / 20.0)) * iq_samples / np.sqrt(2)

if iq_samples.shape[0] > 50:
window = np.blackman(50) / np.max(np.blackman(50))
iq_samples[:25] *= window[:25]
iq_samples[-25:] *= window[-25:]

# Zero-pad to fit num_iq_samples
leading_silence = int(self.meta["num_samples"] * self.meta["start"])
trailing_silence = self.meta["num_samples"] - len(iq_samples) - leading_silence
Expand Down Expand Up @@ -604,10 +555,7 @@ def __init__(
self.start = to_distribution(start, random_generator=self.random_generator)

# Generate the index by creating a set of bursts.
self.index = [
(collection, idx)
for idx, collection in enumerate(self._generate_burst_collections())
]
self.index = [(collection, idx) for idx, collection in enumerate(self._generate_burst_collections())]

def _generate_burst_collections(self) -> List[List[SignalBurst]]:
dataset = []
Expand Down Expand Up @@ -685,7 +633,7 @@ def __init__(
self.index = []
self.pregenerate = False
if pregenerate:
print("Pregenerating dataset...")
#print("Pregenerating dataset...")
for idx in tqdm(range(self.num_samples)):
self.index.append(self.__getitem__(idx))
self.pregenerate = pregenerate
Expand Down Expand Up @@ -925,9 +873,7 @@ def __init__(
]
)

print(self.transform)
self.target_transform = target_transform

self.num_signals = to_distribution(num_signals, random_generator=self.random_generator)
self.snrs = to_distribution(snrs, random_generator=self.random_generator)

Expand Down Expand Up @@ -1037,7 +983,6 @@ def __getitem__(self, item: int) -> Tuple[np.ndarray, Any]:
silence_multiple = to_distribution(literal_eval(self.metadata.iloc[meta_idx].silence_multiple), random_generator=self.random_generator)()
stops_in_frame = False
if hop_random_var < self.metadata.iloc[meta_idx].freq_hopping_prob:
# if 1:
# override bandwidth with smaller options for freq hoppers
if ofdm_present:
bandwidth = self.random_generator.uniform(0.0125, 0.025)
Expand Down Expand Up @@ -1097,51 +1042,51 @@ def __getitem__(self, item: int) -> Tuple[np.ndarray, Any]:
stop = 1.0

# Handle overlaps
overlap = False
minimum_freq_spacing = 0.05
for source in signal_sources:
for signal in source.index[0][0]:
meta = signal.meta
# Check time overlap
if (
(start > meta["start"] and start < meta["stop"])
or (
start + burst_duration > meta["stop"]
and stop < meta["stop"]
)
or (meta["start"] > start and meta["start"] < stop)
or (meta["stop"] > start and meta["stop"] < stop)
or (start == 0.0 and meta["start"] == 0.0)
or (stop == 1.0 and meta["stop"] == 1.0)
):
# Check freq overlap
if (
(
low_freq > meta["lower_freq"] - minimum_freq_spacing
and low_freq < meta["upper_freq"] + minimum_freq_spacing
)
or (
high_freq > meta["lower_freq"] - minimum_freq_spacing
and high_freq
< meta["upper_freq"] + minimum_freq_spacing
)
or (
meta["lower_freq"] - minimum_freq_spacing > low_freq
and meta["lower_freq"] - minimum_freq_spacing
< high_freq
)
or (
meta["upper_freq"] + minimum_freq_spacing > low_freq
and meta["upper_freq"] + minimum_freq_spacing
< high_freq
)
):
# Overlaps in both time and freq, skip
overlap = True
if overlap:
overlap_counter += 1
# print('skipping signal')
continue
# overlap = False
# minimum_freq_spacing = 0.05
# for source in signal_sources:
# for signal in source.index[0][0]:
# meta = signal.meta
# # Check time overlap
# if (
# (start > meta["start"] and start < meta["stop"])
# or (
# start + burst_duration > meta["stop"]
# and stop < meta["stop"]
# )
# or (meta["start"] > start and meta["start"] < stop)
# or (meta["stop"] > start and meta["stop"] < stop)
# or (start == 0.0 and meta["start"] == 0.0)
# or (stop == 1.0 and meta["stop"] == 1.0)
# ):
# # Check freq overlap
# if (
# (
# low_freq > meta["lower_freq"] - minimum_freq_spacing
# and low_freq < meta["upper_freq"] + minimum_freq_spacing
# )
# or (
# high_freq > meta["lower_freq"] - minimum_freq_spacing
# and high_freq
# < meta["upper_freq"] + minimum_freq_spacing
# )
# or (
# meta["lower_freq"] - minimum_freq_spacing > low_freq
# and meta["lower_freq"] - minimum_freq_spacing
# < high_freq
# )
# or (
# meta["upper_freq"] + minimum_freq_spacing > low_freq
# and meta["upper_freq"] + minimum_freq_spacing
# < high_freq
# )
# ):
# # Overlaps in both time and freq, skip
# overlap = True
# if overlap:
# overlap_counter += 1
# # print('skipping signal')
# continue

# Add signal to signal sources
snrs_db = self.snrs()
Expand Down
2 changes: 1 addition & 1 deletion torchsig/models/model_utils/layer_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_layer_list(model):
arr = [m for m in model.modules()]
if len(arr) > 1:
for module in arr[1:]:
final_arr += get_module_list(module)
final_arr += (module)
return final_arr
else:
return arr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def convert_2d_model_to_1d(model):
type_factory_pairs = [
('Conv2d', conv2d_to_conv1d),
('BatchNorm2d', batchNorm2d_to_GBN1d),
('BatchNormAct2d', batchNorm2d_to_batchNorm1d),
('SqueezeExcite', squeezeExcite_to_squeezeExcite1d),
('SelectAdaptivePool2d',make_fast_avg_pooling_layer),
]
Expand Down
9 changes: 2 additions & 7 deletions torchsig/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,7 @@ def __call__(self, target: Any) -> Any:
class SignalTransform(Transform):
"""An abstract base class which explicitly only operates on Signal data"""

def __init__(
self,
**kwargs,
) -> None:
def __init__(self, **kwargs,) -> None:
super(SignalTransform, self).__init__(**kwargs)

def __call__(self, signal: Signal) -> Signal:
Expand Down Expand Up @@ -530,7 +527,7 @@ class RandomResample(SignalTransform):
keep_samples (:obj:`bool`):
Despite returning a different number of samples being an issue, return however many samples
are returned from resample_poly
are returned from resampler
Note:
When rate_ratio is > 1.0, the resampling algorithm produces more samples than the original tensor.
Expand Down Expand Up @@ -1280,8 +1277,6 @@ def transform_data(self, signal: Signal, params: tuple) -> Signal:
return signal

def transform_meta(self, signal: Signal, params: Tuple) -> Signal:
# for meta in signal["metadata"]:
# print(f'meta[start] -> {meta["start"]} meta[stop] -> {meta["stop"]} meta[duration] -> {meta["duration"]}')
return signal

class ContinuousWavelet(SignalTransform):
Expand Down
20 changes: 19 additions & 1 deletion torchsig/utils/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def low_pass(cutoff: float, transition_bandwidth: float) -> np.ndarray:
transition_bandwidth (float): width of the transition region
"""
transition_bandwidth = (0.5 - cutoff) / 4
num_taps = estimate_filter_length(transition_bandwidth)
return sp.firwin(
num_taps,
Expand All @@ -28,6 +27,25 @@ def low_pass(cutoff: float, transition_bandwidth: float) -> np.ndarray:
fs=1,
)

def polyphase_prototype_filter ( num_branches: int ) -> np.ndarray:
# design a low-pass filter
cutoff = 1/(2*num_branches)
transitionBandwidth = 1/(4*num_branches)
prototypeFilterPFB = low_pass(cutoff=cutoff, transition_bandwidth=transitionBandwidth)
# increase gain to account for change in sample rate
prototypeFilterPFB *= num_branches
return prototypeFilterPFB

def irrational_rate_resampler ( input_signal: np.ndarray, resampler_rate: float ) -> np.ndarray:
# TODO: needs to be estimated, not a fixed value
numBranchesPFB = 10000
resamplerUpRate = numBranchesPFB
resamplerDownRate = int(np.round(numBranchesPFB/resampler_rate))
# design the PFB prototype filter
prototypeFilterPFB = polyphase_prototype_filter ( numBranchesPFB )
# apply the PFB via upfirdn()
output = sp.upfirdn(prototypeFilterPFB, input_signal, up=resamplerUpRate, down=resamplerDownRate)
return output

def estimate_filter_length(
transition_bandwidth: float, attenuation_db: int = 120, sample_rate: float = 1.0
Expand Down
10 changes: 4 additions & 6 deletions torchsig/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pywt
import torch
import pdb


class Visualizer:
Expand Down Expand Up @@ -922,17 +923,14 @@ def mask_class_to_outline(tensor: np.ndarray) -> Tuple[List[List[int]], List[Any
for idx in range(batch_size):
label = tensor[idx].numpy()
class_idx_curr = []
pdb.set_trace()
for individual_burst_idx in range(label.shape[0]):
if np.count_nonzero(label[individual_burst_idx]) > 0:
class_idx_curr.append(individual_burst_idx)
label[individual_burst_idx] = label[
individual_burst_idx
] - ndimage.binary_erosion(label[individual_burst_idx])
label[individual_burst_idx] = label[individual_burst_idx] - ndimage.binary_erosion(label[individual_burst_idx])
label = np.sum(label, axis=0)
label[label > 0] = 1
label = ndimage.binary_dilation(label, structure=struct, iterations=2).astype(
label.dtype
)
label = ndimage.binary_dilation(label, structure=struct, iterations=2).astype(label.dtype)
label = np.ma.masked_where(label == 0, label)
class_idx.append(class_idx_curr)
labels.append(label)
Expand Down
4 changes: 2 additions & 2 deletions torchsig/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def write(self, batch):
for label_idx, label in enumerate(labels):
txn.put(
pickle.dumps(last_idx + label_idx),
pickle.dumps(tuple(label.numpy())),
pickle.dumps(tuple(label)),
db=self.label_db,
)
if isinstance(labels, list):
Expand All @@ -116,7 +116,7 @@ def write(self, batch):
for element_idx in range(len(data)):
txn.put(
pickle.dumps(last_idx + element_idx),
pickle.dumps(data[element_idx].numpy()),
pickle.dumps(data[element_idx]),
db=self.data_db,
)

Expand Down

0 comments on commit d8f32f9

Please sign in to comment.