Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: User input checks added for Function class #451

Merged
merged 21 commits into from
Nov 17, 2023
Merged
Changes from 10 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7f88c50
User input checks added
brunosorban Nov 1, 2023
64c5761
Fix code style issues with Black
lint-action Nov 1, 2023
92d7191
Formatting update
brunosorban Nov 1, 2023
f22cabc
Check user inputs updated
brunosorban Nov 12, 2023
93dacbc
Input checks added to set_source
brunosorban Nov 12, 2023
5eb5ca4
Merge incoming changes
brunosorban Nov 12, 2023
032aaa8
Fix code style issues with Black
lint-action Nov 12, 2023
a405ee3
Merge branch 'master' into bug/function-input-validation
Gui-FernandesBR Nov 12, 2023
f1eb70a
BUG: Fix mesh evaluation in Function class
Gui-FernandesBR Nov 13, 2023
8ea0124
Refactor code to improve performance and
Gui-FernandesBR Nov 13, 2023
667afa3
Merge branch 'develop' into bug/function-input-validation
Gui-FernandesBR Nov 15, 2023
634a05c
Fix code style issues with Black
lint-action Nov 15, 2023
322e332
MNT: use collections Iterable in Function type check.
phmbressan Nov 16, 2023
0f76b3e
FIX: error in shepard domain interpolation.
phmbressan Nov 16, 2023
404d152
Merge branch 'bug/multivariable-function' into bug/function-input-val…
phmbressan Nov 16, 2023
f0f0a71
TST: tests for shepard interpolation values.
phmbressan Nov 17, 2023
edd1e05
Merge branch 'develop' into bug/function-input-validation
phmbressan Nov 17, 2023
4508d42
MNT: post conflict solve refactors.
phmbressan Nov 17, 2023
10ceccb
TST: improve shepard test with multivariable case.
phmbressan Nov 17, 2023
03966a6
MNT: Code style at rocketpy/mathutils/function.py
Gui-FernandesBR Nov 17, 2023
848c860
MNT: Code style at rocketpy/mathutils/function.py
Gui-FernandesBR Nov 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 238 additions & 20 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from inspect import signature
Gui-FernandesBR marked this conversation as resolved.
Show resolved Hide resolved
from pathlib import Path

