Skip to content

Commit

Permalink
Optimizations in plot function
Browse files Browse the repository at this point in the history
  • Loading branch information
helvecioneto committed Feb 11, 2025
1 parent 3970085 commit 0994f33
Show file tree
Hide file tree
Showing 4 changed files with 105,134 additions and 35,914 deletions.
140,912 changes: 105,051 additions & 35,861 deletions examples/testes_MRGSPL/Track_Infrared_global.ipynb

Large diffs are not rendered by default.

120 changes: 76 additions & 44 deletions pyfortracc/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.patches as patches
from matplotlib import font_manager as mfonts
import matplotlib.patheffects as patheffects
import cartopy.io.img_tiles as cimgt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LinearSegmentedColormap
from shapely.wkt import loads
Expand All @@ -18,9 +17,6 @@
from pyfortracc.default_parameters import default_parameters


import cartopy.io.img_tiles as cimgt


def plot(name_list,
read_function=None,
timestamp='1970-01-01 00:00:00',
Expand All @@ -40,10 +36,10 @@ def plot(name_list,
scalebar_units='km',
min_val=None,
max_val=None,
nan_operation=np.less_equal,
nan_operation=None,
nan_value=0.01,
num_colors = 20,
title_fontsize=14,
title_fontsize=12,
grid_deg=None,
title='Track Plot',
time_zone='UTC',
Expand Down Expand Up @@ -76,24 +72,21 @@ def plot(name_list,
info_cols=['uid'],
save=False,
save_path='output/img/',
save_name=None,
read_data=True):
save_name=None):
"""
This function is designed to visualize tracking data on a map or a simple 2D plot.
The function reads in tracking data, filters it based on various criteria, and plots it using Matplotlib,
with optional customizations such as colorbars, boundaries, centroids, trajectories, and additional information annotations.
"""

if read_function is None and read_data:
print('Please set a read function to open the files!')
return None
# Plot by track
if 'output_path' not in name_list:
print('Please set the output name for the files!')
return None
elif name_list['output_path'] is None:
print('Please set the output name for the files!')
return None
# Get the tracking table
name_list = default_parameters(name_list, read_function)
track_files = name_list['output_path'] + 'track/trackingtable/'
# Check if trackingtable is a directory with parquet files
Expand All @@ -114,7 +107,7 @@ def plot(name_list,
tck_table = tck_table.loc[tck_table['uid'].isin(uid_list)]
if len(threshold_list) > 0:
tck_table = tck_table.loc[tck_table['threshold'].isin(threshold_list)]
#Check if tck_table is empty
#Check if tck_table is empty and plot empty plot
if len(tck_table) == 0:
fig = plt.figure(figsize=figsize)
# Add title to the figure
Expand All @@ -135,22 +128,59 @@ def plot(name_list,
tck_table['geometry'] = tck_table['geometry'].apply(loads)
tck_table['trajectory'] = tck_table['trajectory'].apply(loads)
tck_table = tck_table.set_geometry('geometry')
if read_data:
data = read_function(tck_table['file'].unique()[0])
data = np.where(nan_operation(data, nan_value), np.nan, data)
# Fit min and max values
if min_val is not None:
data = np.where(data <= min_val, min_val, data)

# Check nan_operation
if nan_operation is None:
# Get from name_list
nan_operation = name_list['operator']
# Reverse nan operator
if nan_operation == '==':
nan_operation = np.not_equal
elif nan_operation == '!=':
nan_operation = np.equal
elif nan_operation == '<':
nan_operation = np.greater
elif nan_operation == '>':
nan_operation = np.less
elif nan_operation == '<=':
nan_operation = np.greater_equal
elif nan_operation == '>=':
nan_operation = np.less_equal
else:
nan_operation = np.not_equal

# Use read_function to get the data or get data from tck_table
if read_function:
# Read the data
data = read_function(tck_table['file'].unique()[0])
# Set min and max values
if min_val is None:
min_val = np.nanmin(data)
if max_val is not None:
data = np.where(data >= max_val, max_val, data)
else:
if max_val is None:
max_val = np.nanmax(data)
# Apply nan_operation
data = np.where((data < min_val) | (data > max_val), np.nan, data)
else:
# Get array x, y and values
x = tck_table['array_x'].explode().values.astype(int)
y = tck_table['array_y'].explode().values.astype(int)
values = tck_table['array_values'].explode().values
# Create a nan matrix
data = np.full((y.max() + 1, x.max() + 1), np.nan)
# Fill the matrix with the values
data[y, x] = values

# Set min and max values
if min_val is None:
min_val = np.nanmin(data)
if max_val is None:
max_val = np.nanmax(data)
# Set of plot
cmap = plt.get_cmap(cmap)
colors = [cmap(i) for i in range(cmap.N)]
cmap = LinearSegmentedColormap.from_list('mycmap', colors, N=num_colors)

# Set of plot
cmap = plt.get_cmap(cmap)
colors = [cmap(i) for i in range(cmap.N)]
cmap = LinearSegmentedColormap.from_list('mycmap', colors, N=num_colors)

# Mount main figure
fig = plt.figure(figsize=figsize)
# Check if lon_min, lon_max, lat_min, lat_max are in name_list and if is not None
Expand Down Expand Up @@ -179,7 +209,7 @@ def plot(name_list,
ax.add_feature(cfeature.BORDERS, linestyle=':')
# Set grid
gl = ax.gridlines(crs= ccrs.PlateCarree(), draw_labels=True,
linewidth=1, color='gray', alpha=0.5, linestyle='--')
linewidth=1, color='gray', alpha=0.2, linestyle='--')
gl.top_labels = False
gl.right_labels = False
if grid_deg is not None:
Expand All @@ -194,26 +224,28 @@ def plot(name_list,
ax.xaxis.set_major_formatter(lon_formatter)
ax.yaxis.set_major_formatter(lat_formatter)
ax.tick_params(axis='both', which='major', labelsize=ticks_fontsize)

# Set plot type
if read_function:
if plot_type == 'imshow':
ax.imshow(data, cmap=cmap, origin='lower', extent=orig_extent,
interpolation=interpolation, aspect='auto',
zorder=10)
elif plot_type == 'contourf':
ax.contourf(data, cmap=cmap, origin='lower', extent=orig_extent,
interpolation=interpolation, zorder=10)
elif plot_type == 'contour':
ax.contour(data, cmap=cmap, origin='lower', extent=orig_extent,
interpolation=interpolation, zorder=10)
elif plot_type == 'pcolormesh':
lons = np.linspace(name_list['lon_min'], name_list['lon_max'], data.shape[1])
lats = np.linspace(name_list['lat_min'], name_list['lat_max'], data.shape[0])
ax.pcolormesh(lons, lats, data, transform= ccrs.PlateCarree(), cmap=cmap, zorder=10)
if plot_type == 'imshow':
ax.imshow(data, cmap=cmap, origin='lower', extent=orig_extent,
interpolation=interpolation, aspect='auto', vmax=max_val, vmin=min_val)
elif plot_type == 'contourf':
ax.contourf(data, cmap=cmap, origin='lower', extent=orig_extent,
interpolation=interpolation, vmax=max_val, vmin=min_val)
elif plot_type == 'contour':
ax.contour(data, cmap=cmap, origin='lower', extent=orig_extent,
interpolation=interpolation, vmax=max_val, vmin=min_val)
elif plot_type == 'pcolormesh':
lons = np.linspace(name_list['lon_min'], name_list['lon_max'], data.shape[1])
lats = np.linspace(name_list['lat_min'], name_list['lat_max'], data.shape[0])
ax.pcolormesh(lons, lats, data, transform= ccrs.PlateCarree(), cmap=cmap,
vmax=max_val, vmin=min_val)
else:
if ax is None: # Comming from animation
ax = fig.add_subplot(1, 1, 1)
ax.imshow(data, cmap=cmap, origin='lower', interpolation=interpolation, zorder=10)
ax.imshow(data, cmap=cmap, origin='lower', interpolation=interpolation,
aspect='auto', vmax=max_val, vmin=min_val)

# Add title to the figure
ax.text(0.5, 1.03, title +' ' + str(timestamp) + ' ' + time_zone,
horizontalalignment='center', fontsize=title_fontsize,
Expand All @@ -233,7 +265,7 @@ def plot(name_list,
plt.close(fig)
return fig
# Set colorbar
if cbar and read_function:
if cbar:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="2%", pad=pad, axes_class=plt.Axes)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=min_val,
Expand Down
14 changes: 7 additions & 7 deletions pyfortracc/plot/plot_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def plot_animation(
scalebar_units='km',
min_val=None,
max_val=None,
nan_operation=np.less_equal,
nan_operation=None,
nan_value=0.01,
num_colors = 20,
title_fontsize=14,
title_fontsize=12,
grid_deg=None,
title='Track Plot',
time_zone='UTC',
Expand Down Expand Up @@ -107,8 +107,7 @@ def plot_animation(
save=False,
save_path='output/',
save_name='plot.png',
parallel=True,
read_data=True):
parallel=True):
# Set the limit of the animation size
rcParams['animation.embed_limit'] = 2**128
# Set default parameters
Expand Down Expand Up @@ -198,8 +197,7 @@ def plot_animation(
info_cols,
save,
save_path,
save_name,
read_data))
save_name))
if parallel:
n_workers = set_nworkers(name_list)
with Pool(n_workers) as pool:
Expand All @@ -217,6 +215,8 @@ def plot_animation(

# Set up the figure for the animation
fig, ax = plt.subplots(figsize=figsize)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
fig.set_size_inches(figsize, forward=True)
img = ax.imshow(np.zeros((1, 1)), cmap=cmap, aspect='auto')
ax.axis('off')

Expand All @@ -230,7 +230,7 @@ def update(i):
ani = animation.FuncAnimation(fig, update, frames=len(frames),
interval=interval,
repeat=True,
blit=False,
blit=True,
repeat_delay=repeat_delay)
ani_html = ani.to_jshtml()
plt.close(fig)
Expand Down
2 changes: 0 additions & 2 deletions pyfortracc/spatial_operations/spatial_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,6 @@ def operations(cur_frme, prv_frme, threshold, l_edge, r_edg, nm_lst):
cur_frme, prv_frme)
cur_frme.loc[mrg_spl_idx,'status'] = 'MRG/SPL'
cur_frme.loc[mrg_spl_idx,'past_idx'] = prev_past_idx
# if mrg_spl_idx.size > 0:
# print('\n',cur_frme.loc[mrg_spl_idx][['timestamp','status','size']])

# Mount the trajectory LineString, distance and direction
# Select non null prev_idx is concat into a single array
Expand Down

0 comments on commit 0994f33

Please sign in to comment.