Skip to content

Commit

Permalink
fix: makes utility functions more robust against missing data
Browse files Browse the repository at this point in the history
  • Loading branch information
johentsch committed Jan 16, 2024
1 parent 9d5ff19 commit f408de3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 38 deletions.
60 changes: 29 additions & 31 deletions src/ms3/transformations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions for transforming DataFrames as output by ms3."""
import logging
import sys
import warnings
from fractions import Fraction as frac
Expand Down Expand Up @@ -369,7 +370,13 @@ def add_weighted_grace_durations(notes, weight=1 / 2, logger=None):
return notes


def compute_chord_tones(df, bass_only=False, expand=False, cols={}, logger=None):
def compute_chord_tones(
df: pd.DataFrame,
bass_only: bool = False,
expand: bool = False,
cols: Optional[dict] = None,
logger: Optional[logging.Logger] = None,
) -> pd.DataFrame | pd.Series:
"""
Compute the chord tones for DCML harmony labels. They are returned as lists
of tonal pitch classes in close position, starting with the bass note. The
Expand All @@ -386,35 +393,18 @@ def compute_chord_tones(df, bass_only=False, expand=False, cols={}, logger=None)
Uses: :py:func:`features2tpcs`
Parameters
----------
df : :obj:`pandas.DataFrame`
Dataframe containing DCML chord labels that have been split by split_labels()
and where the keys have been propagated using propagate_keys(add_bool=True).
bass_only : :obj:`bool`, optional
Pass True if you need only the bass note.
expand : :obj:`bool`, optional
Pass True if you need chord tones and added tones in separate columns.
cols : :obj:`dict`, optional
In case the column names for ``['mc', 'numeral', 'form', 'figbass', 'changes', 'relativeroot', 'localkey',
'globalkey']`` deviate, pass a dict, such as
.. code-block:: python
{'mc': 'mc',
'numeral': 'numeral_col_name',
'form': 'form_col_name',
'figbass': 'figbass_col_name',
'changes': 'changes_col_name',
'relativeroot': 'relativeroot_col_name',
'localkey': 'localkey_col_name',
'globalkey': 'globalkey_col_name'}
You may also deactivate columns by setting them to None, e.g. {'changes': None}
Args:
df:
Dataframe containing DCML chord labels that have been split by :func:`split_labels`
and where the keys have been propagated using propagate_keys(add_bool=True).
bass_only: Pass True if you need only the bass note.
expand: Pass True if you need chord tones and added tones in separate columns. Otherwise a Series is returned.
cols:
In case the column names for ``['mc', 'numeral', 'form', 'figbass', 'changes', 'relativeroot', 'localkey',
'globalkey']`` deviate, pass a dict, such as
logger:
Returns
-------
:obj:`pandas.Series` or :obj:`pandas.DataFrame`
Returns:
For every row of `df` one tuple with chord tones, expressed as tonal pitch classes.
If `expand` is True, the function returns a DataFrame with four columns:
Two with tuples for chord tones and added tones, one with the chord root,
Expand All @@ -424,7 +414,8 @@ def compute_chord_tones(df, bass_only=False, expand=False, cols={}, logger=None)
logger = module_logger
elif isinstance(logger, str):
logger = get_logger(logger)

if cols is None:
cols = {}
df = df.copy()
# If the index is not unique, it has to be temporarily replaced
tmp_index = not df.index.is_unique
Expand Down Expand Up @@ -1898,7 +1889,14 @@ def transpose_chord_tones_by_localkey(df, by_global=False):
:obj:`pandas.DataFrame`
"""
df = df.copy()
ct_cols = ["chord_tones", "added_tones", "root", "bass_note"]
ct_cols = [
col
for col in ("chord_tones", "added_tones", "root", "bass_note")
if col in df.columns
]
assert (
ct_cols
), "Found none of the expected chord-tone columns {'chord_tones', 'added_tones', 'root', 'bass_note'}."
ct = df[ct_cols]
transpose_by = transform(
df, roman_numeral2fifths, ["localkey", "globalkey_is_minor"]
Expand Down
17 changes: 10 additions & 7 deletions src/ms3/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3534,7 +3534,7 @@ def name2pc(nn, logger=None):

def nan_eq(a, b):
"""Returns True if a and b are equal or both null. Works on two Series or two elements."""
return (a == b) | (pd.isnull(a) & pd.isnull(b))
return (a == b).fillna(False) | (pd.isnull(a) & pd.isnull(b))


def next2sequence(next_col: pd.Series, logger=None) -> Optional[List[int]]:
Expand Down Expand Up @@ -4299,7 +4299,9 @@ def split_note_name(nn, count=False, logger=None):
return accidentals, note_name


def split_scale_degree(sd, count=False, logger=None):
def split_scale_degree(
sd, count=False, logger=None
) -> Tuple[Optional[int], Optional[str]]:
"""Splits a scale degree such as 'bbVI' or 'b6' into accidentals and numeral.
sd : :obj:`str`
Expand Down Expand Up @@ -4506,8 +4508,6 @@ def adjacency_groups(
if s.isna().any():
if na_values == "group":
shifted = s.shift()
if pd.isnull(S.iloc[0]):
shifted.iloc[0] = True
beginnings = ~nan_eq(s, shifted)
else:
logger.warning(
Expand All @@ -4516,11 +4516,10 @@ def adjacency_groups(
)
s = s.dropna()
beginnings = (s != s.shift()).fillna(False)
beginnings.iloc[0] = True
reindex_flag = True
else:
beginnings = s != s.shift()
beginnings.iloc[0] = True
beginnings.iat[0] = True
if prevent_merge:
beginnings |= forced_beginnings
groups = beginnings.cumsum()
Expand Down Expand Up @@ -5064,7 +5063,9 @@ def abs2rel_key(
return acc + result_numeral


def rel2abs_key(relative: str, localkey: str, global_minor: bool = False, logger=None):
def rel2abs_key(
relative: str, localkey: str, global_minor: bool = False, logger=None
) -> Optional[str]:
"""Expresses a Roman numeral that is expressed relative to a localkey
as scale degree of the global key. For local keys {III, iii, VI, vi, VII, vii}
the result changes depending on whether the global key is major or minor.
Expand Down Expand Up @@ -5121,6 +5122,8 @@ def rel2abs_key(relative: str, localkey: str, global_minor: bool = False, logger
localkey_accidentals, localkey = split_scale_degree(
localkey, count=True, logger=logger
)
if relative is None or localkey is None:
return
resulting_accidentals = relative_accidentals + localkey_accidentals
numerals = maj_rn if relative.isupper() else min_rn
rel_num = numerals.index(relative)
Expand Down

0 comments on commit f408de3

Please sign in to comment.