diff --git a/tal/__init__.py b/tal/__init__.py index e4765a1..5b36555 100644 --- a/tal/__init__.py +++ b/tal/__init__.py @@ -2,4 +2,4 @@ from tal.plot import * from tal.reconstruct import * -__version__ = '0.6.4' +__version__ = '0.7.0' diff --git a/tal/enums.py b/tal/enums.py index 7d3d940..e15bfda 100644 --- a/tal/enums.py +++ b/tal/enums.py @@ -16,6 +16,14 @@ class HFormat(Enum): T_Si = 3 # confocal or not T_Si_Li = 4 + def time_dim(self) -> int: + assert self in [HFormat.T_Sx_Sy, + HFormat.T_Lx_Ly_Sx_Sy, + HFormat.T_Si, + HFormat.T_Si_Li], \ + f'Unexpected HFormat {self}' + return 0 + class GridFormat(Enum): UNKNOWN = 0 @@ -29,6 +37,12 @@ class VolumeFormat(Enum): X_Y_Z_3 = 2 X_Y_3 = 3 + def xyz_dim_is_last(self) -> bool: + assert self in [VolumeFormat.N_3, + VolumeFormat.X_Y_Z_3, + VolumeFormat.X_Y_3] + return True + class CameraSystem(Enum): STEADY = 0 # focused light diff --git a/tal/reconstruct/__init__.py b/tal/reconstruct/__init__.py index 3e93478..abfc740 100644 --- a/tal/reconstruct/__init__.py +++ b/tal/reconstruct/__init__.py @@ -1,5 +1,6 @@ -from tal.reconstruct.pf import * from tal.reconstruct.bp import * +from tal.reconstruct.pf import * +from tal.reconstruct.pf_dev import * from tal.io.capture_data import NLOSCaptureData from tal.enums import HFormat from typing import Union @@ -10,12 +11,17 @@ def filter_H(data: _Data, filter_name: str, data_format: HFormat = HFormat.UNKNOWN, + border: str = 'zero', plot_filter: bool = False, return_filter: bool = False, **kwargs) -> NLOSCaptureData.HType: """ Filter a captured data signal (H) using specified filter_name * data_format should be non-null if not specified through data + * border sets the behaviour for the edges of the convolution + - 'erase': the filtered signal has the edges set to zero + - 'zero': before filtering, pad the signal with zeros + - 'edge': before filtering, pad the signal with edge values * If plot_filter=True, shows a plot of the resulting filter * If return_filter=True, returns the filter (K) else, returns the filtered signal (H * K) @@ -24,10 +30,11 @@ def filter_H(data: _Data, * wl_mean: Mean of the Gaussian in the frequency domain * wl_sigma: STD of the Gaussian in the frequency domain * delta_t: Time interval, must be non-null if not specified through data + FIXME(diego): is this sentence true? vvv probably not e.g. mean = 3, sigma = 0.5 will filter frequencies of ~2-4m """ from tal.reconstruct.filters import filter_H_impl - return filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, **kwargs) + return filter_H_impl(data, filter_name, data_format, border, plot_filter, return_filter, **kwargs) def get_volume_min_max_resolution(minimal_pos, maximal_pos, resolution): diff --git a/tal/reconstruct/bp/__init__.py b/tal/reconstruct/bp/__init__.py index 8e71c4b..c8864b8 100644 --- a/tal/reconstruct/bp/__init__.py +++ b/tal/reconstruct/bp/__init__.py @@ -1,72 +1,28 @@ from tal.io.capture_data import NLOSCaptureData -from tal.enums import HFormat, GridFormat, VolumeFormat, CameraSystem -from typing import Union - +from tal.enums import VolumeFormat, CameraSystem +from tal.reconstruct.utils import convert_to_N_3, convert_reconstruction_from_N_3 import numpy as np -_Data = Union[NLOSCaptureData, NLOSCaptureData.HType] def solve(data: NLOSCaptureData, volume_xyz: NLOSCaptureData.VolumeXYZType = None, volume_format: VolumeFormat = None, - camera_system: CameraSystem = CameraSystem.STEADY) -> np.array: + camera_system: CameraSystem = CameraSystem.STEADY) -> np.array: # FIXME type """ NOTE: Does _NOT_ attempt to compensate effects caused by attenuation: - cos decay i.e. {sensor|laser}_grid_normals are ignored - 1/d^2 decay + TODO(diego): docs """ - if data.H_format == HFormat.T_Si: - H = data.H - elif data.H_format == HFormat.T_Sx_Sy: - nt, nsx, nsy = data.H.shape - H = data.H.reshape(nt, nsx * nsy) - else: - raise AssertionError(f'H_format {data.H_format} not implemented') - - if data.sensor_grid_format == GridFormat.N_3: - sensor_grid_xyz = data.sensor_grid_xyz - elif data.sensor_grid_format == GridFormat.X_Y_3: - try: - assert nsx == data.sensor_grid_xyz.shape[0] and nsy == data.sensor_grid_xyz.shape[1], \ - 'sensor_grid_xyz.shape does not match with H.shape' - except NameError: - # nsx, nsy not defined, OK - nsx, nsy, _ = data.sensor_grid_xyz.shape - pass - sensor_grid_xyz = data.sensor_grid_xyz.reshape(nsx * nsy, 3) - else: - raise AssertionError( - f'sensor_grid_format {data.sensor_grid_format} not implemented') - - if volume_format == VolumeFormat.X_Y_Z_3: - nvx, nvy, nvz, _ = volume_xyz.shape - volume_xyz_n3 = volume_xyz.reshape((-1, 3)) - elif volume_format == VolumeFormat.N_3: - volume_xyz_n3 = volume_xyz - else: - raise AssertionError('volume_format must be specified') + H, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3 = \ + convert_to_N_3(data, volume_xyz, volume_format) from tal.reconstruct.bp.backprojection import backproject reconstructed_volume_n3 = backproject( - H, data.laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3, + H, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3, camera_system, data.t_accounts_first_and_last_bounces, data.t_start, data.delta_t, data.laser_xyz, data.sensor_xyz, progress=True) - if camera_system.is_transient(): - assert data.H_format == HFormat.T_Si or data.H_format == HFormat.T_Sx_Sy or data.H_format == HFormat.T_Lx_Ly_Sx_Sy, \ - 'Cannot find time dimension given H_format' - time_dim = (data.H.shape[0],) - else: - time_dim = () - - if volume_format == VolumeFormat.X_Y_Z_3: - reconstructed_volume = reconstructed_volume_n3.reshape( - time_dim + (nvx, nvy, nvz)) - elif volume_format == VolumeFormat.N_3: - reconstructed_volume = reconstructed_volume_n3 - else: - raise AssertionError('volume_format must be specified') - - return reconstructed_volume + return convert_reconstruction_from_N_3(data, reconstructed_volume_n3, volume_xyz, volume_format, camera_system) diff --git a/tal/reconstruct/bp/backprojection.py b/tal/reconstruct/bp/backprojection.py index 9d26f05..4775690 100644 --- a/tal/reconstruct/bp/backprojection.py +++ b/tal/reconstruct/bp/backprojection.py @@ -3,15 +3,15 @@ from tqdm import tqdm -def backproject(H, laser_grid_xyz, sensor_grid_xyz, volume_xyz, +def backproject(H_0, laser_grid_xyz, sensor_grid_xyz, volume_xyz, camera_system, t_accounts_first_and_last_bounces, t_start, delta_t, laser_xyz=None, sensor_xyz=None, progress=False): # TODO(diego): extend for multiple laser points - assert H.ndim == 2 and laser_grid_xyz.size == 3, \ + assert H_0.ndim == 2 and laser_grid_xyz.size == 3, \ 'backproject only supports one laser point' - nt, ns = H.shape + nt, ns = H_0.shape assert sensor_grid_xyz.shape[0] == ns, 'H does not match with sensor_grid_xyz' assert (not t_accounts_first_and_last_bounces or (laser_xyz is not None and sensor_xyz is not None)), \ 't_accounts_first_and_last_bounces requires laser_xyz and sensor_xyz' @@ -19,9 +19,9 @@ def backproject(H, laser_grid_xyz, sensor_grid_xyz, volume_xyz, nv, _ = volume_xyz.shape if camera_system.is_transient(): - f = np.zeros((nt, nv), dtype=H.dtype) + H_1 = np.zeros((nt, nv), dtype=H_0.dtype) else: - f = np.zeros(nv, dtype=H.dtype) + H_1 = np.zeros(nv, dtype=H_0.dtype) # d_1: laser origin to laser illuminated point # d_2: laser illuminated point to x_v @@ -50,12 +50,14 @@ def backproject(H, laser_grid_xyz, sensor_grid_xyz, volume_xyz, d_3 = np.linalg.norm(x_v - x_s) t_i = int((d_1 + d_2 + d_3 + d_4 - t_start) / delta_t) if camera_system.is_transient(): - p = np.copy(H[:, s_i]) - p[t_i+nt-1:] = 0.0 - p[:t_i] = 0.0 - f[:, i_v] += np.roll(p, -t_i) + p = np.copy(H_0[:, s_i]) + if t_i > 0: + p[:t_i] = 0.0 + elif t_i < 0: + p[t_i+nt-1:] = 0.0 + H_1[:, i_v] += np.roll(p, -t_i) else: if t_i >= 0 and t_i < nt: - f[i_v] += H[t_i, s_i] + H_1[i_v] += H_0[t_i, s_i] - return f + return H_1 diff --git a/tal/reconstruct/filters.py b/tal/reconstruct/filters.py index 0f35962..0d6237b 100644 --- a/tal/reconstruct/filters.py +++ b/tal/reconstruct/filters.py @@ -3,7 +3,7 @@ import numpy as np -def filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, **kwargs): +def filter_H_impl(data, filter_name, data_format, border, plot_filter, return_filter, **kwargs): if isinstance(data, NLOSCaptureData): assert data.H_format != HFormat.UNKNOWN or data_format is not None, \ 'H format must be specified through the NLOSCaptureData object or the data_format argument' @@ -19,15 +19,11 @@ def filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, ** H = data delta_t = None - if H_format == HFormat.T_Sx_Sy: - nt, nsx, nsy = H.shape - elif H_format == HFormat.T_Lx_Ly_Sx_Sy: - nt, nlx, nly, nsx, nsy = H.shape - K_shape = (nt, 1, 1, 1, 1) - else: - raise AssertionError('Unknown H_format') + assert border in ['erase', 'zero', 'edge'], \ + 'border must be one of "erase", "zero" or "edge"' - padding = 0 + nt = H.shape[H_format.time_dim()] + nt_pad = nt if filter_name == 'pf': wl_mean = kwargs.get('wl_mean', None) @@ -39,40 +35,65 @@ def filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, ** assert delta_t is not None, \ 'For the "pf" filter, delta_t must be specified through an NLOSCaptureData object or the delta_t argument' - t_max = delta_t * (nt * 2 - 1) - t = np.linspace(start=0, stop=t_max, num=nt * 2) + t_6sigma = int(np.round(6 * wl_sigma / delta_t)) # used for padding + if t_6sigma % 2 == 1: + t_6sigma += 1 # its easier if padding is even + if return_filter: + nt_pad = t_6sigma + else: + nt_pad = nt + 2 * (t_6sigma - 1) + t_max = delta_t * (nt_pad - 1) + t = np.linspace(start=0, stop=t_max, num=nt_pad) # vvv Gaussian envelope (x = t - t_max/2, mu = 0, sigma = wl_sigma) - K = (1 / (wl_sigma * np.sqrt(2 * np.pi))) * \ - np.exp(-((t - t_max / 2) / wl_sigma) ** 2 / 2) * \ + gaussian_envelope = np.exp(-((t - t_max / 2) / wl_sigma) ** 2 / 2) + K = gaussian_envelope / np.sum(gaussian_envelope) * \ np.exp(2j * np.pi * t / wl_mean) # ^^^ Pulse inside the Gaussian envelope (complex exponential) # center at zero (not in freq. domain but fftshift works) - K = np.fft.fftshift(K) + K = np.fft.ifftshift(K) else: raise AssertionError( 'Unknown filter_name. Check the documentation for available filters') if plot_filter: import matplotlib.pyplot as plt - K_show = np.fft.ifftshift(K) - plt.plot(t[:len(K_show)] - t[len(K_show) // 2], np.real(K_show), c='b') - plt.plot(t[:len(K_show)] - t[len(K_show) // 2], - np.imag(K_show), c='b', linestyle='--') + K_show = np.fft.fftshift(K) + cut = (nt_pad - t_6sigma) // 2 + if cut > 0: + K_show = K_show[cut:-cut] + plt.plot(t[:len(K_show)] - t[len(K_show) // 2], np.real(K_show), + c='b') + plt.plot(t[:len(K_show)] - t[len(K_show) // 2], np.imag(K_show), + c='b', linestyle='--') + plt.plot(t[:len(K_show)] - t[len(K_show) // 2], np.abs(K_show), + c='r') plt.show() if return_filter: return K - # pad with identical, inverted signal - if H_format == HFormat.T_Sx_Sy or H_format == HFormat.T_Lx_Ly_Sx_Sy: - H = np.resize(H, (nt * 2, *H.shape[1:])) - H[nt:, ...] = H[:nt, ...][::-1, ...] - K_shape = (nt * 2,) + (1,) * (H.ndim - 1) + padding = (nt_pad - nt) + assert padding % 2 == 0 + padding //= 2 + + # pad with edge values + if H_format.time_dim() == 0: + mode = None + if border == 'edge': + mode = 'edge' + if border == 'zero' or border == 'erase': + mode = 'constant' + if mode: + H_pad = np.pad(H, + ((padding, padding),) + # first dim (time) + ((0, 0),) * (H.ndim - 1), # other dims + mode=mode) + K_shape = (nt_pad,) + (1,) * (H.ndim - 1) else: raise AssertionError('Unknown H_format') - H_fft = np.fft.fft(H, axis=0) + H_fft = np.fft.fft(H_pad, axis=0) K_fft = np.fft.fft(K) H_fft *= K_fft.reshape(K_shape) del K_fft @@ -80,7 +101,11 @@ def filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, ** del H_fft # undo padding - if H_format == HFormat.T_Sx_Sy or H_format == HFormat.T_Lx_Ly_Sx_Sy: - return HoK[:nt, ...] + if H_format.time_dim() == 0: + HoK = HoK[padding:-padding, ...] + if border == 'erase': + HoK[:padding//2] = 0 + HoK[-padding//2:] = 0 + return HoK else: raise AssertionError('Unknown H_format') diff --git a/tal/reconstruct/pf_dev/__init__.py b/tal/reconstruct/pf_dev/__init__.py new file mode 100644 index 0000000..b6566e9 --- /dev/null +++ b/tal/reconstruct/pf_dev/__init__.py @@ -0,0 +1,32 @@ +from tal.io.capture_data import NLOSCaptureData +from tal.enums import VolumeFormat, CameraSystem +from tal.reconstruct.utils import convert_to_N_3, convert_reconstruction_from_N_3 +import numpy as np + + +def solve(data: NLOSCaptureData, + wl_mean: float, + wl_sigma: float, + edges: str = 'zero', + volume_xyz: NLOSCaptureData.VolumeXYZType = None, + volume_format: VolumeFormat = None, + camera_system: CameraSystem = CameraSystem.STEADY) -> np.array: # FIXME type + """ + NOTE: Does _NOT_ attempt to compensate effects caused by attenuation: + - cos decay i.e. {sensor|laser}_grid_normals are ignored + - 1/d^2 decay + TODO(diego): docs + """ + H, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3 = \ + convert_to_N_3(data, volume_xyz, volume_format) + + from tal.reconstruct.pf_dev.phasor_fields import backproject_pf_multi_frequency + reconstructed_volume_n3 = backproject_pf_multi_frequency( + H, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3, + camera_system, data.t_accounts_first_and_last_bounces, + data.t_start, data.delta_t, + wl_mean, wl_sigma, edges, + data.laser_xyz, data.sensor_xyz, + progress=True) + + return convert_reconstruction_from_N_3(data, reconstructed_volume_n3, volume_xyz, volume_format, camera_system) diff --git a/tal/reconstruct/pf_dev/phasor_fields.py b/tal/reconstruct/pf_dev/phasor_fields.py new file mode 100644 index 0000000..eaec801 --- /dev/null +++ b/tal/reconstruct/pf_dev/phasor_fields.py @@ -0,0 +1,131 @@ +import numpy as np +import matplotlib.pyplot as plt +import tal +from tal.enums import HFormat +from tqdm import tqdm + + +def backproject_pf_single_frequency(H_0, delta_t, d, wavelength, progress=True): + """ + H_0: (nt, ns) + delta_t: scalar, in meters + d: (ns, nv) + wavelength: scalar, in meters + """ + assert H_0.ndim == 2, 'Incorrect H format' + + ns, nv = d.shape + nt, ns_ = H_0.shape + assert ns == ns_, 'Incorrect shape' + + propagator = np.exp(2j * np.pi * d / wavelength) + + t = delta_t * np.linspace(start=0, stop=nt - 1, num=nt) + e = np.exp(-2j * np.pi * t / wavelength) + H_0_w = np.sum(H_0 * e.reshape((nt, 1)), axis=0).reshape((ns, 1)) + # FIXME(diego): implement convolution in frequency domain + H_1_w = np.sum(H_0_w * propagator, axis=0) + + return H_1_w + + +def backproject_pf_multi_frequency( + H_0, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3, + camera_system, t_accounts_first_and_last_bounces, + t_start, delta_t, + wl_mean, wl_sigma, edges, + laser_xyz=None, sensor_xyz=None, progress=False): + assert H_0.ndim == 2, 'Incorrect H format' + assert volume_xyz_n3.ndim == 2 and volume_xyz_n3.shape[1] == 3, \ + 'Incorrect volume_xyz format, should be N_3' + assert laser_grid_xyz.size == 3, 'Only supports one laser position' + assert (not t_accounts_first_and_last_bounces or (laser_xyz is not None and sensor_xyz is not None)), \ + 't_accounts_first_and_last_bounces requires laser_xyz and sensor_xyz' + + nt, ns = H_0.shape + nv, _ = volume_xyz_n3.shape + + """ Phasor fields filter """ + + # FIXME(diego): obtain filter parameters directly in frequency + # domain (avoid conversion from time domain) + + t_6sigma = int(np.ceil(6 * wl_sigma / delta_t)) + padding = 2 * t_6sigma + # FIXME(diego) if we want to convert a circular convolution to linear, + # this should be nt + t_6sigma - 1 instead of nt + 4 * t_6sigma or even nt + 2 * t_6sigma + # I have found cases where it fails even with nt + 2 * t_6sigma (Z, 0th pixel) + nf = nt + 2 * padding + + t_max = delta_t * (nf - 1) + t = np.linspace(start=0, stop=t_max, num=nf) + + gaussian_envelope = np.exp(-((t - t_max / 2) / wl_sigma) ** 2 / 2) + K = gaussian_envelope / np.sum(gaussian_envelope) * \ + np.exp(2j * np.pi * t / wl_mean) + K = np.fft.fftshift(K) # center at zero + + mean_idx = (nf * delta_t) / wl_mean + sigma_idx = (nf * delta_t) / (wl_sigma * 6) + # shift to center at zero, easier for low negative frequencies + freq_min_idx = nf // 2 + int(np.floor(mean_idx - 3 * sigma_idx)) + freq_max_idx = nf // 2 + int(np.ceil(mean_idx + 3 * sigma_idx)) + K_fftshift = np.fft.fftshift(np.fft.fft(K)) + K_fftfreq = np.fft.fftshift(np.fft.fftfreq(nf, d=delta_t)) + + weights = K_fftshift[freq_min_idx:freq_max_idx+1] + freqs = K_fftfreq[freq_min_idx:freq_max_idx+1] + nw = len(weights) + + if edges == 'zero': + H_0 = np.pad(H_0, ((padding, padding), (0, 0)), 'constant') + else: + raise AssertionError('Implemented only for edges="zero"') + + """ Propagation of specific frequencies """ + + # d_3 + d = np.linalg.norm( + sensor_grid_xyz.reshape((ns, 1, 3)) - + volume_xyz_n3.reshape((1, nv, 3)), + axis=2) + d -= t_start + if t_accounts_first_and_last_bounces: + # d_1 + d += np.linalg.norm(laser_xyz.reshape(3) - + laser_grid_xyz.reshape(3)) + # d_4 + d += np.linalg.norm(sensor_xyz.reshape((1, 1, 3)) - + sensor_grid_xyz.reshape((ns, 1, 3)), axis=2) + if camera_system.bp_accounts_for_d_2(): + # d_2 + d += np.linalg.norm( + laser_grid_xyz.reshape((1, 1, 3)) - + volume_xyz_n3.reshape((1, nv, 3)), + axis=2) + + if camera_system.is_transient(): + H_1 = np.zeros((nt, nv), dtype=np.complex64) + else: + H_1 = np.zeros(nv, dtype=np.complex64) + + t = delta_t * np.linspace(start=0, stop=nf - 1, num=nf) + + iterator = zip(freqs, weights) + if progress: + iterator = tqdm(iterator, total=nw) + + for frequency, weight in iterator: + wavelength = np.inf if np.isclose(frequency, 0) else 1 / frequency + H_1_w = weight * \ + backproject_pf_single_frequency( + H_0, delta_t, d, wavelength, progress=progress) + e = np.exp(2 * np.pi * 1j * t * frequency) / nf + H_1_i = H_1_w.reshape((1, nv)) * e.reshape((nf, 1)) + + if camera_system.is_transient(): + H_1 += H_1_i[padding: -padding, ...] + else: + H_1 += H_1_i[padding] + + return H_1 diff --git a/tal/reconstruct/utils.py b/tal/reconstruct/utils.py new file mode 100644 index 0000000..894700f --- /dev/null +++ b/tal/reconstruct/utils.py @@ -0,0 +1,59 @@ +from tal.enums import HFormat, GridFormat, VolumeFormat, CameraSystem +from tal.io.capture_data import NLOSCaptureData +import numpy as np + + +def convert_to_N_3(data: NLOSCaptureData, + volume_xyz: NLOSCaptureData.VolumeXYZType, + volume_format: VolumeFormat): + if data.H_format == HFormat.T_Si: + H = data.H + elif data.H_format == HFormat.T_Sx_Sy: + nt, nsx, nsy = data.H.shape + H = data.H.reshape(nt, nsx * nsy) + else: + raise AssertionError(f'H_format {data.H_format} not implemented') + + # FIXME(diego) also convert laser data and confocal/exhaustive H measurements + + if data.sensor_grid_format == GridFormat.N_3: + sensor_grid_xyz = data.sensor_grid_xyz + elif data.sensor_grid_format == GridFormat.X_Y_3: + try: + assert nsx == data.sensor_grid_xyz.shape[0] and nsy == data.sensor_grid_xyz.shape[1], \ + 'sensor_grid_xyz.shape does not match with H.shape' + except NameError: + # nsx, nsy not defined, OK + nsx, nsy, _ = data.sensor_grid_xyz.shape + pass + sensor_grid_xyz = data.sensor_grid_xyz.reshape(nsx * nsy, 3) + else: + raise AssertionError( + f'sensor_grid_format {data.sensor_grid_format} not implemented') + + assert H.shape[1] == sensor_grid_xyz.shape[0], \ + 'H.shape does not match with sensor_grid_xyz.shape. Different number of points than measurements.' + + if volume_format == VolumeFormat.X_Y_Z_3: + volume_xyz_n3 = volume_xyz.reshape((-1, 3)) + elif volume_format == VolumeFormat.N_3: + volume_xyz_n3 = volume_xyz + else: + raise AssertionError('volume_format must be specified') + + return (H, data.laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3) + + +def convert_reconstruction_from_N_3(data: NLOSCaptureData, + reconstructed_volume_n3: np.ndarray, # FIXME type + volume_xyz: NLOSCaptureData.VolumeXYZType, + volume_format: VolumeFormat, + camera_system: CameraSystem): + if camera_system.is_transient(): + time_dim = (data.H.shape[data.H_format.time_dim()],) + else: + time_dim = () + + assert volume_format.xyz_dim_is_last(), 'Unexpected volume_format' + return reconstructed_volume_n3.reshape( + time_dim + volume_xyz.shape[:-1])