Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure n_steps_default >= 3 #174

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/environment_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ dependencies:
- myst-parser
- myst-nb
- sphinxcontrib-srclinks
- setuptools
31 changes: 24 additions & 7 deletions gcm_filters/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ def _taper_target(target_spec: TargetSpec):
}


def _compute_n_steps_default(
ndim, filter_shape, filter_scale, dx_min, transition_width
):
"""Compute the default number of steps for 1D or 2D filters based on provided parameters."""

n_steps_factor = filter_params[filter_shape][ndim]["offset"] + filter_params[
filter_shape
][ndim]["factor"] * (
(np.pi / transition_width) ** filter_params[filter_shape][ndim]["exponent"]
)

filter_factor = filter_scale / dx_min

n_steps_default = max(np.ceil(n_steps_factor * filter_factor).astype(int), 3)

return n_steps_default


class FilterSpec(NamedTuple):
n_steps: int
s_max: float
Expand Down Expand Up @@ -332,20 +350,19 @@ def __post_init__(self):
raise ValueError(f"Transition width must be > 1.")

# Get default number of steps
filter_factor = self.filter_scale / self.dx_min
if self.ndim > 2:
if self.n_steps < 3:
raise ValueError(f"When ndim > 2, you must set n_steps manually")
else:
n_steps_default = self.n_steps # For ndim>2 we don't have a default
else:
n_steps_factor = filter_params[self.filter_shape][self.ndim][
"offset"
] + filter_params[self.filter_shape][self.ndim]["factor"] * (
(np.pi / self.transition_width)
** filter_params[self.filter_shape][self.ndim]["exponent"]
n_steps_default = _compute_n_steps_default(
self.ndim,
self.filter_shape,
self.filter_scale,
self.dx_min,
self.transition_width,
)
n_steps_default = np.ceil(n_steps_factor * filter_factor).astype(int)

# Set n_steps if needed and issue n_step warning, if needed
if self.n_steps < 3:
Expand Down
5 changes: 0 additions & 5 deletions readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,5 @@ build:
tools:
python: "mambaforge-4.10"

python:
install:
- method: setuptools
path: .

conda:
environment: ci/environment_docs.yml
9 changes: 8 additions & 1 deletion tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import xarray as xr

from gcm_filters import Filter, FilterShape, GridType
from gcm_filters.filter import FilterSpec
from gcm_filters.filter import FilterSpec, _compute_n_steps_default


def _check_equal_filter_spec(spec1, spec2):
Expand Down Expand Up @@ -85,6 +85,13 @@ def test_filter_spec(filter_args, expected_filter_spec):
# TODO: check other properties of filter_spec?


def test_default_n_steps_larger_equal_3():
n_steps_default = _compute_n_steps_default(2, FilterShape.GAUSSIAN, 1.5, 1, np.pi)

# Assert that the number of steps is greater than or equal to 3
assert n_steps_default >= 3


#################### Diffusion-based filter tests ########################################
area_weighted_regular_grids = [
GridType.REGULAR_AREA_WEIGHTED,
Expand Down
Loading