Skip to content

Commit

Permalink
Merge pull request #508 from RocketPy-Team/enh/filter-method
Browse files Browse the repository at this point in the history
ENH: add Function.low_pass_filter method
  • Loading branch information
kalounis authored Dec 17, 2023
2 parents ea13c73 + 7eb4a73 commit 8a7ab31
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
32 changes: 32 additions & 0 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,38 @@ def to_frequency_domain(self, lower, upper, sampling_frequency, remove_dc=True):
extrapolation="zero",
)

def low_pass_filter(self, alpha):
"""Implements a low pass filter with a moving average filter
Parameters
----------
alpha : float
Attenuation coefficient, 0 <= alpha <= 1
For a given dataset, the larger alpha is, the more closely the
filtered function returned will match the function the smaller
alpha is, the smoother the filtered function returned will be
(but with a phase shift)
Returns
-------
Function
The function with the incoming source filtered
"""
filtered_signal = np.zeros_like(self.source)
filtered_signal[0] = self.source[0]

for i in range(1, len(self.source)):
# for each point of our dataset, we apply a exponential smoothing
filtered_signal[i] = (
alpha * self.source[i] + (1 - alpha) * filtered_signal[i - 1]
)

return Function(
source=filtered_signal,
interpolation=self.__interpolation__,
extrapolation=self.__extrapolation__,
)

# Define all presentation methods
def __call__(self, *args):
"""Plot the Function if no argument is given. If an
Expand Down
40 changes: 40 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,43 @@ def test_pow_arithmetic_priority(other):
assert isinstance(other**func_lambda, Function)
assert isinstance(func_array**other, Function)
assert isinstance(other**func_array, Function)


@pytest.mark.parametrize("alpha", [0.1, 0.5, 0.9])
def test_low_pass_filter(alpha):
"""Test the low_pass_filter method of the Function class.
Parameters
----------
alpha : float
Attenuation coefficient, 0 < alpha < 1.
"""
# Create a test function, sinus here
source = np.array(
[(1, np.sin(1)), (2, np.sin(2)), (3, np.sin(3)), (4, np.sin(4)), (5, np.sin(5))]
)
func = Function(source)

# Apply low pass filter
filtered_func = func.low_pass_filter(alpha)

# Check that the method works as intended and returns the right object with no issue
assert isinstance(filtered_func, Function), "The returned type is not a Function"
assert np.array_equal(
filtered_func.source[0], source[0]
), "The initial value is not the expected value"
assert len(filtered_func.source) == len(
source
), "The filtered Function and the Function have different lengths"
assert (
filtered_func.__interpolation__ == func.__interpolation__
), "The interpolation method was unexpectedly changed"
assert (
filtered_func.__extrapolation__ == func.__extrapolation__
), "The extrapolation method was unexpectedly changed"
for i in range(1, len(source)):
expected = alpha * source[i][1] + (1 - alpha) * filtered_func.source[i - 1][1]
assert np.isclose(filtered_func.source[i][1], expected, atol=1e-6), (
f"The filtered value at index {i} is not the expected value. "
f"Expected: {expected}, Actual: {filtered_func.source[i][1]}"
)

0 comments on commit 8a7ab31

Please sign in to comment.