From b3398e47604e3d4620ed252a858fcbceff5f0be7 Mon Sep 17 00:00:00 2001 From: Weiqi Gao Date: Fri, 12 Jan 2024 18:57:08 +0800 Subject: [PATCH] Extend torchaudio support to 2.1.x (#3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update README.md * Update README.md * Update README.md * Update README.md * minor fixes for 4.1.0a1 (#552) * minor fixes for 4.1.0a1 print out the exception when calling callback ensures all threads can be stopped when interrupting separation add release data for 4.0.1 * Fix model_idx_in_bag always zero * fix linter * Fix can't separate empty audio * Calls callback when skipping empty audio * Add description for aborting * Does not ignore callback exception * Fix linter * Does not ignore exception * Disable torchaudio 2.2+ * Uses epsilon to deal with empty audio * Reraises exception in callback * Ensure the pool stops when encountering exception * Update windows.md for latest instructions * Minor documentation updates (#565) * Minor documentation updates * Update readme * Update api.md * Fix segment defined in bag can't override model * merge from adefossez/demucs * Update README.md * Extend torchaudio support to 2.1.x * Use correct import statement * Calculate FFT on CPU also when device is XPU (Intel GPU) --------- Co-authored-by: Alexandre Défossez Co-authored-by: William Dye --- demucs/api.py | 1 + demucs/audio.py | 1 + demucs/audio_legacy.py | 17 +++++++++++++++++ demucs/hdemucs.py | 10 ++++++---- demucs/htdemucs.py | 9 +++++---- demucs/repitch.py | 1 + demucs/spec.py | 8 ++++---- demucs/train.py | 1 + demucs/wav.py | 1 + requirements.txt | 2 +- requirements_minimal.txt | 2 +- 11 files changed, 39 insertions(+), 14 deletions(-) create mode 100644 demucs/audio_legacy.py diff --git a/demucs/api.py b/demucs/api.py index 20079a6b..ee8a5126 100644 --- a/demucs/api.py +++ b/demucs/api.py @@ -22,6 +22,7 @@ import subprocess +from . import audio_legacy import torch as th import torchaudio as ta diff --git a/demucs/audio.py b/demucs/audio.py index 31b29b3c..600bd55b 100644 --- a/demucs/audio.py +++ b/demucs/audio.py @@ -10,6 +10,7 @@ import lameenc import julius import numpy as np +from . import audio_legacy import torch import torchaudio as ta import typing as tp diff --git a/demucs/audio_legacy.py b/demucs/audio_legacy.py new file mode 100644 index 00000000..ab6bdce4 --- /dev/null +++ b/demucs/audio_legacy.py @@ -0,0 +1,17 @@ +# This file is to extend support for torchaudio 2.1 + +import importlib +import os +import sys +import warnings + +if not "torchaudio" in sys.modules: + os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "0" +elif os.getenv("TORCHAUDIO_USE_BACKEND_DISPATCHER", default="1") == "1": + if sys.modules["torchaudio"].__version__ >= "2.1": + os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "0" + importlib.reload(sys.modules["torchaudio"]) + warnings.warn( + "TORCHAUDIO_USE_BACKEND_DISPATCHER is set to 0 and torchaudio is reloaded.", + ImportWarning, + ) diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py index 711d4715..9992b60a 100644 --- a/demucs/hdemucs.py +++ b/demucs/hdemucs.py @@ -776,16 +776,18 @@ def forward(self, mix): # demucs issue #435 ##432 # NOTE: in this case z already is on cpu # TODO: remove this when mps supports complex numbers - x_is_mps = x.device.type == "mps" - if x_is_mps: + x_is_mps_xpu = x.device.type in ["mps", "xpu"] + x_device = x.device + if x_is_mps_xpu: x = x.cpu() zout = self._mask(z, x) x = self._ispec(zout, length) # back to mps device - if x_is_mps: - x = x.to('mps') + if x_is_mps_xpu: + x = x.to(x_device) + if self.hybrid: xt = xt.view(B, S, -1, length) diff --git a/demucs/htdemucs.py b/demucs/htdemucs.py index 5d2eaaa1..56568608 100644 --- a/demucs/htdemucs.py +++ b/demucs/htdemucs.py @@ -629,8 +629,9 @@ def forward(self, mix): # demucs issue #435 ##432 # NOTE: in this case z already is on cpu # TODO: remove this when mps supports complex numbers - x_is_mps = x.device.type == "mps" - if x_is_mps: + x_is_mps_xpu = x.device.type in ["mps", "xpu"] + x_device = x.device + if x_is_mps_xpu: x = x.cpu() zout = self._mask(z, x) @@ -643,8 +644,8 @@ def forward(self, mix): x = self._ispec(zout, length) # back to mps device - if x_is_mps: - x = x.to("mps") + if x_is_mps_xpu: + x = x.to(x_device) if self.use_train_segment: if self.training: diff --git a/demucs/repitch.py b/demucs/repitch.py index ebef7364..b69c0d25 100644 --- a/demucs/repitch.py +++ b/demucs/repitch.py @@ -9,6 +9,7 @@ import subprocess as sp import tempfile +from . import audio_legacy import torch import torchaudio as ta diff --git a/demucs/spec.py b/demucs/spec.py index 29250459..d8f6ee5e 100644 --- a/demucs/spec.py +++ b/demucs/spec.py @@ -11,8 +11,8 @@ def spectro(x, n_fft=512, hop_length=None, pad=0): *other, length = x.shape x = x.reshape(-1, length) - is_mps = x.device.type == 'mps' - if is_mps: + is_mps_xpu = x.device.type in ['mps', 'xpu'] + if is_mps_xpu: x = x.cpu() z = th.stft(x, n_fft * (1 + pad), @@ -32,8 +32,8 @@ def ispectro(z, hop_length=None, length=None, pad=0): n_fft = 2 * freqs - 2 z = z.view(-1, freqs, frames) win_length = n_fft // (1 + pad) - is_mps = z.device.type == 'mps' - if is_mps: + is_mps_xpu = z.device.type in ['mps', 'xpu'] + if is_mps_xpu: z = z.cpu() x = th.istft(z, n_fft, diff --git a/demucs/train.py b/demucs/train.py index 9aa7b64b..e045b83f 100644 --- a/demucs/train.py +++ b/demucs/train.py @@ -15,6 +15,7 @@ import hydra from hydra.core.global_hydra import GlobalHydra from omegaconf import OmegaConf +from . import audio_legacy import torch from torch import nn import torchaudio diff --git a/demucs/wav.py b/demucs/wav.py index 6acb9b5d..ca1e23a3 100644 --- a/demucs/wav.py +++ b/demucs/wav.py @@ -15,6 +15,7 @@ import musdb import julius +from . import audio_legacy import torch as th from torch import distributed import torchaudio as ta diff --git a/requirements.txt b/requirements.txt index 294290d3..d4832a2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ openunmix pyyaml submitit torch>=1.8.1 -torchaudio>=0.8,<2.1 +torchaudio>=0.8,<2.2 tqdm treetable soundfile>=0.10.3;sys_platform=="win32" diff --git a/requirements_minimal.txt b/requirements_minimal.txt index 1940bf01..dcae84bc 100644 --- a/requirements_minimal.txt +++ b/requirements_minimal.txt @@ -6,5 +6,5 @@ lameenc>=1.2 openunmix pyyaml torch>=1.8.1 -torchaudio>=0.8,<2.1 +torchaudio>=0.8,<2.2 tqdm