Expand Down Expand Up @@ -68,6 +69,9 @@ def __init__(
-------
None
"""
inputs, outputs, interpolation, extrapolation = self._check_user_input(
source, inputs, outputs, interpolation, extrapolation
)
# Set input and output
self.set_inputs(inputs)
self.set_outputs(outputs)
Expand Down Expand Up @@ -138,6 +142,13 @@ def set_source(self, source):
-------
self : Function
"""
_, _, _, _ = self._check_user_input(
source,
self.__inputs__,
self.__outputs__,
self.__interpolation__,
self.__extrapolation__,
)
Gui-FernandesBR marked this conversation as resolved.
Show resolved Hide resolved
# If the source is a Function
if isinstance(source, Function):
source = source.get_source()
Expand Down Expand Up @@ -186,17 +197,13 @@ def source(x):
# Check to see if dimensions match incoming data set
new_total_dim = len(source[0, :])
old_total_dim = self.__dom_dim__ + self.__img_dim__
dV = self.__inputs__ == ["Scalar"] and self.__outputs__ == ["Scalar"]

# If they don't, update default values or throw error
if new_total_dim != old_total_dim:
if dV:
# Update dimensions and inputs
self.__dom_dim__ = new_total_dim - 1
self.__inputs__ = self.__dom_dim__ * self.__inputs__
else:
# User has made a mistake inputting inputs and outputs
print("Error in input and output dimensions!")
return None
# Update dimensions and inputs
self.__dom_dim__ = new_total_dim - 1
self.__inputs__ = self.__dom_dim__ * self.__inputs__

# Do things if domDim is 1
if self.__dom_dim__ == 1:
source = source[source[:, 0].argsort()]
Expand Down Expand Up @@ -749,18 +756,83 @@ def get_value(self, *args):
Returns
-------
ans : scalar, list
Value of the Function at the specified point(s).

Examples
--------
>>> from rocketpy import Function

Testing with callable source (1 dimension):
>>> f = Function(lambda x: x**2)
>>> f.get_value(2)
4
>>> f.get_value(2.5)
6.25
>>> f.get_value([1, 2, 3])
[1, 4, 9]
>>> f.get_value([1, 2.5, 4.0])
[1, 6.25, 16.0]

Testing with callable source (2 dimensions):
>>> f2 = Function(lambda x, y: x**2 + y**2)
>>> f2.get_value(1, 2)
5
>>> f2.get_value([1, 2, 3], [1, 2, 3])
[2, 8, 18]
>>> f2.get_value([5], [5])
[50]

Testing with ndarray source (1 dimension):
>>> f3 = Function(
... [(0, 0), (1, 1), (1.5, 2.25), (2, 4), (2.5, 6.25), (3, 9), (4, 16)]
... )
>>> f3.get_value(2)
4.0
>>> f3.get_value(2.5)
6.25
>>> f3.get_value([1, 2, 3])
[1.0, 4.0, 9.0]
>>> f3.get_value([1, 2.5, 4.0])
[1.0, 6.25, 16.0]

Testing with ndarray source (2 dimensions):
>>> f4 = Function(
... [(0, 0, 0), (1, 1, 1), (1, 2, 2), (2, 4, 8), (3, 9, 27)]
... )
>>> f4.get_value(1, 1)
1.0
>>> f4.get_value(2, 4)
8.0
>>> abs(f4.get_value(1, 1.5) - 1.5) < 1e-2 # the interpolation is not perfect
True
>>> f4.get_value(3, 9)
27.0
"""
if len(args) != self.__dom_dim__:
raise ValueError(
f"This Function takes {self.__dom_dim__} arguments, {len(args)} given."
)

# Return value for Function of function type
if callable(self.source):
if len(args) == 1 and isinstance(args[0], (list, tuple)):
if isinstance(args[0][0], (tuple, list)):
return [self.source(*arg) for arg in args[0]]
else:
return [self.source(arg) for arg in args[0]]
elif len(args) == 1 and isinstance(args[0], np.ndarray):
return self.source(args[0])
# if the function is 1-D:
if self.__dom_dim__ == 1:
# if the args is a simple number (int or float)
if isinstance(args[0], (int, float)):
return self.source(args[0])
# if the arguments are iterable, we map and return a list
if isinstance(args[0], (list, tuple, np.ndarray)):
phmbressan marked this conversation as resolved.
Show resolved Hide resolved
return list(map(self.source, args[0]))

# if the function is n-D:
else:
return self.source(*args)
# if each arg is a simple number (int or float)
if all(isinstance(arg, (int, float)) for arg in args):
return self.source(*args)
# if each arg is iterable, we map and return a list
if all(isinstance(arg, (list, tuple, np.ndarray)) for arg in args):
return [self.source(*arg) for arg in zip(*args)]

# Returns value for shepard interpolation
elif self.__interpolation__ == "shepard":
if isinstance(args[0], (list, tuple)):
Expand Down Expand Up @@ -1217,10 +1289,10 @@ def plot2D(
x = np.linspace(lower[0], upper[0], samples[0])
y = np.linspace(lower[1], upper[1], samples[1])
mesh_x, mesh_y = np.meshgrid(x, y)
mesh_x_flat, mesh_y_flat = mesh_x.flatten(), mesh_y.flatten()
mesh = [[mesh_x_flat[i], mesh_y_flat[i]] for i in range(len(mesh_x_flat))]
mesh = np.column_stack((mesh_x.flatten(), mesh_y.flatten()))

# Evaluate function at all mesh nodes and convert it to matrix
z = np.array(self.get_value(mesh)).reshape(mesh_x.shape)
z = np.array(self.get_value(mesh[:, 0], mesh[:, 1])).reshape(mesh_x.shape)
phmbressan marked this conversation as resolved.
Show resolved Hide resolved
# Plot function
if disp_type == "surface":
surf = axes.plot_surface(
Expand Down Expand Up @@ -2612,6 +2684,152 @@ def compose(self, func, extrapolate=False):
extrapolation=self.__extrapolation__,
)

@staticmethod
def _check_user_input(
Gui-FernandesBR marked this conversation as resolved.
Show resolved Hide resolved
source,
inputs,
outputs,
interpolation,
extrapolation,
):
"""
Validates and processes the user input parameters for creating or
modifying a Function object. This function ensures the inputs, outputs,
interpolation, and extrapolation parameters are compatible with the
given source. It converts the source to a numpy array if necessary, sets
default values and raises warnings or errors for incompatible or
ill-defined parameters.

Parameters
----------
source : list, np.ndarray, or Function
Gui-FernandesBR marked this conversation as resolved.
Show resolved Hide resolved
The source data or Function object. If a list or ndarray, it should
contain numeric data. If a Function, its inputs and outputs are
checked against the provided inputs and outputs.
inputs : list of str or None
The names of the input variables. If None, defaults are generated
based on the dimensionality of the source.
outputs : str or list of str
The name(s) of the output variable(s). If a list is provided, it
must have a single element.
interpolation : str or None
The method of interpolation to be used. For multidimensional sources
it defaults to 'shepard' if not provided.
extrapolation : str or None
The method of extrapolation to be used. For multidimensional sources
it defaults to 'natural' if not provided.

Returns
-------
tuple
A tuple containing the processed inputs, outputs, interpolation, and
extrapolation parameters.

Raises
------
ValueError
If the dimensionality of the source does not match the combined
dimensions of inputs and outputs. If the outputs list has more than
one element.
TypeError
If the source is not a list, np.ndarray, or Function object.
Warning
If inputs or outputs do not match for a Function source, or if
defaults are used for inputs, interpolation,and extrapolation for a
multidimensional source.

Examples
--------
>>> from rocketpy import Function
>>> source = np.array([(1, 1), (2, 4), (3, 9)])
>>> inputs = "x"
>>> outputs = ["y"]
>>> interpolation = 'linear'
>>> extrapolation = 'zero'
>>> inputs, outputs, interpolation, extrapolation = Function._check_user_input(
... source, inputs, outputs, interpolation, extrapolation
... )
>>> inputs
['x']
>>> outputs
['y']
>>> interpolation
'linear'
>>> extrapolation
'zero'
"""
# check output type and dimensions
if isinstance(outputs, str):
outputs = [outputs]
if isinstance(inputs, str):
inputs = [inputs]

elif len(outputs) > 1:
raise ValueError(
"Output must either be a string or have dimension 1, "
+ f"it currently has dimension ({len(outputs)})."
)
phmbressan marked this conversation as resolved.
Show resolved Hide resolved

# check source for data type
# if list or ndarray, check for dimensions, interpolation and extrapolation
if isinstance(source, (list, np.ndarray)):
# this will also trigger an error if the source is not a list of
# numbers or if the array is not homogeneous
source = np.array(source, dtype=np.float64)

# check dimensions
source_dim = source.shape[1]

# check interpolation and extrapolation
if source_dim > 2:
# check for inputs and outputs
if inputs == ["Scalar"]:
inputs = [f"Input {i+1}" for i in range(source_dim - 1)]
warnings.warn(
f"Inputs not set, defaulting to {inputs} for "
+ "multidimensional functions.",
)

if interpolation not in [None, "shepard"]:
interpolation = "shepard"
warnings.warn(
(
"Interpolation method for multidimensional functions is set"
"to 'shepard', currently other methods are not supported."
),
)

if extrapolation is None:
extrapolation = "natural"
warnings.warn(
"Extrapolation not set, defaulting to 'natural' "
+ "for multidimensional functions.",
)

# check input dimensions
in_out_dim = len(inputs) + len(outputs)
if source_dim != in_out_dim:
raise ValueError(
"Source dimension ({source_dim}) does not match input "
+ f"and output dimension ({in_out_dim})."
)

# if function, check for inputs and outputs
if isinstance(source, Function):
# check inputs
if inputs is not None and inputs != source.get_inputs():
warnings.warn(
f"Inputs do not match source inputs, setting inputs to {inputs}.",
)

# check outputs
if outputs is not None and outputs != source.get_outputs():
warnings.warn(
f"Outputs do not match source outputs, setting outputs to {outputs}.",
)

return inputs, outputs, interpolation, extrapolation


class PiecewiseFunction(Function):
def __new__(
Expand Down