Skip to content

Commit

Permalink
update pyproject.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinedaurat committed Jan 10, 2025
1 parent eb61251 commit 7b4950d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 14 deletions.
4 changes: 1 addition & 3 deletions mimikit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,14 @@ def save(cls,
net_dict = network.state_dict()
opt_dict = optimizer.state_dict() if optimizer is not None else {}
cls.network.set_ds_kwargs(net_dict)
# if optimizer is not None:
# cls.optimizer.set_ds_kwargs(opt_dict)
os.makedirs(os.path.split(filename)[0], exist_ok=True)

bank = cls(filename, mode="w")
bank.network.attrs["config"] = network.config.serialize()
bank.network.add("state_dict", h5m.TensorDict.format(net_dict))

if optimizer is not None:
# bank.optimizer.add("state_dict", h5m.TensorDict.format(opt_dict))
# opt is saved in a separate file
torch.save(opt_dict, os.path.splitext(filename)[0] + ".opt")

if training_config is not None:
Expand Down
1 change: 0 additions & 1 deletion mimikit/loops/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ class MMKCheckpoint(Callback):
def __init__(self,
epochs=None,
root_dir=''
# todo: save_optimizer
):
super().__init__()
self.epochs = epochs
Expand Down
3 changes: 1 addition & 2 deletions mimikit/views/clusterizer_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,5 @@ def on_reset_selected_labels(ev):
W.HBox(children=(reset_label_button, bounce),
layout=dict(margin="8px auto",
)),
self.labels_grid,
W.HTML("<h4>Selected Labels Segments Table: </h4>")
self.labels_grid
)
15 changes: 7 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,18 @@ classifiers = [
]
dependencies = [
"ffmpeg-python",
"h5mapper>=0.3.1",
"ipywidgets==7.7.1",
"librosa>=0.9.1",
"h5mapper>=0.3.3",
"ipywidgets>=8.1.0",
"librosa>=0.10.0,<1",
"matplotlib",
"numba",
"numpy>=1.19.1",
"numpy>=1.19.1,<2",
"omegaconf>=2.3.0",
"pandas>=1.1.3",
"peaksjs_widget",
"peaksjs_widget>=0.2.1",
"pyamg",
"pydub",
"pypbind",
"qgrid",
"scikit-learn>=1.0.0",
"scipy>=1.4.1",
"soundfile>=0.10.2",
Expand All @@ -60,8 +59,8 @@ dynamic = [

[project.optional-dependencies]
colab = [
"torchaudio==2.0.1+cu118",
"pytorch-lightning>=2.0.2",
"torchaudio==2.2.2",
"pytorch-lightning>=2.5.0",
]
torch = [
"torch==2.2.2",
Expand Down

0 comments on commit 7b4950d

Please sign in to comment.