Skip to content

Commit

Permalink
Improve Auto Interpretation Performance (#814)
Browse files Browse the repository at this point in the history
* cythonize merge_plateaus

* add a threshold for plateau count to prevent running forever
  • Loading branch information
jopohl authored Nov 4, 2020
1 parent 217a814 commit d300ba5
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 41 deletions.
25 changes: 2 additions & 23 deletions src/urh/ainterpretation/AutoInterpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,28 +280,7 @@ def merge_plateau_lengths(plateau_lengths, tolerance=None) -> list:
if tolerance == 0 or tolerance is None:
return plateau_lengths

result = []
if len(plateau_lengths) == 0:
return result

if plateau_lengths[0] <= tolerance:
result.append(0)

i = 0
while i < len(plateau_lengths):
if plateau_lengths[i] <= tolerance:
# Look forward to see if we need to merge a larger window e.g. for 67, 1, 10, 1, 21
n = 2
while i + n < len(plateau_lengths) and plateau_lengths[i + n] <= tolerance:
n += 2

result[-1] = sum(plateau_lengths[max(i - 1, 0):i + n])
i += n
else:
result.append(plateau_lengths[i])
i += 1

return result
return c_auto_interpretation.merge_plateaus(plateau_lengths, tolerance, max_count=10000)


def round_plateau_lengths(plateau_lengths: list):
Expand Down Expand Up @@ -343,8 +322,8 @@ def get_bit_length_from_plateau_lengths(merged_plateau_lengths) -> int:
return int(merged_plateau_lengths[0])

round_plateau_lengths(merged_plateau_lengths)
histogram = c_auto_interpretation.get_threshold_divisor_histogram(merged_plateau_lengths)

histogram = c_auto_interpretation.get_threshold_divisor_histogram(np.array(merged_plateau_lengths, dtype=np.uint64))
if len(histogram) == 0:
return 0
else:
Expand Down
49 changes: 41 additions & 8 deletions src/urh/cythonext/auto_interpretation.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ from cpython cimport array
import array
import cython

from cython.parallel import prange
from libc.stdlib cimport malloc, free
from libcpp.algorithm cimport sort
from libc.stdint cimport uint64_t

cpdef tuple k_means(float[:] data, unsigned int k=2):
cdef float[:] centers = np.empty(k, dtype=np.float32)
cdef list clusters = []
Expand Down Expand Up @@ -105,7 +110,7 @@ def segment_messages_from_magnitudes(cython.floating[:] magnitudes, float noise_

return result

cpdef unsigned long long[:] get_threshold_divisor_histogram(unsigned long long[:] plateau_lengths, float threshold=0.2):
cpdef uint64_t[:] get_threshold_divisor_histogram(uint64_t[:] plateau_lengths, float threshold=0.2):
"""
Get a histogram (i.e. count) how many times a value is a threshold divisor for other values in given data
Expand All @@ -114,12 +119,10 @@ cpdef unsigned long long[:] get_threshold_divisor_histogram(unsigned long long[:
:param plateau_lengths:
:return:
"""
cdef unsigned long long num_lengths = len(plateau_lengths)
cdef uint64_t i, j, x, y, minimum, maximum, num_lengths = len(plateau_lengths)

cdef np.ndarray[np.uint64_t, ndim=1] histogram = np.zeros(int(np.max(plateau_lengths)) + 1, dtype=np.uint64)

cdef unsigned long long i, j, x, y, minimum, maximum

for i in range(0, num_lengths):
for j in range(i+1, num_lengths):
x = plateau_lengths[i]
Expand All @@ -139,6 +142,40 @@ cpdef unsigned long long[:] get_threshold_divisor_histogram(unsigned long long[:

return histogram

cpdef np.ndarray[np.uint64_t, ndim=1] merge_plateaus(np.ndarray[np.uint64_t, ndim=1] plateaus,
uint64_t tolerance,
uint64_t max_count):
cdef uint64_t j, n, L = len(plateaus), current = 0, i = 1, tmp_sum
if L == 0:
return np.zeros(0, dtype=np.uint64)

cdef np.ndarray[np.uint64_t, ndim=1] result = np.empty(L, dtype=np.uint64)
if plateaus[0] <= tolerance:
result[0] = 0
else:
result[0] = plateaus[0]

while i < L and current < max_count:
if plateaus[i] <= tolerance:
# Look ahead to see whether we need to merge a larger window e.g. for 67, 1, 10, 1, 21
n = 2
while i + n < L and plateaus[i + n] <= tolerance:
n += 2

tmp_sum = 0
for j in range(i - 1, i + n):
tmp_sum += plateaus[j]

result[current] = tmp_sum
i += n
else:
current += 1
result[current] = plateaus[i]
i += 1

return result[:current+1]


cpdef np.ndarray[np.uint64_t, ndim=1] get_plateau_lengths(float[:] rect_data, float center, int percentage=25):
if len(rect_data) == 0 or center is None:
return np.array([], dtype=np.uint64)
Expand Down Expand Up @@ -171,10 +208,6 @@ cpdef np.ndarray[np.uint64_t, ndim=1] get_plateau_lengths(float[:] rect_data, fl
return np.array(result, dtype=np.uint64)


from cython.parallel import prange
from libc.stdlib cimport malloc, free
from libcpp.algorithm cimport sort

cdef float median(double[:] data, unsigned long start, unsigned long data_len, unsigned int k=3) nogil:
cdef unsigned long i, j

Expand Down
22 changes: 12 additions & 10 deletions tests/auto_interpretation/test_bit_length_detection.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import unittest
import numpy as np

from urh.ainterpretation import AutoInterpretation


class TestAutoInterpretation(unittest.TestCase):
def __run_merge(self, data):
return list(AutoInterpretation.merge_plateau_lengths(np.array(data, dtype=np.uint64)))

def test_merge_plateau_lengths(self):
self.assertEqual(AutoInterpretation.merge_plateau_lengths([]), [])
self.assertEqual(AutoInterpretation.merge_plateau_lengths([42]), [42])
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 100, 100]), [100, 100, 100])
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 49, 1, 50, 100]), [100, 100, 100])
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 48, 2, 50, 100]), [100, 100, 100])
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 100, 67, 1, 10, 1, 21]), [100, 100, 100])
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 100, 67, 1, 10, 1, 21, 100, 50, 1, 49]),
[100, 100, 100, 100, 100])
self.assertEqual(self.__run_merge([100, 49, 1, 50, 100]), [100, 100, 100])
self.assertEqual(self.__run_merge([100, 48, 2, 50, 100]), [100, 100, 100])
self.assertEqual(self.__run_merge([100, 100, 67, 1, 10, 1, 21]), [100, 100, 100])
self.assertEqual(self.__run_merge([100, 100, 67, 1, 10, 1, 21, 100, 50, 1, 49]), [100, 100, 100, 100, 100])

