Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Through connectivity #106

Merged
merged 8 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ matplotlib>3.6
pytest-runner==5.2
numpy>=1.24.2
tifffile==2023.2.3
myst-parser==0.18.1
myst-parser==0.18.1
scikit-image>=0.20.0
scipy>=1.3
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
with open('HISTORY.md') as history_file:
history = history_file.read()

requirements = ['Click>=7.0', 'numpy>=1.0', 'matplotlib>=3.4', 'tifffile>=2023.2.3']
requirements = [
'Click>=7.0',
'numpy>=1.0',
'matplotlib>=3.4',
'tifffile>=2023.2.3',
'scikit-image>=0.20.0',
'scipy>=1.3'
]

setup_requirements = ['pytest-runner', ]

Expand Down
137 changes: 131 additions & 6 deletions taufactor/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
import torch.nn.functional as F

from scipy.ndimage import label, generate_binary_structure

def volume_fraction(img, phases={}):
"""
Calculates volume fractions of phases in an image
Expand All @@ -14,12 +16,14 @@ def volume_fraction(img, phases={}):
img = torch.tensor(img)

if phases=={}:
phases = torch.unique(img)
vf_out = []
for p in phases:
vf_out.append((img==p).to(torch.float).mean().item())
if len(vf_out)==1:
vf_out=vf_out[0]
volume = torch.numel(img)
labels, counts = torch.unique(img, return_counts=True)
labels = labels.int()
counts = counts.float()
counts /= volume
vf_out = {}
for i, label in enumerate(labels):
vf_out[str(label.item())] = counts[i].item()
else:
vf_out={}
for p in phases:
Expand Down Expand Up @@ -142,3 +146,124 @@ def triple_phase_boundary(img):
tpb += torch.sum(tpb_map)

return tpb/total_edges

def label_periodic(field, grayscale_value, neighbour_structure, periodic, debug=False):
# Initialize phi field whith enlarged dimensions in periodic directions. Boundary values of
# array are copied into ghost cells which are necessary to impose boundary conditions.
padx = int(periodic[0])
pady = int(periodic[1])
padz = int(periodic[2])
mask = np.pad(field, ((padx, padx), (pady, pady), (padz, padz)), mode='wrap')
labeled_mask, num_labels = label(mask==grayscale_value, structure=neighbour_structure)
count = 1
for k in range(100):
# Find indices where labels are different at the boundaries and create swaplist
swap_list = np.zeros((1,2))
if periodic[0]:
# right x
indices = np.where((labeled_mask[0,:,:]!=labeled_mask[-2,:,:]) & (labeled_mask[0,:,:]!=0) & (labeled_mask[-2,:,:]!=0))
additional_swaps = np.column_stack((labeled_mask[0,:,:][indices], labeled_mask[-2,:,:][indices]))
swap_list = np.row_stack((swap_list,additional_swaps))
# left x
indices = np.where((labeled_mask[1,:,:]!=labeled_mask[-1,:,:]) & (labeled_mask[1,:,:]!=0) & (labeled_mask[-1,:,:]!=0))
additional_swaps = np.column_stack((labeled_mask[1,:,:][indices], labeled_mask[-1,:,:][indices]))
swap_list = np.row_stack((swap_list,additional_swaps))
if periodic[1]:
# top y
indices = np.where((labeled_mask[:,0,:]!=labeled_mask[:,-2,:]) & (labeled_mask[:,0,:]!=0) & (labeled_mask[:,-2,:]!=0))
additional_swaps = np.column_stack((labeled_mask[:,0,:][indices], labeled_mask[:,-2,:][indices]))
swap_list = np.row_stack((swap_list,additional_swaps))
# bottom y
indices = np.where((labeled_mask[:,1,:]!=labeled_mask[:,-1,:]) & (labeled_mask[:,1,:]!=0) & (labeled_mask[:,-1,:]!=0))
additional_swaps = np.column_stack((labeled_mask[:,1,:][indices], labeled_mask[:,-1,:][indices]))
swap_list = np.row_stack((swap_list,additional_swaps))
if periodic[2]:
# front z
indices = np.where((labeled_mask[:,:,0]!=labeled_mask[:,:,-2]) & (labeled_mask[:,:,0]!=0) & (labeled_mask[:,:,-2]!=0))
additional_swaps = np.column_stack((labeled_mask[:,:,0][indices], labeled_mask[:,:,-2][indices]))
swap_list = np.row_stack((swap_list,additional_swaps))
# back z
indices = np.where((labeled_mask[:,:,1]!=labeled_mask[:,:,-1]) & (labeled_mask[:,:,1]!=0) & (labeled_mask[:,:,-1]!=0))
additional_swaps = np.column_stack((labeled_mask[:,:,1][indices], labeled_mask[:,:,-1][indices]))
swap_list = np.row_stack((swap_list,additional_swaps))
swap_list = swap_list[1:,:]
# Sort swap list columns to ensure consistent ordering
swap_list = np.sort(swap_list, axis=1)

# Remove duplicates from swap_list
swap_list = np.unique(swap_list, axis=0)
# print(f"swap_list contains {swap_list.shape[0]} elements.")
if (swap_list.shape[0]==0):
break
for i in range(swap_list.shape[0]):
index = swap_list.shape[0] - i -1
labeled_mask[labeled_mask == swap_list[index][1]] = swap_list[index][0]
count += 1
if(debug):
print(f"Did {count} iterations for periodic labelling.")
dim = labeled_mask.shape
return labeled_mask[padx:dim[0]-padx,pady:dim[1]-pady,padz:dim[2]-padz], np.unique(labeled_mask).size-1

def find_spanning_labels(labelled_array, axis):
"""
Find labels that appear on both ends along given axis

