From 4f738ae896aaf12eb3961d8029cff96eaba6ada5 Mon Sep 17 00:00:00 2001 From: Cloudac7 <812556867@qq.com> Date: Thu, 11 Apr 2024 17:18:56 +0800 Subject: [PATCH] fix: unexpected 2dim mean and var when using 1dim data --- catflow/utils/statistics.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/catflow/utils/statistics.py b/catflow/utils/statistics.py index 41c008a..cf8d0ef 100644 --- a/catflow/utils/statistics.py +++ b/catflow/utils/statistics.py @@ -4,9 +4,11 @@ from numpy.typing import ArrayLike -def block_average(data: ArrayLike, block_size: int, axis: int = 0) -> Tuple[float, float]: +def block_average(data: ArrayLike, block_size: int, axis: int = 0): + reshape_flag = False data = np.array(data) if data.ndim == 1: + reshape_flag = True data = np.reshape(data, (-1, 1)) # Reshape data into a 2D array N_b = data.shape[axis] // block_size if N_b == 0: @@ -21,6 +23,9 @@ def block_average(data: ArrayLike, block_size: int, axis: int = 0) -> Tuple[floa blocked_data = np.mean(reshaped_data, axis=axis+1) mean = np.mean(blocked_data, axis=axis) var = np.std(blocked_data, ddof=1, axis=axis) / np.sqrt(N_b) + if reshape_flag: + mean = mean[0] + var = var[0] return mean, var