def test_estimate_tolerance_from_plateau_lengths(self):
self.assertEqual(AutoInterpretation.estimate_tolerance_from_plateau_lengths([]), None)
Expand All @@ -34,19 +36,19 @@ def test_tolerant_greatest_common_divisor(self):
def test_get_bit_length_from_plateau_length(self):
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths([]), 0)
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths([42]), 42)
plateau_lengths = [2, 1, 2, 73, 1, 26, 100, 40, 1, 59, 100, 47, 1, 52, 67, 1, 10, 1, 21, 33, 1, 66, 100, 5, 1, 3, 1, 48, 1, 27, 1, 8]
plateau_lengths = np.array([2, 1, 2, 73, 1, 26, 100, 40, 1, 59, 100, 47, 1, 52, 67, 1, 10, 1, 21, 33, 1, 66, 100, 5, 1, 3, 1, 48, 1, 27, 1, 8], dtype=np.uint64)
merged_lengths = AutoInterpretation.merge_plateau_lengths(plateau_lengths)
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths(merged_lengths), 100)


plateau_lengths = [1, 292, 331, 606, 647, 286, 645, 291, 334, 601, 339, 601, 338, 602, 337, 603, 338, 604, 336, 605, 337, 600, 338, 605, 646]
plateau_lengths = np.array([1, 292, 331, 606, 647, 286, 645, 291, 334, 601, 339, 601, 338, 602, 337, 603, 338, 604, 336, 605, 337, 600, 338, 605, 646], dtype=np.uint64)
merged_lengths = AutoInterpretation.merge_plateau_lengths(plateau_lengths)
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths(merged_lengths), 300)

plateau_lengths = [3, 8, 8, 8, 8, 8, 8, 8, 8, 16, 8, 8, 16, 32, 8, 8, 8, 8, 8, 24, 8, 24, 8, 24, 8, 24, 8, 24, 16, 16, 24, 8]
plateau_lengths = np.array([3, 8, 8, 8, 8, 8, 8, 8, 8, 16, 8, 8, 16, 32, 8, 8, 8, 8, 8, 24, 8, 24, 8, 24, 8, 24, 8, 24, 16, 16, 24, 8], dtype=np.uint64)
merged_lengths = AutoInterpretation.merge_plateau_lengths(plateau_lengths)
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths(merged_lengths), 8)

def test_get_bit_length_from_merged_plateau_lengths(self):
merged_lengths = [40, 40, 40, 40, 40, 30, 50, 30, 90, 40, 40, 80, 160, 30, 50, 30]
merged_lengths = np.array([40, 40, 40, 40, 40, 30, 50, 30, 90, 40, 40, 80, 160, 30, 50, 30], dtype=np.uint64)
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths(merged_lengths), 40)

0 comments on commit d300ba5

Please sign in to comment.