Skip to content

Commit

Permalink
feat: faster and better filters, improve bp module, new pf_dev module
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoroyo committed Nov 30, 2022
1 parent 723fd45 commit bcffe80
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 92 deletions.
2 changes: 1 addition & 1 deletion tal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from tal.plot import *
from tal.reconstruct import *

__version__ = '0.6.4'
__version__ = '0.7.0'
14 changes: 14 additions & 0 deletions tal/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ class HFormat(Enum):
T_Si = 3 # confocal or not
T_Si_Li = 4

def time_dim(self) -> int:
assert self in [HFormat.T_Sx_Sy,
HFormat.T_Lx_Ly_Sx_Sy,
HFormat.T_Si,
HFormat.T_Si_Li], \
f'Unexpected HFormat {self}'
return 0


class GridFormat(Enum):
UNKNOWN = 0
Expand All @@ -29,6 +37,12 @@ class VolumeFormat(Enum):
X_Y_Z_3 = 2
X_Y_3 = 3

def xyz_dim_is_last(self) -> bool:
assert self in [VolumeFormat.N_3,
VolumeFormat.X_Y_Z_3,
VolumeFormat.X_Y_3]
return True


class CameraSystem(Enum):
STEADY = 0 # focused light
Expand Down
11 changes: 9 additions & 2 deletions tal/reconstruct/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tal.reconstruct.pf import *
from tal.reconstruct.bp import *
from tal.reconstruct.pf import *
from tal.reconstruct.pf_dev import *
from tal.io.capture_data import NLOSCaptureData
from tal.enums import HFormat
from typing import Union
Expand All @@ -10,12 +11,17 @@
def filter_H(data: _Data,
filter_name: str,
data_format: HFormat = HFormat.UNKNOWN,
border: str = 'zero',
plot_filter: bool = False,
return_filter: bool = False,
**kwargs) -> NLOSCaptureData.HType:
"""
Filter a captured data signal (H) using specified filter_name
* data_format should be non-null if not specified through data
* border sets the behaviour for the edges of the convolution
- 'erase': the filtered signal has the edges set to zero
- 'zero': before filtering, pad the signal with zeros
- 'edge': before filtering, pad the signal with edge values
* If plot_filter=True, shows a plot of the resulting filter
* If return_filter=True, returns the filter (K)
else, returns the filtered signal (H * K)
Expand All @@ -24,10 +30,11 @@ def filter_H(data: _Data,
* wl_mean: Mean of the Gaussian in the frequency domain
* wl_sigma: STD of the Gaussian in the frequency domain
* delta_t: Time interval, must be non-null if not specified through data
FIXME(diego): is this sentence true? vvv probably not
e.g. mean = 3, sigma = 0.5 will filter frequencies of ~2-4m
"""
from tal.reconstruct.filters import filter_H_impl
return filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, **kwargs)
return filter_H_impl(data, filter_name, data_format, border, plot_filter, return_filter, **kwargs)


def get_volume_min_max_resolution(minimal_pos, maximal_pos, resolution):
Expand Down
60 changes: 8 additions & 52 deletions tal/reconstruct/bp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,28 @@
from tal.io.capture_data import NLOSCaptureData
from tal.enums import HFormat, GridFormat, VolumeFormat, CameraSystem
from typing import Union

from tal.enums import VolumeFormat, CameraSystem
from tal.reconstruct.utils import convert_to_N_3, convert_reconstruction_from_N_3
import numpy as np
_Data = Union[NLOSCaptureData, NLOSCaptureData.HType]


def solve(data: NLOSCaptureData,
volume_xyz: NLOSCaptureData.VolumeXYZType = None,
volume_format: VolumeFormat = None,
camera_system: CameraSystem = CameraSystem.STEADY) -> np.array:
camera_system: CameraSystem = CameraSystem.STEADY) -> np.array: # FIXME type
"""
NOTE: Does _NOT_ attempt to compensate effects caused by attenuation:
- cos decay i.e. {sensor|laser}_grid_normals are ignored
- 1/d^2 decay
TODO(diego): docs
"""
if data.H_format == HFormat.T_Si:
H = data.H
elif data.H_format == HFormat.T_Sx_Sy:
nt, nsx, nsy = data.H.shape
H = data.H.reshape(nt, nsx * nsy)
else:
raise AssertionError(f'H_format {data.H_format} not implemented')

if data.sensor_grid_format == GridFormat.N_3:
sensor_grid_xyz = data.sensor_grid_xyz
elif data.sensor_grid_format == GridFormat.X_Y_3:
try:
assert nsx == data.sensor_grid_xyz.shape[0] and nsy == data.sensor_grid_xyz.shape[1], \
'sensor_grid_xyz.shape does not match with H.shape'
except NameError:
# nsx, nsy not defined, OK
nsx, nsy, _ = data.sensor_grid_xyz.shape
pass
sensor_grid_xyz = data.sensor_grid_xyz.reshape(nsx * nsy, 3)
else:
raise AssertionError(
f'sensor_grid_format {data.sensor_grid_format} not implemented')

if volume_format == VolumeFormat.X_Y_Z_3:
nvx, nvy, nvz, _ = volume_xyz.shape
volume_xyz_n3 = volume_xyz.reshape((-1, 3))
elif volume_format == VolumeFormat.N_3:
volume_xyz_n3 = volume_xyz
else:
raise AssertionError('volume_format must be specified')
H, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3 = \
convert_to_N_3(data, volume_xyz, volume_format)

from tal.reconstruct.bp.backprojection import backproject
reconstructed_volume_n3 = backproject(
H, data.laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3,
H, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3,
camera_system, data.t_accounts_first_and_last_bounces,
data.t_start, data.delta_t,
data.laser_xyz, data.sensor_xyz,
progress=True)

if camera_system.is_transient():
assert data.H_format == HFormat.T_Si or data.H_format == HFormat.T_Sx_Sy or data.H_format == HFormat.T_Lx_Ly_Sx_Sy, \
'Cannot find time dimension given H_format'
time_dim = (data.H.shape[0],)
else:
time_dim = ()

if volume_format == VolumeFormat.X_Y_Z_3:
reconstructed_volume = reconstructed_volume_n3.reshape(
time_dim + (nvx, nvy, nvz))
elif volume_format == VolumeFormat.N_3:
reconstructed_volume = reconstructed_volume_n3
else:
raise AssertionError('volume_format must be specified')

return reconstructed_volume
return convert_reconstruction_from_N_3(data, reconstructed_volume_n3, volume_xyz, volume_format, camera_system)
24 changes: 13 additions & 11 deletions tal/reconstruct/bp/backprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
from tqdm import tqdm


