diff --git a/Dockerfile b/Dockerfile index 46586ba..598e1ac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime +FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime ENV DEBIAN_FRONTEND=noninteractive diff --git a/README.md b/README.md index 5ad4c45..15614b7 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,22 @@ cd torchsig pip install . ``` +## Generating the Datasets +If you'd like to generate the named datasets without messing with your current Python environment, you can build the development container and use it to generate data at the location of your choosing. + +``` +docker build -t torchsig -f Dockerfile . +docker run -u $(id -u ${USER}):$(id -g ${USER}) -v `pwd`:/workspace/code/torchsig torchsig python3 torchsig/scripts/generate_sig53.py --root=/workspace/code/torchsig/data --all=True +``` + +If you do not need to use Docker, you can also just generate using the regular command-line interface + +``` +python3 torchsig/scripts/generate_sig53.py --root=torchsig/data --all=True +``` + +Then, be sure to point scripts looking for ```root``` to ```torchsig/data```. + ## Using the Dockerfile If you have Docker installed along with compatible GPUs and drivers, you can try: diff --git a/examples/02_example_sig53_classifier.py b/examples/02_example_sig53_classifier.py index 852472c..a0f9051 100644 --- a/examples/02_example_sig53_classifier.py +++ b/examples/02_example_sig53_classifier.py @@ -17,7 +17,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import LightningModule, Trainer from sklearn.metrics import classification_report -from cm_plotter import plot_confusion_matrix +from torchsig.utils.cm_plotter import plot_confusion_matrix from torchsig.datasets.sig53 import Sig53 from torchsig.datasets.modulations import ModulationsDataset from torch.utils.data import DataLoader diff --git a/pyproject.toml b/pyproject.toml index 4883711..3e3997b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "sympy", "numba", "torchmetrics", + "click" ] dynamic = ["version"] diff --git a/scripts/generate_sig53.py b/scripts/generate_sig53.py new file mode 100644 index 0000000..2e7f38b --- /dev/null +++ b/scripts/generate_sig53.py @@ -0,0 +1,64 @@ +from torchsig.utils.writer import DatasetCreator, DatasetLoader +from torchsig.datasets.modulations import ModulationsDataset +from torchsig.datasets import conf +from typing import List +import click +import os + + +def generate(path: str, configs: List[conf.Sig53Config]): + for config in configs: + ds = ModulationsDataset( + level=config.level, + num_samples=config.num_samples, + num_iq_samples=config.num_iq_samples, + use_class_idx=config.use_class_idx, + include_snr=config.include_snr, + eb_no=config.eb_no, + ) + loader = DatasetLoader( + ds, + seed=12345678, + num_workers=os.cpu_count() // 2, + batch_size=os.cpu_count() // 2, + ) + creator = DatasetCreator( + ds, + seed=12345678, + path="{}".format(os.path.join(path, config.name)), + loader=loader, + ) + creator.create() + + +@click.command() +@click.option("--root", default="sig53", help="Path to generate sig53 datasets") +@click.option("--all", default=True, help="Generate all versions of sig53 dataset.") +@click.option( + "--impaired", + default=False, + help="Generate impaired dataset. Ignored if --all=True (default)", +) +def main(root: str, all: bool, impaired: bool): + if not os.root.isdir(root): + os.mkdir(root) + + configs = [ + conf.Sig53CleanTrainConfig, + conf.Sig53CleanValConfig, + conf.Sig53ImpairedTrainConfig, + conf.Sig53ImpairedValConfig, + ] + if all: + generate(root, configs) + return + + if impaired: + generate(root, configs[2:]) + return + + generate(root, configs[:2]) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_sig53.py b/scripts/train_sig53.py new file mode 100644 index 0000000..8888c7a --- /dev/null +++ b/scripts/train_sig53.py @@ -0,0 +1,188 @@ +from torchsig.transforms.target_transforms import DescToClassIndex +from torchsig.models.iq_models.efficientnet.efficientnet import efficientnet_b4 +from torchsig.transforms.transforms import ( + RandomPhaseShift, + Normalize, + ComplexTo2D, + Compose, +) +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning import LightningModule, Trainer +from sklearn.metrics import classification_report +from torchsig.utils.cm_plotter import plot_confusion_matrix +from torchsig.datasets.sig53 import Sig53 +from torch.utils.data import DataLoader +from matplotlib import pyplot as plt +from torch import optim +from tqdm import tqdm +import torch.nn.functional as F +import numpy as np +import click +import torch +import os + + +class ExampleNetwork(LightningModule): + def __init__(self, model, data_loader, val_data_loader): + super(ExampleNetwork, self).__init__() + self.mdl: torch.nn.Module = model + self.data_loader: DataLoader = data_loader + self.val_data_loader: DataLoader = val_data_loader + + # Hyperparameters + self.lr = 0.001 + self.batch_size = data_loader.batch_size + + def forward(self, x: torch.Tensor): + return self.mdl(x.float()) + + def predict(self, x: torch.Tensor): + with torch.no_grad(): + out = self.forward(x.float()) + return out + + def configure_optimizers(self): + return optim.Adam(self.parameters(), lr=self.lr) + + def train_dataloader(self): + return self.data_loader + + def val_dataloader(self): + return self.val_data_loader + + def training_step(self, batch: torch.Tensor, batch_nb: int): + x, y = batch + y = torch.squeeze(y.to(torch.int64)) + loss = F.cross_entropy(self(x.float()), y) + self.log("loss", loss, on_step=True, prog_bar=True, logger=True) + return loss + + def validation_step(self, batch: torch.Tensor, batch_nb: int): + x, y = batch + y = torch.squeeze(y.to(torch.int64)) + loss = F.cross_entropy(self(x.float()), y) + self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True) + return loss + + +@click.command() +@click.option("--root", default="data/sig53", help="Path to train/val datasets") +@click.option("--impaired", default=False, help="Impaired or clean datasets") +def main(root: str, impaired: bool): + class_list = list(Sig53._idx_to_name_dict.values()) + transform = Compose( + [ + RandomPhaseShift(phase_offset=(-1, 1)), + Normalize(norm=np.inf), + ComplexTo2D(), + ] + ) + target_transform = DescToClassIndex(class_list=class_list) + + sig53_train = Sig53( + root, + train=True, + impaired=impaired, + transform=transform, + target_transform=target_transform, + use_signal_data=True, + ) + + sig53_val = Sig53( + root, + train=False, + impaired=impaired, + transform=transform, + target_transform=target_transform, + use_signal_data=True, + ) + + # Create dataloaders"data + train_dataloader = DataLoader( + dataset=sig53_train, + batch_size=os.cpu_count(), + num_workers=os.cpu_count() // 2, + shuffle=True, + drop_last=True, + ) + val_dataloader = DataLoader( + dataset=sig53_val, + batch_size=os.cpu_count(), + num_workers=os.cpu_count() // 2, + shuffle=False, + drop_last=True, + ) + + model = efficientnet_b4(pretrained=False) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + example_model = ExampleNetwork(model, train_dataloader, val_dataloader) + example_model = example_model.to(device) + + # Setup checkpoint callbacks + checkpoint_filename = "{}/checkpoint".format(os.getcwd()) + checkpoint_callback = ModelCheckpoint( + filename=checkpoint_filename, + save_top_k=True, + monitor="val_loss", + mode="min", + ) + + # Create and fit trainer + epochs = 500 + trainer = Trainer( + max_epochs=epochs, callbacks=checkpoint_callback, devices=1, accelerator="gpu" + ) + trainer.fit(example_model) + + # Load best checkpoint + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load( + checkpoint_filename + ".ckpt", map_location=lambda storage, loc: storage + ) + example_model.load_state_dict(checkpoint["state_dict"]) + example_model = example_model.to(device=device).eval() + + # Infer results over validation set + num_test_examples = len(sig53_val) + num_classes = len(list(Sig53._idx_to_name_dict.values())) + y_raw_preds = np.empty((num_test_examples, num_classes)) + y_preds = np.zeros((num_test_examples,)) + y_true = np.zeros((num_test_examples,)) + + for i in tqdm(range(0, num_test_examples)): + # Retrieve data + idx = i # Use index if evaluating over full dataset + data, label = sig53_val[idx] + # Infer + data = torch.from_numpy(np.expand_dims(data, 0)).float().to(device) + pred_tmp = example_model.predict(data) + pred_tmp = pred_tmp.cpu().numpy() if torch.cuda.is_available() else pred_tmp + # Argmax + y_preds[i] = np.argmax(pred_tmp) + # Store label + y_true[i] = label + + acc = np.sum(np.asarray(y_preds) == np.asarray(y_true)) / len(y_true) + plot_confusion_matrix( + y_true, + y_preds, + classes=class_list, + normalize=True, + title="Example Modulations Confusion Matrix\nTotal Accuracy: {:.2f}%".format( + acc * 100 + ), + text=False, + rotate_x_text=90, + figsize=(16, 9), + ) + plt.savefig("{}/02_sig53_classifier.png".format(os.getcwd())) + + print("Classification Report:") + print(classification_report(y_true, y_preds)) + + +if __name__ == "__main__": + main() diff --git a/torchsig/utils/cm_plotter.py b/torchsig/utils/cm_plotter.py new file mode 100644 index 0000000..ca5cd42 --- /dev/null +++ b/torchsig/utils/cm_plotter.py @@ -0,0 +1,66 @@ +from sklearn.metrics import confusion_matrix +from matplotlib import pyplot as plt +from typing import Optional +import numpy as np + + +def plot_confusion_matrix( + y_true: np.array, + y_pred: np.array, + classes: list, + normalize: bool = True, + title: Optional[str] = None, + text: bool = True, + rotate_x_text: int = 90, + figsize: tuple = (16, 9), + cmap: plt.cm = plt.cm.Blues, +): + """Function to help plot confusion matrices + + https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html + """ + if not title: + if normalize: + title = "Normalized confusion matrix" + else: + title = "Confusion matrix, without normalization" + + # Compute confusion matrix + cm = confusion_matrix(y_true, y_pred) + if normalize: + cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] + + fig, ax = plt.subplots() + im = ax.imshow(cm, interpolation="none", cmap=cmap) + ax.figure.colorbar(im, ax=ax) + ax.set( + xticks=np.arange(cm.shape[1]), + yticks=np.arange(cm.shape[0]), + xticklabels=classes, + yticklabels=classes, + title=title, + ylabel="True label", + xlabel="Predicted label", + ) + ax.set_xticklabels(classes, rotation=rotate_x_text) + ax.figure.set_size_inches(figsize) + + # Loop over data dimensions and create text annotations. + fmt = ".2f" if normalize else "d" + thresh = cm.max() / 2.0 + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + if text: + ax.text( + j, + i, + format(cm[i, j], fmt), + ha="center", + va="center", + color="white" if cm[i, j] > thresh else "black", + ) + if len(classes) == 2: + plt.axis([-0.5, 1.5, 1.5, -0.5]) + fig.tight_layout() + + return ax