diff --git a/audiotools/__init__.py b/audiotools/__init__.py index 62fcb9cd..1df5af1e 100644 --- a/audiotools/__init__.py +++ b/audiotools/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.0" +__version__ = "0.4.1" from .core import AudioSignal, STFTParams, Meter, util from . import metrics from . import data diff --git a/audiotools/ml/__init__.py b/audiotools/ml/__init__.py index f5b74a7c..dda11efe 100644 --- a/audiotools/ml/__init__.py +++ b/audiotools/ml/__init__.py @@ -1,4 +1,3 @@ -from . import registry from . import tricks from .accelerator import Accelerator from .experiment import Experiment diff --git a/audiotools/ml/experiment.py b/audiotools/ml/experiment.py index 7aaf7d0a..eaed9929 100644 --- a/audiotools/ml/experiment.py +++ b/audiotools/ml/experiment.py @@ -1,14 +1,15 @@ """ Experiment tracking. """ +import datetime import os import shlex import shutil -import socket import subprocess -from datetime import datetime from pathlib import Path +import randomname + class Experiment: def __init__( @@ -57,8 +58,9 @@ def __exit__(self, exc_type, exc_value, traceback): @staticmethod def generate_exp_name(): - current_time = datetime.now().strftime("%b%d_%H-%M-%S") - return current_time + "_" + socket.gethostname() + date = datetime.datetime.now().strftime("%y%m%d") + name = f"{date}-{randomname.get_name()}" + return name def snapshot(self, filter_fn=lambda f: True): """Captures a full snapshot of all the files tracked by git at the time diff --git a/audiotools/ml/layers/base.py b/audiotools/ml/layers/base.py index ec949876..196ea4f6 100644 --- a/audiotools/ml/layers/base.py +++ b/audiotools/ml/layers/base.py @@ -1,6 +1,7 @@ import inspect import shutil import tempfile +from pathlib import Path import torch from torch import nn @@ -119,3 +120,44 @@ def _load_package(cls, path, package_name=None): model.importer = imp return model + + def save_to_folder( + self, + folder: str, + extra_data: dict = None, + ): + extra_data = {} if extra_data is None else extra_data + model_name = type(self).__name__.lower() + target_base = Path(f"{folder}/{model_name}/") + target_base.mkdir(exist_ok=True, parents=True) + + package_path = target_base / f"package.pth" + weights_path = target_base / f"weights.pth" + + self.save(package_path) + self.save(weights_path, package=False) + + for path, obj in extra_data.items(): + torch.save(obj, target_base / path) + + return target_base + + @classmethod + def load_from_folder( + cls, + folder: Path, + package: bool = True, + strict: bool = False, + ): + folder = Path(folder) / cls.__name__.lower() + model_pth = "package.pth" if package else "weights.pth" + model_pth = folder / model_pth + + model = cls.load(model_pth, strict=strict) + extra_data = {} + excluded = ["package.pth", "weights.pth"] + files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded] + for f in files: + extra_data[f.name] = torch.load(folder / f) + + return model, extra_data diff --git a/audiotools/ml/registry.py b/audiotools/ml/registry.py deleted file mode 100644 index 4df61aa7..00000000 --- a/audiotools/ml/registry.py +++ /dev/null @@ -1,137 +0,0 @@ -# Registering models derived from BaseModel in a -# local filestore or in a GCP bucket -import datetime -import glob -import shlex -import subprocess -import tempfile -from pathlib import Path -from typing import Type - -import rich -from flatten_dict import unflatten -from rich.text import Text -from rich.tree import Tree - -from .layers.base import BaseModel - - -def convert_to_tree(d, tree: Tree): - for k in d: - if not isinstance(d[k], dict): - prefix = "✅ " if d[k] else "❌ " - style = "green" if d[k] else "red" - tree.add(Text(prefix + k, style=style)) - else: - convert_to_tree(d[k], tree.add(k)) - return tree - - -class BaseModelRegistry: - def __init__( - self, - location: str, - cache: str = None, - ): - cache = location if cache is None else cache - self.location = str(location) - self.cache = Path(cache) - - def copy(self, src, dst): # pragma: no cover - raise NotImplementedError() - - def upload_model( - self, - model: Type[BaseModel], - domain: str, - version: str = None, - ): - model_name = type(model).__name__.lower() - if version is None: - version = datetime.datetime.now().strftime("%Y%m%d") - target_base = f"{self.location}/{domain}/{version}/{model_name}/" - Path(target_base).mkdir(exist_ok=True, parents=True) - - with tempfile.TemporaryDirectory() as tmpdir: - package_path = Path(tmpdir) / f"package.pth" - weights_path = Path(tmpdir) / f"weights.pth" - - model.save(package_path) - model.save(weights_path, package=False) - - self.copy(package_path, target_base + "package.pth") - self.copy(weights_path, target_base + "weights.pth") - - return target_base - - def upload( - self, - local_path: str, - domain: str, - path: str, - ): - remote_path = f"{self.location}/{domain}/{path}" - self.copy(local_path, remote_path) - - def download( - self, - domain: str, - path: str, - overwrite: bool = False, - ): - # Check if model exists locally. - local_path = self.cache / domain / path - remote_path = f"{self.location}/{domain}/{path}" - local_path.parent.mkdir(exist_ok=True, parents=True) - - if not local_path.exists() or overwrite: - self.copy(remote_path, local_path) - - return local_path - - def get_files( - self, - domain: str, - ): # pragma: no cover - raise NotImplementedError() - - def list_models(self, domain: str): - files = self.get_files(domain) - - def exists(f): - local_path = self.cache / str(f).split(self.location)[-1] - return local_path.exists() - - _files = unflatten({str(f): exists(f) for f in files}, splitter="path") - tree = convert_to_tree(_files, Tree(self.location)) - rich.print(tree) - return [str(f).split(domain + "/")[-1] for f in files] - - -class LocalModelRegistry(BaseModelRegistry): - def copy(self, src, dst): - Path(dst).parent.mkdir(exist_ok=True, parents=True) - command = f"cp -r {str(src)} {str(dst)}" - subprocess.check_call(shlex.split(command)) - - def get_files(self, domain: str): - base_path = f"{self.location}/{domain}" - files = glob.glob(f"{base_path}/**", recursive=True) - files = [Path(f).relative_to(self.location) for f in files if Path(f).is_file()] - return files - - -class GCPModelRegistry(BaseModelRegistry): # pragma: no cover - def copy(self, src, dst): - command = f"gsutil -m cp -r {str(src)} {str(dst)}" - subprocess.check_call(shlex.split(command)) - - def get_files(self, domain: str): - base_path = f"{self.location}/{domain}" - command = f"gsutil ls {base_path}/**" - - files = ( - subprocess.check_output(shlex.split(command)).decode("utf-8").splitlines() - ) - files = [Path(f).relative_to(self.location) for f in files] - return files diff --git a/setup.py b/setup.py index 495d0b35..bb410534 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="audiotools", - version="0.4.0", + version="0.4.1", classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Education", @@ -53,6 +53,7 @@ "flatten-dict", "markdown2", "pytorch-ignite", + "randomname", # Have to freeze protobuf version, https://github.com/protocolbuffers/protobuf/issues/10051 # Borrowing pin from tensorboard source: https://github.com/tensorflow/tensorboard/commit/fd4f5ff79374252e313c2e7e9b247bc49ab0d54d. "protobuf >= 3.9.2, < 3.20", diff --git a/tests/ml/test_model.py b/tests/ml/test_model.py index a0cfc29e..fbb08ccc 100644 --- a/tests/ml/test_model.py +++ b/tests/ml/test_model.py @@ -79,3 +79,7 @@ def test_base_model(): assert torch.allclose(out1, out2) assert torch.allclose(out1, out3) + + with tempfile.TemporaryDirectory() as d: + model1.save_to_folder(d, {"data": 1.0}) + Model.load_from_folder(d) diff --git a/tests/ml/test_registry.py b/tests/ml/test_registry.py deleted file mode 100644 index 18d64876..00000000 --- a/tests/ml/test_registry.py +++ /dev/null @@ -1,69 +0,0 @@ -import datetime -import json -import tempfile -from pathlib import Path - -from torch import nn - -from audiotools import ml - - -class Generator(ml.BaseModel): - def __init__(self): - super().__init__() - self.linear = nn.Linear(1, 1) - - def forward(self, x): - return self.linear(x) - - -class Discriminator(ml.BaseModel): - def __init__(self): - super().__init__() - self.linear = nn.Linear(1, 1) - - def forward(self, x): - return self.linear(x) - - -def test_local_registry(): - ml.BaseModel.EXTERN += ["test_registry"] - - with tempfile.TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - registry = ml.registry.LocalModelRegistry(tmpdir / "remote", tmpdir / "cache") - - generator = Generator() - discriminator = Discriminator() - - version = datetime.datetime.now().strftime("%Y%m%d") - gen_path = registry.upload_model(generator, "domain") - disc_path = registry.upload_model(discriminator, "domain") - - assert version in gen_path - assert version in disc_path - - version = "test" - gen_path = registry.upload_model(generator, "domain", version=version) - disc_path = registry.upload_model(discriminator, "domain", version=version) - - assert version in gen_path - assert version in disc_path - - models = registry.list_models("domain") - for model in models: - registry.download("domain", model) - registry.list_models("domain") - - with open(tmpdir / "metadata.json", "w") as f: - d = {"test": "test"} - json.dump(d, f) - registry.upload( - tmpdir / "metadata.json", "domain", f"{version}/metadata/metadata_a.json" - ) - registry.upload( - tmpdir / "metadata.json", "domain", f"{version}/metadata/metadata_b.json" - ) - registry.list_models("domain") - registry.download("domain", f"{version}/metadata") - registry.list_models("domain")