Skip to content

Commit

Permalink
fix: correct display time in tal plot txy_interactive (from #1) + gen…
Browse files Browse the repository at this point in the history
…eral cleanup
  • Loading branch information
diegoroyo committed Jan 31, 2023
1 parent db731ae commit 334c973
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 128 deletions.
59 changes: 20 additions & 39 deletions tal/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from tal.io.capture_data import NLOSCaptureData
from typing import Union, List
from enum import Enum


class ByAxis(Enum):
UNKNOWN = 0
T = 1
X = 2
Y = 3
Z = 4


_Data = Union[NLOSCaptureData, NLOSCaptureData.HType]
Expand All @@ -22,24 +13,6 @@ def xy_grid(data: Union[NLOSCaptureData, NLOSCaptureData.HType],
return plot_xy_grid(data, size_x, size_y, t_start, t_end, t_step)


def xy_interactive(data: _Data,
cmap: str = 'hot'):
from tal.plot.xy import ByAxis, plot_txy_interactive
return plot_txy_interactive(data, cmap, ByAxis.T)


def tx_interactive(data: _Data,
cmap: str = 'hot'):
from tal.plot.xy import ByAxis, plot_txy_interactive
return plot_txy_interactive(data, cmap, ByAxis.Y)


def ty_interactive(data: _Data,
cmap: str = 'hot'):
from tal.plot.xy import ByAxis, plot_txy_interactive
return plot_txy_interactive(data, cmap, ByAxis.X)


def t_comparison(data_list: _DataList,
x: int = None, y: int = None,
t_start: int = None, t_end: int = None,
Expand All @@ -49,21 +22,29 @@ def t_comparison(data_list: _DataList,
return plot_t_comparison(data_list, x, y, t_start, t_end, a_min, a_max, labels)


def txy_interactive(data: _Data,
cmap: str = 'hot', by: ByAxis = ByAxis.T):
from tal.plot.xy import plot_txy_interactive
return plot_txy_interactive(data, cmap, by)


def zxy_interactive(data: _Data,
cmap: str = 'hot', by: ByAxis = ByAxis.Z):
from tal.plot.xy import plot_zxy_interactive
return plot_zxy_interactive(data, cmap, by)


def volume(data: _Data, title: str = '', slider_title: str = 'Time',
slider_step: float = 0.1, cmap: str = 'hot',
opacity='sigmoid', backgroundcolor=None):
from tal.plot.plotter3d import plot3d
return plot3d(data, title, slider_title, slider_step, cmap, opacity,
backgroundcolor)


def xy_interactive(data: _Data, cmap: str = 'hot'):
from tal.plot.xy import plot_txy_interactive
return plot_txy_interactive(data, cmap, 't')


def tx_interactive(data: _Data, cmap: str = 'hot'):
from tal.plot.xy import plot_txy_interactive
return plot_txy_interactive(data, cmap, 'y')


def ty_interactive(data: _Data, cmap: str = 'hot'):
from tal.plot.xy import plot_txy_interactive
return plot_txy_interactive(data, cmap, 'x')


def txy_interactive(data: _Data, cmap: str = 'hot', slice_axis: str = 't'):
from tal.plot.xy import plot_txy_interactive
return plot_txy_interactive(data, cmap, slice_axis)
5 changes: 0 additions & 5 deletions tal/plot/plotter3d.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from nptyping import NDArray
from tal.io.capture_data import NLOSCaptureData
from tal.enums import HFormat
from tal.plot import ByAxis
from tal.util import SPEED_OF_LIGHT
from typing import Union
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.widgets import Slider, Button
import numpy as np
Expand Down
95 changes: 23 additions & 72 deletions tal/plot/xy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from nptyping import NDArray
from tal.io.capture_data import NLOSCaptureData
from tal.enums import HFormat
from tal.plot import ByAxis
from tal.util import SPEED_OF_LIGHT
from typing import Union
from nptyping import NDArray, Shape
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.widgets import Slider, Button
Expand Down Expand Up @@ -50,11 +49,10 @@ def t_to_time(
plt.show()


def plot_3d_interactive_axis(xyz: np.ndarray, focus_slider: np.ndarray,
def plot_3d_interactive_axis(xyz: np.ndarray, focus_slider: NDArray[Shape['T'], NLOSCaptureData.Float],
axis: int, title: str, slider_title: str,
slider_unit: str, cmap: str = 'hot',
xlabel: str = '', ylabel: str = ''):
"""More general plotting"""
assert xyz.ndim == 3, 'Unknown H_Format to plot'
assert axis < 3, f'Data only has 3 dims (given axis={axis})'
assert xyz.shape[axis] == len(focus_slider), \
Expand Down Expand Up @@ -83,14 +81,14 @@ def set_ax_common():
ax.set_xticks([])
ax.set_yticks([])

def scale_linear(event):
def scale_linear(_):
nonlocal img, scale
scale = 'linear'
ax.cla()
img = ax.imshow(xyz_p[current_idx], cmap=cmap, vmin=v_min, vmax=v_max)
set_ax_common()

def scale_log(event):
def scale_log(_):
nonlocal img, scale
scale = 'log'
ax.cla()
Expand All @@ -113,7 +111,7 @@ def scale_log(event):
orientation='horizontal')
ax_text = plt.axes([0.25, 0.05, 0.5, 0.03])
ax_text.axis('off')
text = ax_text.text(0, 0, f'{focus_slider[0]:.3f} {slider_unit}')
text = ax_text.text(0, 0, f'{focus_slider[0]:.2f} {slider_unit}')

ax_button_linear = plt.axes([0.05, 0.4, 0.1, 0.05])
button_linear = Button(ax_button_linear, 'Linear')
Expand All @@ -122,15 +120,15 @@ def scale_log(event):
button_log = Button(ax_button_log, 'Log')
button_log.on_clicked(scale_log)

def cmap_hot(event):
def cmap_hot(_):
nonlocal cmap
cmap = 'hot'
if scale == 'linear':
scale_linear(None)
else:
scale_log(None)

def cmap_nipyspectral(event):
def cmap_nipyspectral(_):
nonlocal cmap
cmap = 'nipy_spectral'
if scale == 'linear':
Expand All @@ -151,12 +149,12 @@ def update(i):
nonlocal current_idx
current_idx = int(i)
img.set_array(xyz_p[current_idx] + (1 if scale == 'log' else 0))
text.set_text(f'{focus_slider[current_idx]:.3f} {slider_unit}')
text.set_text(f'{focus_slider[current_idx]:.2f} {slider_unit}')

def update_prev(event):
def update_prev(_):
slider.set_val(0 if current_idx == 0 else current_idx - 1)

def update_next(event):
def update_next(_):
slider.set_val(max_idx if current_idx == max_idx else current_idx + 1)

ax_prev = plt.axes([0.15, 0.1, 0.03, 0.03])
Expand All @@ -171,106 +169,59 @@ def update_next(event):


def plot_txy_interactive(data: Union[NLOSCaptureData, NLOSCaptureData.HType],
cmap: str = 'hot', by: ByAxis = ByAxis.T):
cmap: str = 'hot', slice_axis: str = 't'):
if isinstance(data, NLOSCaptureData):
assert data.H_format == HFormat.T_Sx_Sy, \
'plot_txy_interactive does not support this data format'
txy = data.H
delta_t = data.delta_t
t_start = data.t_start
else:
assert data.ndim == 3, \
'plot_txy_interactive does not support this data format'
txy = data
delta_t = None
t_start = None

title = 'Impulse response '
# Plot parameters
if by == ByAxis.T:
title = 'H(t, x, y) '
if slice_axis == 't':
axis = 0
n_it = txy.shape[axis]
it_v = np.arange(n_it, dtype=np.float32)
title += 'by time'
title += '(T axis slices)'
slider_title = 'Bins'
slider_unit = 'index'
if delta_t is not None:
it_v *= delta_t
if delta_t is not None and t_start is not None:
it_v = (t_start + it_v * delta_t) * 1e12 / SPEED_OF_LIGHT
slider_unit = 'ps'
xlabel = 'x'
ylabel = 'y'
elif by == ByAxis.Y:
elif slice_axis == 'x':
axis = 1
n_it = txy.shape[axis]
it_v = np.arange(n_it, dtype=np.float32)
title += 'by y'
title += '(Y axis slices)'
slider_title = 'Planes'
slider_unit = 'index'
txy = txy.swapaxes(0, 2)
xlabel = 't'
ylabel = 'x'
elif by == ByAxis.X:
elif slice_axis == 'y':
axis = 2
n_it = txy.shape[axis]
it_v = np.arange(n_it, dtype=np.float32)
title += 'by x'
title += '(X axis slices)'
slider_title = 'Planes'
slider_unit = 'index'
txy = txy.swapaxes(0, 1)
xlabel = 't'
ylabel = 'y'
else:
raise 'plot_txy_interactive does not support this axis'
raise AssertionError('slice_axis must be one of ("t", "x", "y")')

# Plot the data
return plot_3d_interactive_axis(txy, it_v, axis=axis,
title=title,
slider_title=slider_title,
slider_unit=slider_unit,
cmap=cmap,
xlabel=xlabel, ylabel=ylabel)


def plot_zxy_interactive(data: NDArray, cmap: str = 'hot', by: ByAxis = ByAxis.Z):

assert data.ndim == 3, \
'plot_zxy_interactive does not support this data format'
zxy = data

title = 'Amplitude by '
# Plot parameters
if by == ByAxis.Z:
axis = 0
n_it = zxy.shape[axis]
it_v = np.arange(n_it, dtype=np.float32)
title += 'by z'
slider_title = 'Planes'
slider_unit = 'index'
xlabel = 'x'
ylabel = 'y'
elif by == ByAxis.Y:
axis = 1
n_it = zxy.shape[axis]
it_v = np.arange(n_it, dtype=np.float32)
title += 'by y'
slider_title = 'Planes'
slider_unit = 'index'
xlabel = 'x'
ylabel = 'z'
elif by == ByAxis.X:
axis = 2
n_it = zxy.shape[axis]
it_v = np.arange(n_it, dtype=np.float32)
title += 'by x'
slider_title = 'Planes'
slider_unit = 'index'
xlabel = 'y'
ylabel = 'z'
else:
raise 'plot_zxy_interactive does not support this axis'

# Plot the data
return plot_3d_interactive_axis(zxy, it_v, axis=axis,
title=title,
slider_title=slider_title,
slider_unit=slider_unit,
cmap=cmap,
xlabel=xlabel, ylabel=ylabel)
2 changes: 1 addition & 1 deletion tal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from textwrap import dedent, indent

SPEED_OF_LIGHT = 300_000_000
SPEED_OF_LIGHT = 299_792_458


def local_file_path(path):
Expand Down
14 changes: 3 additions & 11 deletions test/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import tal
from tal.plot import ByAxis
print('Start read')
data = tal.io.read_capture('test/data/ZNLOS_Z_10_Single.hdf5')
print('Finish read')


tal.plot.xy_grid(data, size_x=8, size_y=4)

tal.plot.txy_interactive(data, by = ByAxis.T)
tal.plot.txy_interactive(data, by = ByAxis.X)
tal.plot.txy_interactive(data, by = ByAxis.Y)

tal.plot.zxy_interactive(data.H, by = ByAxis.Z)
tal.plot.zxy_interactive(data.H, by = ByAxis.X)
tal.plot.zxy_interactive(data.H, by = ByAxis.Y)


tal.plot.txy_interactive(data, slice_axis='t')
tal.plot.txy_interactive(data, slice_axis='x')
tal.plot.txy_interactive(data, slice_axis='y')

0 comments on commit 334c973

Please sign in to comment.