From 6859f9eedb4ffa3e7dcb528078ef175b59f06c5e Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Fri, 26 Jan 2024 10:40:24 +0100 Subject: [PATCH] Clean up --- examples/network_compression/mnist_compression.py | 12 ++++-------- src/ptwt/nn.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/network_compression/mnist_compression.py b/examples/network_compression/mnist_compression.py index 917e4a33..53030eeb 100644 --- a/examples/network_compression/mnist_compression.py +++ b/examples/network_compression/mnist_compression.py @@ -64,11 +64,9 @@ def forward(self, x): def wavelet_loss(self): if self.wavelet is None: - return torch.tensor(0.0), torch.tensor(0.0) + return torch.tensor(0.0) else: - acl, _, _ = self.fc1.wavelet.alias_cancellation_loss() - prl, _, _ = self.fc1.wavelet.perfect_reconstruction_loss() - return acl, prl + return self.fc1.get_wavelet_loss() def train(args, model, device, train_loader, optimizer, epoch): @@ -79,8 +77,7 @@ def train(args, model, device, train_loader, optimizer, epoch): output = model(data) nll_loss = F.nll_loss(output, target) if args.compression == "Wavelet": - acl, prl = model.wavelet_loss() - wvl = acl + prl + wvl = model.wavelet_loss() loss = nll_loss + wvl * args.wave_loss_weight else: wvl = torch.tensor(0.0) @@ -117,8 +114,7 @@ def test(args, model, device, test_loader, test_writer, epoch): correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - acl, prl = model.wavelet_loss() - wvl_loss = acl + prl + wvl_loss = model.wavelet_loss() print( "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( diff --git a/src/ptwt/nn.py b/src/ptwt/nn.py index 15c37cb3..51876116 100644 --- a/src/ptwt/nn.py +++ b/src/ptwt/nn.py @@ -1,3 +1,5 @@ +"""Neural network modules.""" + import numpy as np import pywt import torch @@ -20,11 +22,14 @@ class WaveletLayer(torch.nn.Module): .. note:: - Originally created by moritz (wolter@cs.uni-bonn.de) - at https://github.com/v0lta/Wavelet-network-compression/blob/master/wavelet_learning/wavelet_linear.py + Originally created by moritz (wolter@cs.uni-bonn.de) at + https://github.com/v0lta/Wavelet-network-compression/blob/master/wavelet_learning/wavelet_linear.py """ - def __init__(self, depth: int, init_wavelet, scales, p_drop=0.5): + def __init__( + self, depth: int, init_wavelet: pywt.Wavelet, scales: int, p_drop: float = 0.5 + ): + """Initialize the wavelet layer.""" super().__init__() self.scales = scales self.wavelet = init_wavelet @@ -98,7 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: step6 = self._mul_s(step5) return step6 - def extra_repr(self) -> str: + def extra_repr(self) -> str: # noqa:D102 return "depth={}".format(self.depth) def get_wavelet_loss(self) -> torch.Tensor: