Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 26, 2024
1 parent 1990697 commit 6859f9e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
12 changes: 4 additions & 8 deletions examples/network_compression/mnist_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,9 @@ def forward(self, x):

def wavelet_loss(self):
if self.wavelet is None:
return torch.tensor(0.0), torch.tensor(0.0)
return torch.tensor(0.0)
else:
acl, _, _ = self.fc1.wavelet.alias_cancellation_loss()
prl, _, _ = self.fc1.wavelet.perfect_reconstruction_loss()
return acl, prl
return self.fc1.get_wavelet_loss()


def train(args, model, device, train_loader, optimizer, epoch):
Expand All @@ -79,8 +77,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
output = model(data)
nll_loss = F.nll_loss(output, target)
if args.compression == "Wavelet":
acl, prl = model.wavelet_loss()
wvl = acl + prl
wvl = model.wavelet_loss()
loss = nll_loss + wvl * args.wave_loss_weight
else:
wvl = torch.tensor(0.0)
Expand Down Expand Up @@ -117,8 +114,7 @@ def test(args, model, device, test_loader, test_writer, epoch):
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
acl, prl = model.wavelet_loss()
wvl_loss = acl + prl
wvl_loss = model.wavelet_loss()

print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
Expand Down
13 changes: 9 additions & 4 deletions src/ptwt/nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Neural network modules."""

import numpy as np
import pywt
import torch
Expand All @@ -20,11 +22,14 @@ class WaveletLayer(torch.nn.Module):
.. note::
Originally created by moritz (wolter@cs.uni-bonn.de)
at https://github.com/v0lta/Wavelet-network-compression/blob/master/wavelet_learning/wavelet_linear.py
Originally created by moritz (wolter@cs.uni-bonn.de) at
https://github.com/v0lta/Wavelet-network-compression/blob/master/wavelet_learning/wavelet_linear.py
"""

def __init__(self, depth: int, init_wavelet, scales, p_drop=0.5):
def __init__(
self, depth: int, init_wavelet: pywt.Wavelet, scales: int, p_drop: float = 0.5
):
"""Initialize the wavelet layer."""
super().__init__()
self.scales = scales
self.wavelet = init_wavelet
Expand Down Expand Up @@ -98,7 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
step6 = self._mul_s(step5)
return step6

def extra_repr(self) -> str:
def extra_repr(self) -> str: # noqa:D102
return "depth={}".format(self.depth)

def get_wavelet_loss(self) -> torch.Tensor:
Expand Down

0 comments on commit 6859f9e

Please sign in to comment.