Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add Function.low_pass_filter method #508

Merged
merged 13 commits into from
Dec 17, 2023
32 changes: 32 additions & 0 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,38 @@ def to_frequency_domain(self, lower, upper, sampling_frequency, remove_dc=True):
extrapolation="zero",
)

def low_pass_filter(self, alpha):
"""Implements a low pass filter with a moving average filter

Parameters
----------
alpha : float
Attenuation coefficient, 0 <= alpha <= 1
For a given dataset, the larger alpha is, the more closely the
filtered function returned will match the function the smaller
alpha is, the smoother the filtered function returned will be
(but with a phase shift)

Returns
-------
Function
The function with the incoming source filtered
"""
filtered_signal = np.zeros_like(self.source)
filtered_signal[0] = self.source[0]

for i in range(1, len(self.source)):
# for each point of our dataset, we apply a exponential smoothing
filtered_signal[i] = (
alpha * self.source[i] + (1 - alpha) * filtered_signal[i - 1]
)

return Function(
source=filtered_signal,
interpolation=self.__interpolation__,
extrapolation=self.__extrapolation__,
)

# Define all presentation methods
def __call__(self, *args):
"""Plot the Function if no argument is given. If an
Expand Down
40 changes: 40 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,43 @@ def test_pow_arithmetic_priority(other):
assert isinstance(other**func_lambda, Function)
assert isinstance(func_array**other, Function)
assert isinstance(other**func_array, Function)


@pytest.mark.parametrize("alpha", [0.1, 0.5, 0.9])
def test_low_pass_filter(alpha):
"""Test the low_pass_filter method of the Function class.

Parameters
----------
alpha : float
Attenuation coefficient, 0 < alpha < 1.
"""
# Create a test function, sinus here
source = np.array(
[(1, np.sin(1)), (2, np.sin(2)), (3, np.sin(3)), (4, np.sin(4)), (5, np.sin(5))]
)
func = Function(source)

# Apply low pass filter
filtered_func = func.low_pass_filter(alpha)

# Check that the method works as intended and returns the right object with no issue
assert isinstance(filtered_func, Function), "The returned type is not a Function"
assert np.array_equal(
filtered_func.source[0], source[0]
), "The initial value is not the expected value"
assert len(filtered_func.source) == len(
source
), "The filtered Function and the Function have different lengths"
assert (
filtered_func.__interpolation__ == func.__interpolation__
), "The interpolation method was unexpectedly changed"
assert (
filtered_func.__extrapolation__ == func.__extrapolation__
), "The extrapolation method was unexpectedly changed"
for i in range(1, len(source)):
expected = alpha * source[i][1] + (1 - alpha) * filtered_func.source[i - 1][1]
assert np.isclose(filtered_func.source[i][1], expected, atol=1e-6), (
f"The filtered value at index {i} is not the expected value. "
f"Expected: {expected}, Actual: {filtered_func.source[i][1]}"
)
Loading