Skip to content

Commit

Permalink
Removing registry, refactored useful functionality into BaseModel. (#52)
Browse files Browse the repository at this point in the history
* Adding a function to save package/weights for BaseModel  plus extra data. Removed registry. Updated experiment name generation.

* Experiments get date, added load_from_folder.

* lower case

Co-authored-by: pseeth <prem@descript.com>
  • Loading branch information
pseeth and pseeth authored Aug 31, 2022
1 parent f2d2f64 commit 97585ed
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 213 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.0"
__version__ = "0.4.1"
from .core import AudioSignal, STFTParams, Meter, util
from . import metrics
from . import data
Expand Down
1 change: 0 additions & 1 deletion audiotools/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import registry
from . import tricks
from .accelerator import Accelerator
from .experiment import Experiment
Expand Down
10 changes: 6 additions & 4 deletions audiotools/ml/experiment.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
Experiment tracking.
"""
import datetime
import os
import shlex
import shutil
import socket
import subprocess
from datetime import datetime
from pathlib import Path

import randomname


class Experiment:
def __init__(
Expand Down Expand Up @@ -57,8 +58,9 @@ def __exit__(self, exc_type, exc_value, traceback):

@staticmethod
def generate_exp_name():
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
return current_time + "_" + socket.gethostname()
date = datetime.datetime.now().strftime("%y%m%d")
name = f"{date}-{randomname.get_name()}"
return name

def snapshot(self, filter_fn=lambda f: True):
"""Captures a full snapshot of all the files tracked by git at the time
Expand Down
42 changes: 42 additions & 0 deletions audiotools/ml/layers/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import shutil
import tempfile
from pathlib import Path

import torch
from torch import nn
Expand Down Expand Up @@ -119,3 +120,44 @@ def _load_package(cls, path, package_name=None):
model.importer = imp

return model

def save_to_folder(
self,
folder: str,
extra_data: dict = None,
):
extra_data = {} if extra_data is None else extra_data
model_name = type(self).__name__.lower()
target_base = Path(f"{folder}/{model_name}/")
target_base.mkdir(exist_ok=True, parents=True)

package_path = target_base / f"package.pth"
weights_path = target_base / f"weights.pth"

self.save(package_path)
self.save(weights_path, package=False)

for path, obj in extra_data.items():
torch.save(obj, target_base / path)

return target_base

@classmethod
def load_from_folder(
cls,
folder: Path,
package: bool = True,
strict: bool = False,
):
folder = Path(folder) / cls.__name__.lower()
model_pth = "package.pth" if package else "weights.pth"
model_pth = folder / model_pth

model = cls.load(model_pth, strict=strict)
extra_data = {}
excluded = ["package.pth", "weights.pth"]
files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
for f in files:
extra_data[f.name] = torch.load(folder / f)

return model, extra_data
137 changes: 0 additions & 137 deletions audiotools/ml/registry.py

This file was deleted.

3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.4.0",
version="0.4.1",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down Expand Up @@ -53,6 +53,7 @@
"flatten-dict",
"markdown2",
"pytorch-ignite",
"randomname",
# Have to freeze protobuf version, https://github.com/protocolbuffers/protobuf/issues/10051
# Borrowing pin from tensorboard source: https://github.com/tensorflow/tensorboard/commit/fd4f5ff79374252e313c2e7e9b247bc49ab0d54d.
"protobuf >= 3.9.2, < 3.20",
Expand Down
4 changes: 4 additions & 0 deletions tests/ml/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ def test_base_model():
assert torch.allclose(out1, out2)

assert torch.allclose(out1, out3)

with tempfile.TemporaryDirectory() as d:
model1.save_to_folder(d, {"data": 1.0})
Model.load_from_folder(d)
69 changes: 0 additions & 69 deletions tests/ml/test_registry.py

This file was deleted.

0 comments on commit 97585ed

Please sign in to comment.