def backproject(H, laser_grid_xyz, sensor_grid_xyz, volume_xyz,
def backproject(H_0, laser_grid_xyz, sensor_grid_xyz, volume_xyz,
camera_system, t_accounts_first_and_last_bounces,
t_start, delta_t,
laser_xyz=None, sensor_xyz=None, progress=False):
# TODO(diego): extend for multiple laser points
assert H.ndim == 2 and laser_grid_xyz.size == 3, \
assert H_0.ndim == 2 and laser_grid_xyz.size == 3, \
'backproject only supports one laser point'

nt, ns = H.shape
nt, ns = H_0.shape
assert sensor_grid_xyz.shape[0] == ns, 'H does not match with sensor_grid_xyz'
assert (not t_accounts_first_and_last_bounces or (laser_xyz is not None and sensor_xyz is not None)), \
't_accounts_first_and_last_bounces requires laser_xyz and sensor_xyz'
ns, _ = sensor_grid_xyz.shape
nv, _ = volume_xyz.shape

if camera_system.is_transient():
f = np.zeros((nt, nv), dtype=H.dtype)
H_1 = np.zeros((nt, nv), dtype=H_0.dtype)
else:
f = np.zeros(nv, dtype=H.dtype)
H_1 = np.zeros(nv, dtype=H_0.dtype)

# d_1: laser origin to laser illuminated point
# d_2: laser illuminated point to x_v
Expand Down Expand Up @@ -50,12 +50,14 @@ def backproject(H, laser_grid_xyz, sensor_grid_xyz, volume_xyz,
d_3 = np.linalg.norm(x_v - x_s)
t_i = int((d_1 + d_2 + d_3 + d_4 - t_start) / delta_t)
if camera_system.is_transient():
p = np.copy(H[:, s_i])
p[t_i+nt-1:] = 0.0
p[:t_i] = 0.0
f[:, i_v] += np.roll(p, -t_i)
p = np.copy(H_0[:, s_i])
if t_i > 0:
p[:t_i] = 0.0
elif t_i < 0:
p[t_i+nt-1:] = 0.0
H_1[:, i_v] += np.roll(p, -t_i)
else:
if t_i >= 0 and t_i < nt:
f[i_v] += H[t_i, s_i]
H_1[i_v] += H_0[t_i, s_i]

return f
return H_1
77 changes: 51 additions & 26 deletions tal/reconstruct/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np


def filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, **kwargs):
def filter_H_impl(data, filter_name, data_format, border, plot_filter, return_filter, **kwargs):
if isinstance(data, NLOSCaptureData):
assert data.H_format != HFormat.UNKNOWN or data_format is not None, \
'H format must be specified through the NLOSCaptureData object or the data_format argument'
Expand All @@ -19,15 +19,11 @@ def filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, **
H = data
delta_t = None

if H_format == HFormat.T_Sx_Sy:
nt, nsx, nsy = H.shape
elif H_format == HFormat.T_Lx_Ly_Sx_Sy:
nt, nlx, nly, nsx, nsy = H.shape
K_shape = (nt, 1, 1, 1, 1)
else:
raise AssertionError('Unknown H_format')
assert border in ['erase', 'zero', 'edge'], \
'border must be one of "erase", "zero" or "edge"'

padding = 0
nt = H.shape[H_format.time_dim()]
nt_pad = nt

if filter_name == 'pf':
wl_mean = kwargs.get('wl_mean', None)
Expand All @@ -39,48 +35,77 @@ def filter_H_impl(data, filter_name, data_format, plot_filter, return_filter, **
assert delta_t is not None, \
'For the "pf" filter, delta_t must be specified through an NLOSCaptureData object or the delta_t argument'

t_max = delta_t * (nt * 2 - 1)
t = np.linspace(start=0, stop=t_max, num=nt * 2)
t_6sigma = int(np.round(6 * wl_sigma / delta_t)) # used for padding
if t_6sigma % 2 == 1:
t_6sigma += 1 # its easier if padding is even
if return_filter:
nt_pad = t_6sigma
else:
nt_pad = nt + 2 * (t_6sigma - 1)
t_max = delta_t * (nt_pad - 1)
t = np.linspace(start=0, stop=t_max, num=nt_pad)

# vvv Gaussian envelope (x = t - t_max/2, mu = 0, sigma = wl_sigma)
K = (1 / (wl_sigma * np.sqrt(2 * np.pi))) * \
np.exp(-((t - t_max / 2) / wl_sigma) ** 2 / 2) * \
gaussian_envelope = np.exp(-((t - t_max / 2) / wl_sigma) ** 2 / 2)
K = gaussian_envelope / np.sum(gaussian_envelope) * \
np.exp(2j * np.pi * t / wl_mean)
# ^^^ Pulse inside the Gaussian envelope (complex exponential)

# center at zero (not in freq. domain but fftshift works)
K = np.fft.fftshift(K)
K = np.fft.ifftshift(K)
else:
raise AssertionError(
'Unknown filter_name. Check the documentation for available filters')

if plot_filter:
import matplotlib.pyplot as plt
K_show = np.fft.ifftshift(K)
plt.plot(t[:len(K_show)] - t[len(K_show) // 2], np.real(K_show), c='b')
plt.plot(t[:len(K_show)] - t[len(K_show) // 2],
np.imag(K_show), c='b', linestyle='--')
K_show = np.fft.fftshift(K)
cut = (nt_pad - t_6sigma) // 2
if cut > 0:
K_show = K_show[cut:-cut]
plt.plot(t[:len(K_show)] - t[len(K_show) // 2], np.real(K_show),
c='b')
plt.plot(t[:len(K_show)] - t[len(K_show) // 2], np.imag(K_show),
c='b', linestyle='--')
plt.plot(t[:len(K_show)] - t[len(K_show) // 2], np.abs(K_show),
c='r')
plt.show()
if return_filter:
return K

# pad with identical, inverted signal
if H_format == HFormat.T_Sx_Sy or H_format == HFormat.T_Lx_Ly_Sx_Sy:
H = np.resize(H, (nt * 2, *H.shape[1:]))
H[nt:, ...] = H[:nt, ...][::-1, ...]
K_shape = (nt * 2,) + (1,) * (H.ndim - 1)
padding = (nt_pad - nt)
assert padding % 2 == 0
padding //= 2

# pad with edge values
if H_format.time_dim() == 0:
mode = None
if border == 'edge':
mode = 'edge'
if border == 'zero' or border == 'erase':
mode = 'constant'
if mode:
H_pad = np.pad(H,
((padding, padding),) + # first dim (time)
((0, 0),) * (H.ndim - 1), # other dims
mode=mode)
K_shape = (nt_pad,) + (1,) * (H.ndim - 1)
else:
raise AssertionError('Unknown H_format')

H_fft = np.fft.fft(H, axis=0)
H_fft = np.fft.fft(H_pad, axis=0)
K_fft = np.fft.fft(K)
H_fft *= K_fft.reshape(K_shape)
del K_fft
HoK = np.fft.ifft(H_fft, axis=0)
del H_fft

# undo padding
if H_format == HFormat.T_Sx_Sy or H_format == HFormat.T_Lx_Ly_Sx_Sy:
return HoK[:nt, ...]
if H_format.time_dim() == 0:
HoK = HoK[padding:-padding, ...]
if border == 'erase':
HoK[:padding//2] = 0
HoK[-padding//2:] = 0
return HoK
else:
raise AssertionError('Unknown H_format')
32 changes: 32 additions & 0 deletions tal/reconstruct/pf_dev/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from tal.io.capture_data import NLOSCaptureData
from tal.enums import VolumeFormat, CameraSystem
from tal.reconstruct.utils import convert_to_N_3, convert_reconstruction_from_N_3
import numpy as np


def solve(data: NLOSCaptureData,
wl_mean: float,
wl_sigma: float,
edges: str = 'zero',
volume_xyz: NLOSCaptureData.VolumeXYZType = None,
volume_format: VolumeFormat = None,
camera_system: CameraSystem = CameraSystem.STEADY) -> np.array: # FIXME type
"""
NOTE: Does _NOT_ attempt to compensate effects caused by attenuation:
- cos decay i.e. {sensor|laser}_grid_normals are ignored
- 1/d^2 decay
TODO(diego): docs
"""
H, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3 = \
convert_to_N_3(data, volume_xyz, volume_format)

from tal.reconstruct.pf_dev.phasor_fields import backproject_pf_multi_frequency
reconstructed_volume_n3 = backproject_pf_multi_frequency(
H, laser_grid_xyz, sensor_grid_xyz, volume_xyz_n3,
camera_system, data.t_accounts_first_and_last_bounces,
data.t_start, data.delta_t,
wl_mean, wl_sigma, edges,
data.laser_xyz, data.sensor_xyz,
progress=True)

return convert_reconstruction_from_N_3(data, reconstructed_volume_n3, volume_xyz, volume_format, camera_system)
Loading

0 comments on commit bcffe80

Please sign in to comment.