Returns:
set: Labels that appear on both ends of the first axis.
"""
if axis == "x":
front = np.s_[0,:,:]
end = np.s_[-1,:,:]
elif axis == "y":
front = np.s_[:,0,:]
end = np.s_[:,-1,:]
elif axis == "z":
front = np.s_[:,:,0]
end = np.s_[:,:,-1]
else:
raise ValueError("Axis should be x, y or z!")

first_slice_labels = np.unique(labelled_array[front])
last_slice_labels = np.unique(labelled_array[end])
spanning_labels = set(first_slice_labels) & set(last_slice_labels)
spanning_labels.discard(0) # Remove the background label if it exists
return spanning_labels

def extract_through_feature(array, grayscale_value, axis, periodic=[False,False,False], connectivity=1, debug=False):
if array.ndim != 3:
print(f"Expected a 3D array, but got an array with {array.ndim} dimension(s).")
return None

# Compute volume fraction of given grayscale value
vol_phase = volume_fraction(array, phases={'1': grayscale_value})['1']
if vol_phase == 0:
return 0, 0

# Define a list of connectivities to loop over
connectivities_to_loop_over = [connectivity] if connectivity else range(1, 4)
through_feature = []
through_feature_fraction = np.zeros(len(connectivities_to_loop_over))

# Compute the largest interconnected features depending on given connectivity
count = 0
for conn in connectivities_to_loop_over:
# connectivity 1 = cells connected by sides (6 neighbours)
# connectivity 2 = cells connected by sides & edges (14 neighbours)
# connectivity 3 = cells connected by sides & edges & corners (26 neighbours)
neighbour_structure = generate_binary_structure(3,conn)
# Label connected components in the mask with given neighbour structure
if any(periodic):
labeled_mask, num_labels = label_periodic(array, grayscale_value, neighbour_structure, periodic, debug=debug)
else:
labeled_mask, num_labels = label(array == grayscale_value, structure=neighbour_structure)
if(debug):
print(f"Found {num_labels} labelled regions. For connectivity {conn} and grayscale {grayscale_value}.")

through_labels = find_spanning_labels(labeled_mask,axis)
spanning_network = np.isin(labeled_mask, list(through_labels))

through_feature.append(spanning_network)
through_feature_fraction[count] = volume_fraction(spanning_network, phases={'1': 1})['1']/vol_phase
count += 1

return through_feature, through_feature_fraction
5 changes: 4 additions & 1 deletion taufactor/taufactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
except ImportError:
raise ImportError("Pytorch is required to use this package. Please install pytorch and try again. More information about TauFactor's requirements can be found at https://taufactor.readthedocs.io/en/latest/")
import warnings
from .metrics import extract_through_feature

class BaseSolver:
def __init__(self, img, bc=(-0.5, 0.5), device=torch.device('cuda')):
Expand Down Expand Up @@ -89,7 +90,9 @@ def check_vertical_flux(self, conv_crit):
fl = torch.sum(vert_flux, (0, 2, 3))
err = (fl.max() - fl.min())/(fl.max())
if fl.min() == 0:
return 'zero_flux', torch.mean(fl), err
_ , frac = extract_through_feature(self.cpu_img[0], 1, 'x')
if frac == 0:
return 'zero_flux', torch.mean(fl), err
if err < conv_crit or torch.isnan(err).item():
return True, torch.mean(fl), err
return False, torch.mean(fl), err
Expand Down
8 changes: 4 additions & 4 deletions tests/test_taufactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_volume_fraction_on_uniform_block():
"""Run volume fraction on uniform block"""
l = 20
img = np.ones([l, l, l]).reshape(1, l, l, l)
vf = volume_fraction(img)
vf = volume_fraction(img)['1']

assert np.around(vf, decimals=5) == 1.0

Expand All @@ -134,7 +134,7 @@ def test_volume_fraction_on_empty_block():
"""Run volume fraction on empty block"""
l = 20
img = np.zeros([l, l, l]).reshape(1, l, l, l)
vf = volume_fraction(img)
vf = volume_fraction(img)['0']

assert np.around(vf, decimals=5) == 1.0

Expand All @@ -143,9 +143,9 @@ def test_volume_fraction_on_checkerboard():
"""Run volume fraction on checkerboard block"""
l = 20
img = generate_checkerboard(l)
vf = volume_fraction(img)
vf = volume_fraction(img, phases={'zeros': 0, 'ones': 1})

assert vf == [0.5, 0.5]
assert (vf['zeros'], vf['ones']) == (0.5, 0.5)


def test_volume_fraction_on_strip_of_ones():
Expand Down
Loading