Skip to content

Commit

Permalink
More documentation and examples (#134)
Browse files Browse the repository at this point in the history
* More documentation and examples

* Added script to train
  • Loading branch information
gvanhoy authored Jun 9, 2023
1 parent 634dc2f commit beb19c6
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime

ENV DEBIAN_FRONTEND=noninteractive

Expand Down
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ cd torchsig
pip install .
```

## Generating the Datasets
If you'd like to generate the named datasets without messing with your current Python environment, you can build the development container and use it to generate data at the location of your choosing.

```
docker build -t torchsig -f Dockerfile .
docker run -u $(id -u ${USER}):$(id -g ${USER}) -v `pwd`:/workspace/code/torchsig torchsig python3 torchsig/scripts/generate_sig53.py --root=/workspace/code/torchsig/data --all=True
```

If you do not need to use Docker, you can also just generate using the regular command-line interface

```
python3 torchsig/scripts/generate_sig53.py --root=torchsig/data --all=True
```

Then, be sure to point scripts looking for ```root``` to ```torchsig/data```.

## Using the Dockerfile
If you have Docker installed along with compatible GPUs and drivers, you can try:

Expand Down
2 changes: 1 addition & 1 deletion examples/02_example_sig53_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer
from sklearn.metrics import classification_report
from cm_plotter import plot_confusion_matrix
from torchsig.utils.cm_plotter import plot_confusion_matrix
from torchsig.datasets.sig53 import Sig53
from torchsig.datasets.modulations import ModulationsDataset
from torch.utils.data import DataLoader
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"sympy",
"numba",
"torchmetrics",
"click"
]
dynamic = ["version"]

Expand Down
64 changes: 64 additions & 0 deletions scripts/generate_sig53.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from torchsig.utils.writer import DatasetCreator, DatasetLoader
from torchsig.datasets.modulations import ModulationsDataset
from torchsig.datasets import conf
from typing import List
import click
import os


def generate(path: str, configs: List[conf.Sig53Config]):
for config in configs:
ds = ModulationsDataset(
level=config.level,
num_samples=config.num_samples,
num_iq_samples=config.num_iq_samples,
use_class_idx=config.use_class_idx,
include_snr=config.include_snr,
eb_no=config.eb_no,
)
loader = DatasetLoader(
ds,
seed=12345678,
num_workers=os.cpu_count() // 2,
batch_size=os.cpu_count() // 2,
)
creator = DatasetCreator(
ds,
seed=12345678,
path="{}".format(os.path.join(path, config.name)),
loader=loader,
)
creator.create()


@click.command()
@click.option("--root", default="sig53", help="Path to generate sig53 datasets")
@click.option("--all", default=True, help="Generate all versions of sig53 dataset.")
@click.option(
"--impaired",
default=False,
help="Generate impaired dataset. Ignored if --all=True (default)",
)
def main(root: str, all: bool, impaired: bool):
if not os.root.isdir(root):
os.mkdir(root)

configs = [
conf.Sig53CleanTrainConfig,
conf.Sig53CleanValConfig,
conf.Sig53ImpairedTrainConfig,
conf.Sig53ImpairedValConfig,
]
if all:
generate(root, configs)
return

if impaired:
generate(root, configs[2:])
return

generate(root, configs[:2])


if __name__ == "__main__":
main()
188 changes: 188 additions & 0 deletions scripts/train_sig53.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from torchsig.transforms.target_transforms import DescToClassIndex
from torchsig.models.iq_models.efficientnet.efficientnet import efficientnet_b4
from torchsig.transforms.transforms import (
RandomPhaseShift,
Normalize,
ComplexTo2D,
Compose,
)
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer
from sklearn.metrics import classification_report
from torchsig.utils.cm_plotter import plot_confusion_matrix
from torchsig.datasets.sig53 import Sig53
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torch import optim
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np
import click
import torch
import os


class ExampleNetwork(LightningModule):
def __init__(self, model, data_loader, val_data_loader):
super(ExampleNetwork, self).__init__()
self.mdl: torch.nn.Module = model
self.data_loader: DataLoader = data_loader
self.val_data_loader: DataLoader = val_data_loader

# Hyperparameters
self.lr = 0.001
self.batch_size = data_loader.batch_size

def forward(self, x: torch.Tensor):
return self.mdl(x.float())

def predict(self, x: torch.Tensor):
with torch.no_grad():
out = self.forward(x.float())
return out

def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.lr)

def train_dataloader(self):
return self.data_loader

def val_dataloader(self):
return self.val_data_loader

def training_step(self, batch: torch.Tensor, batch_nb: int):
x, y = batch
y = torch.squeeze(y.to(torch.int64))
loss = F.cross_entropy(self(x.float()), y)
self.log("loss", loss, on_step=True, prog_bar=True, logger=True)
return loss

def validation_step(self, batch: torch.Tensor, batch_nb: int):
x, y = batch
y = torch.squeeze(y.to(torch.int64))
loss = F.cross_entropy(self(x.float()), y)
self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
return loss


@click.command()
@click.option("--root", default="data/sig53", help="Path to train/val datasets")
@click.option("--impaired", default=False, help="Impaired or clean datasets")
def main(root: str, impaired: bool):
class_list = list(Sig53._idx_to_name_dict.values())
transform = Compose(
[
RandomPhaseShift(phase_offset=(-1, 1)),
Normalize(norm=np.inf),
ComplexTo2D(),
]
)
target_transform = DescToClassIndex(class_list=class_list)

sig53_train = Sig53(
root,
train=True,
impaired=impaired,
transform=transform,
target_transform=target_transform,
use_signal_data=True,
)

sig53_val = Sig53(
root,
train=False,
impaired=impaired,
transform=transform,
target_transform=target_transform,
use_signal_data=True,
)

# Create dataloaders"data
train_dataloader = DataLoader(
dataset=sig53_train,
batch_size=os.cpu_count(),
num_workers=os.cpu_count() // 2,
shuffle=True,
drop_last=True,
)
val_dataloader = DataLoader(
dataset=sig53_val,
batch_size=os.cpu_count(),
num_workers=os.cpu_count() // 2,
shuffle=False,
drop_last=True,
)

model = efficientnet_b4(pretrained=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

example_model = ExampleNetwork(model, train_dataloader, val_dataloader)
example_model = example_model.to(device)

# Setup checkpoint callbacks
checkpoint_filename = "{}/checkpoint".format(os.getcwd())
checkpoint_callback = ModelCheckpoint(
filename=checkpoint_filename,
save_top_k=True,
monitor="val_loss",
mode="min",
)

# Create and fit trainer
epochs = 500
trainer = Trainer(
max_epochs=epochs, callbacks=checkpoint_callback, devices=1, accelerator="gpu"
)
trainer.fit(example_model)

# Load best checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(
checkpoint_filename + ".ckpt", map_location=lambda storage, loc: storage
)
example_model.load_state_dict(checkpoint["state_dict"])
example_model = example_model.to(device=device).eval()

# Infer results over validation set
num_test_examples = len(sig53_val)
num_classes = len(list(Sig53._idx_to_name_dict.values()))
y_raw_preds = np.empty((num_test_examples, num_classes))
y_preds = np.zeros((num_test_examples,))
y_true = np.zeros((num_test_examples,))

for i in tqdm(range(0, num_test_examples)):
# Retrieve data
idx = i # Use index if evaluating over full dataset
data, label = sig53_val[idx]
# Infer
data = torch.from_numpy(np.expand_dims(data, 0)).float().to(device)
pred_tmp = example_model.predict(data)
pred_tmp = pred_tmp.cpu().numpy() if torch.cuda.is_available() else pred_tmp
# Argmax
y_preds[i] = np.argmax(pred_tmp)
# Store label
y_true[i] = label

acc = np.sum(np.asarray(y_preds) == np.asarray(y_true)) / len(y_true)
plot_confusion_matrix(
y_true,
y_preds,
classes=class_list,
normalize=True,
title="Example Modulations Confusion Matrix\nTotal Accuracy: {:.2f}%".format(
acc * 100
),
text=False,
rotate_x_text=90,
figsize=(16, 9),
)
plt.savefig("{}/02_sig53_classifier.png".format(os.getcwd()))

print("Classification Report:")
print(classification_report(y_true, y_preds))


if __name__ == "__main__":
main()
66 changes: 66 additions & 0 deletions torchsig/utils/cm_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from typing import Optional
import numpy as np


def plot_confusion_matrix(
y_true: np.array,
y_pred: np.array,
classes: list,
normalize: bool = True,
title: Optional[str] = None,
text: bool = True,
rotate_x_text: int = 90,
figsize: tuple = (16, 9),
cmap: plt.cm = plt.cm.Blues,
):
"""Function to help plot confusion matrices
https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
"""
if not title:
if normalize:
title = "Normalized confusion matrix"
else:
title = "Confusion matrix, without normalization"

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation="none", cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(
xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=classes,
yticklabels=classes,
title=title,
ylabel="True label",
xlabel="Predicted label",
)
ax.set_xticklabels(classes, rotation=rotate_x_text)
ax.figure.set_size_inches(figsize)

# Loop over data dimensions and create text annotations.
fmt = ".2f" if normalize else "d"
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
if text:
ax.text(
j,
i,
format(cm[i, j], fmt),
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
if len(classes) == 2:
plt.axis([-0.5, 1.5, 1.5, -0.5])
fig.tight_layout()

return ax

0 comments on commit beb19c6

Please sign in to comment.