Skip to content

Commit

Permalink
Merge pull request #174 from NoraLoose/n-steps
Browse files Browse the repository at this point in the history
Make sure `n_steps_default` >= 3
  • Loading branch information
NoraLoose authored Nov 13, 2024
2 parents d3d52c9 + 424f78c commit bcd900d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
1 change: 1 addition & 0 deletions ci/environment_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ dependencies:
- myst-parser
- myst-nb
- sphinxcontrib-srclinks
- setuptools
31 changes: 24 additions & 7 deletions gcm_filters/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ def _taper_target(target_spec: TargetSpec):
}


def _compute_n_steps_default(
ndim, filter_shape, filter_scale, dx_min, transition_width
):
"""Compute the default number of steps for 1D or 2D filters based on provided parameters."""

n_steps_factor = filter_params[filter_shape][ndim]["offset"] + filter_params[
filter_shape
][ndim]["factor"] * (
(np.pi / transition_width) ** filter_params[filter_shape][ndim]["exponent"]
)

filter_factor = filter_scale / dx_min

n_steps_default = max(np.ceil(n_steps_factor * filter_factor).astype(int), 3)

return n_steps_default


class FilterSpec(NamedTuple):
n_steps: int
s_max: float
Expand Down Expand Up @@ -332,20 +350,19 @@ def __post_init__(self):
raise ValueError(f"Transition width must be > 1.")

# Get default number of steps
filter_factor = self.filter_scale / self.dx_min
if self.ndim > 2:
if self.n_steps < 3:
raise ValueError(f"When ndim > 2, you must set n_steps manually")
else:
n_steps_default = self.n_steps # For ndim>2 we don't have a default
else:
n_steps_factor = filter_params[self.filter_shape][self.ndim][
"offset"
] + filter_params[self.filter_shape][self.ndim]["factor"] * (
(np.pi / self.transition_width)
** filter_params[self.filter_shape][self.ndim]["exponent"]
n_steps_default = _compute_n_steps_default(
self.ndim,
self.filter_shape,
self.filter_scale,
self.dx_min,
self.transition_width,
)
n_steps_default = np.ceil(n_steps_factor * filter_factor).astype(int)

# Set n_steps if needed and issue n_step warning, if needed
if self.n_steps < 3:
Expand Down
5 changes: 0 additions & 5 deletions readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,5 @@ build:
tools:
python: "mambaforge-4.10"

python:
install:
- method: setuptools
path: .

conda:
environment: ci/environment_docs.yml
9 changes: 8 additions & 1 deletion tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import xarray as xr

from gcm_filters import Filter, FilterShape, GridType
from gcm_filters.filter import FilterSpec
from gcm_filters.filter import FilterSpec, _compute_n_steps_default


def _check_equal_filter_spec(spec1, spec2):
Expand Down Expand Up @@ -85,6 +85,13 @@ def test_filter_spec(filter_args, expected_filter_spec):
# TODO: check other properties of filter_spec?


def test_default_n_steps_larger_equal_3():
n_steps_default = _compute_n_steps_default(2, FilterShape.GAUSSIAN, 1.5, 1, np.pi)

# Assert that the number of steps is greater than or equal to 3
assert n_steps_default >= 3


#################### Diffusion-based filter tests ########################################
area_weighted_regular_grids = [
GridType.REGULAR_AREA_WEIGHTED,
Expand Down

0 comments on commit bcd900d

Please sign in to comment.