Skip to content

Commit

Permalink
Refactor Volume class to support multiple data types (#282)
Browse files Browse the repository at this point in the history
* Refactor volume to allow multiple data types

* Update code to pass tests

* Update src/gemdat/path.py

Co-authored-by: SCiarella <58949181+SCiarella@users.noreply.github.com>

* Update src/gemdat/volume.py

Co-authored-by: SCiarella <58949181+SCiarella@users.noreply.github.com>

* Simplify Volume dataclass

* Remove Volume.data

* Pass tests

* Calculate resolution from dims

* Use more accurate voxel to frac conversion

* Remove redundant frac/cart path methods, use volume methods instead

* Merge voxel_size / resolution

* Save multiple volumes to volumetric data / vesta output

* Remove cost, duplicate of total_energy

---------

Co-authored-by: SCiarella <58949181+SCiarella@users.noreply.github.com>
  • Loading branch information
stefsmeets and SCiarella authored Mar 25, 2024
1 parent 63919cb commit f73fc3e
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 211 deletions.
109 changes: 30 additions & 79 deletions src/gemdat/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class Pathway:
List of the energy along the path
"""

sites: list[tuple[int, int, int]] | None = None
energy: list[float] | None = None
sites: list[tuple[int, int, int]]
energy: list[float]

def __repr__(self):
s = (
Expand All @@ -44,52 +44,7 @@ def total_energy(self):
"""Return total energy for path."""
return sum(self.energy)

def cartesian_path(self,
volume: Volume) -> list[tuple[float, float, float]]:
"""Convert voxel coordinates to cartesian coordinates.
Parameters
----------
volume : Volume
Volume object containing the grid information
Returns
-------
cart_sites: list[tuple]
List of cartesian coordinates of the sites defining the path
"""
cart_sites = []
if self.sites is None:
raise ValueError('Voxel coordinates of the path are required.')
for site in self.fractional_path(volume=volume):
cartesian_coords = volume.lattice.get_cartesian_coords(site)
cart_sites.append(tuple(cartesian_coords))
return cart_sites

def fractional_path(self,
volume: Volume) -> list[tuple[float, float, float]]:
"""Convert voxel coordinates to fractional coordinates.
Parameters
----------
volume : Volume
Volume object containing the grid information
Returns
-------
frac_sites: list[tuple]
List of fractional coordinates of the sites defining the path
"""
if self.sites is None:
raise ValueError('Voxel coordinates of the path are required.')
frac_sites = []
for site in self.sites:
fractional_coords = site / np.asarray(
[x // volume.resolution for x in volume.lattice.lengths])
frac_sites.append(tuple(fractional_coords))
return frac_sites

def wrap(self, F: np.ndarray):
def wrap(self, dims: tuple[int, int, int]):
"""Wrap path in periodic boundary conditions in-place.
Parameters
Expand All @@ -100,7 +55,7 @@ def wrap(self, F: np.ndarray):
if self.sites is None:
raise ValueError('Voxel coordinates of the path are required.')

X, Y, Z = F.shape
X, Y, Z = dims
self.sites = [(x % X, y % Y, z % Z) for x, y, z in self.sites]

def path_over_structure(
Expand All @@ -124,7 +79,7 @@ def path_over_structure(
nearest_structure_coord: list[np.ndarray]
List of cartesian coordinates of the closest site of the reference structure
"""
frac_sites = self.fractional_path(volume)
frac_sites = volume.voxel_to_frac_coords(np.array(self.sites))
nearest_structure_tree, nearest_structure_map = nearest_structure_reference(
structure)

Expand All @@ -144,13 +99,6 @@ def path_over_structure(

return nearest_structure_label, nearest_structure_coord

@property
def cost(self) -> float:
"""Calculate the path cost."""
if self.energy is None:
raise ValueError('Energy of the path is required.')
return np.sum(self.energy)

@property
def start_site(self) -> tuple[int, int, int]:
"""Return first site."""
Expand All @@ -166,14 +114,14 @@ def stop_site(self) -> tuple[int, int, int]:
return self.sites[-1]


def free_energy_graph(F: np.ndarray,
def free_energy_graph(F: np.ndarray | Volume,
max_energy_threshold: float = 1e20,
diagonal: bool = True) -> nx.Graph:
"""Compute the graph of the free energy for networkx functions.
Parameters
----------
F : np.ndarray
F : np.ndarray | Volume
Free energy on the 3d grid
max_energy_threshold : float, optional
Maximum energy threshold for the path to be considered valid
Expand All @@ -199,14 +147,18 @@ def free_energy_graph(F: np.ndarray,
movements = np.vstack((movements, diagonal_movements))

G = nx.Graph()
for index, Fi in np.ndenumerate(F):

data = F.data if isinstance(F, Volume) else F

for index, Fi in np.ndenumerate(data):
if 0 <= Fi < max_energy_threshold:
G.add_node(index, energy=Fi)

for node in G.nodes:
for move in movements:
neighbor = tuple((node + move) % F.shape)
neighbor = tuple((node + move) % data.shape)
if neighbor in G.nodes:
weight = 0.5 * (F[node] + F[neighbor])
weight = 0.5 * (data[node] + data[neighbor])
exp_n_energy = np.exp(weight)
if exp_n_energy < max_energy_threshold:
weight_exp = exp_n_energy
Expand Down Expand Up @@ -450,19 +402,19 @@ def _optimal_path_minmax_energy(
return optimal_path


def find_best_perc_path(F: np.ndarray,
volume: Volume,
def find_best_perc_path(F: Volume,
peaks: np.ndarray,
percolate_x: bool = True,
percolate_y: bool = False,
percolate_z: bool = False) -> Pathway:
percolate_z: bool = False) -> Pathway | None:
"""Calculate the best percolating path.
Parameters
----------
F : np.ndarray
F : Volume
Energy grid that will be used to calculate the shortest path
volume : Volume
Volume object containing the grid information
peaks : np.ndarray
List of the peaks that correspond to high probability regions
percolate_x : bool
If True, consider paths that percolate along the x dimension
percolate_y : bool
Expand All @@ -475,19 +427,18 @@ def find_best_perc_path(F: np.ndarray,
best_percolating_path: Pathway
Optimal path that percolates the graph in the specified directions
"""
xyz_real = F.shape
xyz_real = F.dims

# Find percolation using virtual images along the required dimensions
if not any([percolate_x, percolate_y, percolate_z]):
print('Warning: percolation is not defined')
return Pathway()
raise ValueError('percolation is not defined')

# Tile the grind in the percolation directions
F_periodic = np.tile(F,
(1 + percolate_x, 1 + percolate_y, 1 + percolate_z))
F_data_periodic = np.tile(
F.data, (1 + percolate_x, 1 + percolate_y, 1 + percolate_z))

# Get F on a graph
F_graph = free_energy_graph(F_periodic,
F_graph = free_energy_graph(F_data_periodic,
max_energy_threshold=1e7,
diagonal=True)

Expand All @@ -498,9 +449,8 @@ def find_best_perc_path(F: np.ndarray,

# Find the lowest cost path that percolates along the x dimension
best_cost = float('inf')
best_path = Pathway()
best_path = None

peaks = volume.find_peaks()
for start_point in peaks:

# Get the stop point which is a periodic image of the peak
Expand All @@ -516,13 +466,14 @@ def find_best_perc_path(F: np.ndarray,
except nx.NetworkXNoPath:
continue

cost = path.cost
cost = path.total_energy

if cost < best_cost:
best_cost = cost
best_path = path

# Before returning, wrap the path in the original volume
best_path.wrap(F)
if best_path:
# Before returning, wrap the path in the original volume
best_path.wrap(F.dims)

return best_path
6 changes: 3 additions & 3 deletions src/gemdat/plots/plotly/_plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def plot_volume(

for i, isoval in enumerate(isovals):
isoval = isoval * np.max(data)
verts, faces, _, _ = measure.marching_cubes(data, level=isoval)
verts, faces, *_ = measure.marching_cubes(data, level=isoval)

# Transform verts to cartesian system
verts = (verts + 0.5) / np.array(data.shape)
Expand Down Expand Up @@ -201,7 +201,7 @@ def plot_paths(
else:
optimal_path = paths

x_path, y_path, z_path = np.asarray(optimal_path.cartesian_path(volume)).T
x_path, y_path, z_path = volume.voxel_to_cart_coords(optimal_path.sites).T

fig.add_trace(
go.Scatter3d(
Expand All @@ -222,7 +222,7 @@ def plot_paths(
# If available, plot the other pathways
if isinstance(paths, list):
for idx, path in enumerate(paths[1:]):
x_path, y_path, z_path = np.asarray(path.cartesian_path(volume)).T
x_path, y_path, z_path = volume.voxel_to_cart_coords(path.sites).T

fig.add_trace(
go.Scatter3d(
Expand Down
Loading

0 comments on commit f73fc3e

Please sign in to comment.