Skip to content

Commit

Permalink
Refactor fallback logic into its own function
Browse files Browse the repository at this point in the history
  • Loading branch information
arkrow committed Aug 8, 2024
1 parent 1d4ab1f commit 4ae760b
Showing 1 changed file with 59 additions and 24 deletions.
83 changes: 59 additions & 24 deletions python/pymusiclooper/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,57 @@ def _db_diff(power_db_f1: np.ndarray, power_db_f2: np.ndarray) -> float:
def _norm(a: np.ndarray) -> float:
return np.sqrt(np.sum(np.abs(a) ** 2, axis=0))


@njit(cache=True)
def _find_candidate_pairs_fallback(
chroma: np.ndarray,
power_db: np.ndarray,
beats: np.ndarray,
deviation: np.ndarray,
min_loop_duration: int,
max_loop_duration: int,
acceptable_loudness_difference: float,
) -> List[Tuple[int, int, float, float]]:
"""Python fallback function.
Generates a list of all valid candidate loop pairs using combinations of beat indices,
by comparing the notes using the chroma spectrogram and their loudness difference
Args:
chroma (np.ndarray): The chroma spectrogram
power_db (np.ndarray): The power spectrogram in dB
beats (np.ndarray): The frame indices of detected beats
min_loop_duration (int): Minimum loop duration (in frames)
max_loop_duration (int): Maximum loop duration (in frames)
Returns:
List[Tuple[int, int, float, float]]: A list of tuples containing each candidate loop pair data in the following format (loop_start, loop_end, note_distance, loudness_difference)
"""
candidate_pairs = []

for idx, loop_end in enumerate(beats):
for loop_start in beats:
loop_length = loop_end - loop_start
if loop_length < min_loop_duration:
break
if loop_length > max_loop_duration:
continue
note_distance = _norm(chroma[..., loop_end] - chroma[..., loop_start])

if note_distance <= deviation[idx]:
loudness_difference = _db_diff(
power_db[..., loop_end], power_db[..., loop_start]
)
if loudness_difference <= acceptable_loudness_difference:
loop_pair = (
int(loop_start),
int(loop_end),
note_distance,
loudness_difference,
)
candidate_pairs.append(loop_pair)
return candidate_pairs


def _find_candidate_pairs(
chroma: np.ndarray,
power_db: np.ndarray,
Expand Down Expand Up @@ -313,30 +363,15 @@ def _find_candidate_pairs(
)
except Exception:
# Python fallback
candidate_pairs = []

for idx, loop_end in enumerate(beats):
for loop_start in beats:
loop_length = loop_end - loop_start
if loop_length < min_loop_duration:
break
if loop_length > max_loop_duration:
continue
note_distance = _norm(chroma[..., loop_end] - chroma[..., loop_start])

if note_distance <= deviation[idx]:
loudness_difference = _db_diff(
power_db[..., loop_end], power_db[..., loop_start]
)
loop_pair = (
int(loop_start),
int(loop_end),
note_distance,
loudness_difference,
)
if loudness_difference <= ACCEPTABLE_LOUDNESS_DIFFERENCE:
candidate_pairs.append(loop_pair)
return candidate_pairs
return _find_candidate_pairs_fallback(
chroma,
power_db,
beats,
deviation,
min_loop_duration,
max_loop_duration,
ACCEPTABLE_LOUDNESS_DIFFERENCE,
)


def _assess_and_filter_loop_pairs(
Expand Down

0 comments on commit 4ae760b

Please sign in to comment.