Skip to content

Commit

Permalink
fix: don't check shapes for CRPS calc funcs;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Dec 21, 2023
1 parent d2a475a commit b6858eb
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions pypots/utils/metrics/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def _check_inputs(
predictions: Union[np.ndarray, torch.Tensor, list],
targets: Union[np.ndarray, torch.Tensor, list],
masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None,
check_shape: bool = True,
):
# check type
assert isinstance(predictions, type(targets)), (
Expand All @@ -27,9 +28,10 @@ def _check_inputs(
# check shape
prediction_shape = predictions.shape
target_shape = targets.shape
assert (
prediction_shape == target_shape
), f"shape of `predictions` and `targets` must match, but got {prediction_shape} and {target_shape}"
if check_shape:
assert (
prediction_shape == target_shape
), f"shape of `predictions` and `targets` must match, but got {prediction_shape} and {target_shape}"
# check NaN
assert not lib.isnan(
predictions
Expand All @@ -44,11 +46,11 @@ def _check_inputs(
f"types of `masks`, `predictions`, and `targets` must match, but got"
f"`masks`: {type(masks)}, `targets`: {type(targets)}"
)
# check shape
# check shape, masks shape must match targets
mask_shape = masks.shape
assert mask_shape == target_shape, (
f"shape of `masks` must match `predictions` and `targets` shape, "
f"but got `mask`: {mask_shape} that is different from {prediction_shape}"
f"shape of `masks` must match `targets` shape, "
f"but got `mask`: {mask_shape} that is different from `targets`: {target_shape}"
)
# check NaN
assert not lib.isnan(
Expand Down Expand Up @@ -311,7 +313,7 @@ def calc_quantile_crps(
"""
# check shapes and values of inputs
_ = _check_inputs(predictions, targets, masks)
_ = _check_inputs(predictions, targets, masks, check_shape=False)

if isinstance(predictions, np.ndarray):
predictions = torch.from_numpy(predictions)
Expand Down Expand Up @@ -370,7 +372,7 @@ def calc_quantile_crps_sum(
"""
# check shapes and values of inputs
_ = _check_inputs(predictions, targets, masks)
_ = _check_inputs(predictions, targets, masks, check_shape=False)

if isinstance(predictions, np.ndarray):
predictions = torch.from_numpy(predictions)
Expand Down

0 comments on commit b6858eb

Please sign in to comment.