Skip to content

Commit

Permalink
use inheritance for check_vert_flux
Browse files Browse the repository at this point in the history
  • Loading branch information
daubners committed Nov 7, 2024
1 parent 982efc5 commit 7e04615
Showing 1 changed file with 14 additions and 37 deletions.
51 changes: 14 additions & 37 deletions taufactor/taufactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def __init__(self, img, bc=(-0.5, 0.5), D_0=1, device=torch.device('cuda')):
raise ValueError(
f'Input image must only contain 0s and 1s. Your image must be segmented to use this tool. If your image has been segmented, ensure your labels are 0 for non-conductive and 1 for conductive phase. Your image has the following labels: {torch.unique(img).numpy()}. If you have more than one conductive phase, use the multi-phase solver.')

# calculate

# init conc
self.conc = self.init_conc(img)
# create nn map
Expand Down Expand Up @@ -172,7 +170,6 @@ def check_convergence(self, verbose, conv_crit):
abs(self.top_bc - self.bot_bc)).cpu()
self.tau = self.VF / \
self.D_rel if self.D_rel != 0 else torch.tensor(torch.inf)


if verbose == 'per_iter':
print(
Expand All @@ -192,15 +189,16 @@ def check_convergence(self, verbose, conv_crit):

def calc_vertical_flux(self):
'''Calculates the vertical flux through the volume'''
vert_flux = self.conc[:, 1:-1, 1:-1, 1:-1] - \
self.conc[:, :-2, 1:-1, 1:-1]
vert_flux[self.conc[:, :-2, 1:-1, 1:-1] == 0] = 0
vert_flux[self.conc[:, 1:-1, 1:-1, 1:-1] == 0] = 0
# Indexing removes boundary layers (1 layer at every boundary)
vert_flux = self.conc[:, 2:-1, 1:-1, 1:-1] - \
self.conc[:, 1:-2, 1:-1, 1:-1]
vert_flux[self.conc[:, 1:-2, 1:-1, 1:-1] == 0] = 0
vert_flux[self.conc[:, 2:-1, 1:-1, 1:-1] == 0] = 0
return vert_flux

def check_vertical_flux(self, conv_crit):
vert_flux = self.calc_vertical_flux()
fl = torch.sum(vert_flux, (0, 2, 3))[1:-1]
fl = torch.sum(vert_flux, (0, 2, 3))
err = (fl.max() - fl.min())/(fl.max())
if fl.min() == 0:
return 'zero_flux', torch.mean(fl), err
Expand Down Expand Up @@ -292,21 +290,12 @@ def solve(self, iter_limit=5000, verbose=True, conv_crit=2*10**-2, D_0=1):

def calc_vertical_flux(self):
'''Calculates the vertical flux through the volume'''
vert_flux = abs(self.conc - torch.roll(self.conc, 1, 1))
vert_flux[self.conc == 0] = 0
vert_flux[torch.roll(self.conc, 1, 1) == 0] = 0
# Indexing removes 2 boundary layers at top and bottom
vert_flux = self.conc[:, 3:-2] - self.conc[:, 2:-3]
vert_flux[self.conc[:, 3:-2] == 0] = 0
vert_flux[self.conc[:, 2:-3] == 0] = 0
return vert_flux

def check_vertical_flux(self, conv_crit):
vert_flux = self.calc_vertical_flux()
fl = torch.sum(vert_flux, (0, 2, 3))[3:-2]
err = (fl.max() - fl.min())*2/(fl.max() + fl.min())
if err < conv_crit or torch.isnan(err).item():
return True, torch.mean(fl), err
if fl.min() == 0:
return 'zero_flux', torch.mean(fl), err
return False, torch.mean(fl), err


class MultiPhaseSolver(Solver):
"""
Expand Down Expand Up @@ -348,8 +337,6 @@ def __init__(self, img, cond={1: 1}, bc=(-0.5, 0.5), device=torch.device('cuda:0
# save original image in cuda
img = torch.tensor(img, dtype=self.precision, device=self.device)

# calculate

# init conc
self.conc = self.init_conc(img)
# create nn map
Expand Down Expand Up @@ -498,21 +485,11 @@ def check_convergence(self, verbose, conv_crit):

def calc_vertical_flux(self):
'''Calculates the vertical flux through the volume'''
vert_flux = (self.conc[:, 1:-1, 1:-1, 1:-1] - self.conc[:,
:-2, 1:-1, 1:-1]) * self.pre_factors[1][:, :-2, 1:-1, 1:-1]
vert_flux[self.nn == torch.inf] = 0
vert_flux = (self.conc[:, 2:-1, 1:-1, 1:-1] - self.conc[:,
1:-2, 1:-1, 1:-1]) * self.pre_factors[1][:, 1:-2, 1:-1, 1:-1]
vert_flux[self.nn[:,1:] == torch.inf] = 0
return vert_flux

def check_vertical_flux(self, conv_crit):
vert_flux = self.calc_vertical_flux()
fl = torch.sum(vert_flux, (0, 2, 3))[2:-2]
err = (fl.max() - fl.min())*2/(fl.max() + fl.min())
if err < conv_crit or torch.isnan(err).item():
return True, torch.mean(fl), err
if fl.min() == 0:
return 'zero_flux', torch.mean(fl), err
return False, torch.mean(fl), err


class ElectrodeSolver():
"""
Expand Down

0 comments on commit 7e04615

Please sign in to comment.