From bd97acd5d0373a8b120d0e0ae666145645171561 Mon Sep 17 00:00:00 2001 From: areeves87 Date: Mon, 9 Jul 2018 10:09:23 -0700 Subject: [PATCH] Enforce strictly increasing values in adjust_times Resolves #126. * enforce strictly increasing values in old_time and new_time * update monotonicity check; move warnings out of fxn * strictly increase original_times, monotonic incr new_times * remove trailing whitespace for pep8 * update new_times indices if original_times fails check * use np.unique to check and enforce original_time strictly increasing * fix pep8 formatting issues. * fix pep8 formatting issues. * fix pep8 formatting issues. * fix pep8 formatting issues. * fix pep8 formatting issues. * convert new_times to array before subsetting with unique_idx * so that we test adjust_times with appropriate event times * so that test_adjust_times() tests the correct values * update tests for adjust_times * comment out some tests in adjust_times * PrettyMIDI.adjust_times enforces strict increase in original_times and monotonic increase in new_times * prepare to squash commits * passes all tests * passes all tests and pep8 * tolerate floating point error in time signature calculation * maintain consistent decimal usage --- pretty_midi/pretty_midi.py | 14 ++++++++++++++ tests/test_pretty_midi.py | 22 +++++++++++++++------- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/pretty_midi/pretty_midi.py b/pretty_midi/pretty_midi.py index bf21759..fba8d99 100644 --- a/pretty_midi/pretty_midi.py +++ b/pretty_midi/pretty_midi.py @@ -1019,6 +1019,20 @@ def adjust_times(self, original_times, new_times): # Get original downbeat locations (we will use them to determine where # to put the first time signature change) original_downbeats = self.get_downbeats() + # Force strict increase in original_times and monotonic in new_times. + # While enforcing, give warning. + original_size = len(original_times) + original_times, unique_idx = np.unique(original_times, + return_index=True) + if ((unique_idx.size != original_size) or + any(unique_idx != np.arange(unique_idx.size))): + warnings.warn('original_times must be strictly increasing; ' + 'automatically enforcing this.') + new_times = np.asarray(new_times)[unique_idx] + if not np.all(np.diff(new_times) >= 0): + warnings.warn('new_times must be monotonic; ' + 'automatically enforcing this.') + new_times = np.maximum.accumulate(new_times) # Only include notes within start/end time of the provided times for instrument in self.instruments: instrument.notes = [copy.deepcopy(note) diff --git a/tests/test_pretty_midi.py b/tests/test_pretty_midi.py index 36ebc72..64fdaa7 100644 --- a/tests/test_pretty_midi.py +++ b/tests/test_pretty_midi.py @@ -162,9 +162,9 @@ def simple(): assert np.allclose( [n.start for n in pm.instruments[0].notes], expected_starts) pm = simple() - pm.adjust_times([0, 5, 5, 10], [5, 10, 12, 17]) + pm.adjust_times([0, 5, 5, 10], [7, 12, 13, 17]) # Original times [1, 2, 3, 4, 5, 6, 7, 8, 9] - expected_starts = [6, 7, 8, 9, 12, 13, 14, 15, 16] + expected_starts = [8, 9, 10, 11, 12, 13, 14, 15, 16] assert np.allclose( [n.start for n in pm.instruments[0].notes], expected_starts) @@ -233,11 +233,19 @@ def simple(): assert np.allclose(expected_tempi, tempi, rtol=.002) # Test that all other events were interpolated as expected - note_starts = [5., 5 + 1/1.1, 7 + .9/(2/1.5), 7 + 1.9/(2/1.5), 8.5 + .5, + note_starts = [5.0, + 5 + 1/1.1, + 6 + .9/(2/2.5), + 6 + 1.9/(2/2.5), + 8.5 + .5, 8.5 + 1.5] - note_ends = [5 + .5/1.1, 7 + .4/(2/1.5), 7 + 1.4/(2/1.5), 8.5, 9 + .5, + note_ends = [5 + .5/1.1, + 6 + .4/(2/2.5), + 6 + 1.4/(2/2.5), + 8.5, + 8.5 + 1., 10 + .5] - note_pitches = [101, 102, 103, 104, 107, 108, 109] + note_pitches = [101, 102, 103, 104, 107, 108] for note, s, e, p in zip(pm.instruments[0].notes, note_starts, note_ends, note_pitches): assert note.start == s @@ -262,11 +270,11 @@ def simple(): # downbeat location - so, start by computing the location of the first # downbeat after the start of original_times, then interpolate it first_downbeat_after = .1 + 2*3*60./100. - first_ts_time = 7 + (first_downbeat_after - 3.1)/(2/1.5) + first_ts_time = 6. + (first_downbeat_after - 3.1)/(2./2.5) ts_times = [first_ts_time, 8.5, 8.5] ts_numerators = [3, 4, 6] for ts, t, n in zip(pm.time_signature_changes, ts_times, ts_numerators): - assert ts.time == t + assert np.isclose(ts.time, t) assert ts.numerator == n ks_times = [5., 8.5, 8.5]