Skip to content

Commit

Permalink
move self.model_* to EynollaDirs
Browse files Browse the repository at this point in the history
  • Loading branch information
kba committed Aug 24, 2024
1 parent 59dbffe commit b954a55
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 60 deletions.
109 changes: 50 additions & 59 deletions qurator/eynollah/eynollah.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,6 @@ def __init__(
textline_light = self.textline_light,
pcgts=pcgts)


self.model_dir_of_enhancement = dirs.dir_models + "/eynollah-enhancement_20210425"
self.model_dir_of_binarization = dirs.dir_models + "/eynollah-binarization_20210425"
self.model_dir_of_col_classifier = dirs.dir_models + "/eynollah-column-classifier_20210425"
# FIXME: unused
# self.model_region_dir_p = dirs.dir_models + "/eynollah-main-regions-aug-scaling_20210425"
self.model_region_dir_p2 = dirs.dir_models + "/eynollah-main-regions-aug-rotation_20210425"
self.model_region_dir_fully_np = dirs.dir_models + "/eynollah-full-regions-1column_20210425"
self.model_region_dir_fully = dirs.dir_models + "/eynollah-full-regions-3+column_20210425"
self.model_page_dir = dirs.dir_models + "/eynollah-page-extraction_20210425"
self.model_region_dir_p_ens = dirs.dir_models + "/eynollah-main-regions-ensembled_20210425"
self.model_region_dir_p_ens_light = dirs.dir_models + "/eynollah-main-regions_20220314"
if self.textline_light:
self.model_textline_dir = dirs.dir_models + "/eynollah-textline_light_20210425"
else:
self.model_textline_dir = dirs.dir_models + "/eynollah-textline_20210425"
self.model_tables = dirs.dir_models + "/eynollah-tables_20210319"

self.models : dict[str, tf.keras.Model] = {}

if self.batch_processing_mode and light_version:
Expand All @@ -169,13 +151,13 @@ def __init__(
session = tf.compat.v1.Session(config=config)
set_session(session)

self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
self.model_textline = self.our_load_model(self.model_textline_dir)
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_page = self.our_load_model(self.dirs.model_page_dir)
self.model_classifier = self.our_load_model(self.dirs.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.dirs.model_dir_of_binarization)
self.model_textline = self.our_load_model(self.dirs.model_textline_dir)
self.model_region = self.our_load_model(self.dirs.model_region_dir_p_ens_light)
self.model_region_fl_np = self.our_load_model(self.dirs.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.dirs.model_region_dir_fully)

self.ls_imgs = listdir(self.dirs.dir_in)

Expand All @@ -185,15 +167,15 @@ def __init__(
session = tf.compat.v1.Session(config=config)
set_session(session)

self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
self.model_textline = self.our_load_model(self.model_textline_dir)
self.model_region = self.our_load_model(self.model_region_dir_p_ens)
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.model_page = self.our_load_model(self.dirs.model_page_dir)
self.model_classifier = self.our_load_model(self.dirs.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.dirs.model_dir_of_binarization)
self.model_textline = self.our_load_model(self.dirs.model_textline_dir)
self.model_region = self.our_load_model(self.dirs.model_region_dir_p_ens)
self.model_region_p2 = self.our_load_model(self.dirs.model_region_dir_p2)
self.model_region_fl_np = self.our_load_model(self.dirs.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.dirs.model_region_dir_fully)
self.model_enhancement = self.our_load_model(self.dirs.model_dir_of_enhancement)

self.ls_imgs = listdir(self.dirs.dir_in)

Expand Down Expand Up @@ -237,7 +219,7 @@ def isNaN(self, num):

def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement")
model_enhancement = self.load_model(self.model_dir_of_enhancement)
model_enhancement = self.load_model(self.dirs.model_dir_of_enhancement)

img_height_model = model_enhancement.layers[len(model_enhancement.layers) - 1].output_shape[1]
img_width_model = model_enhancement.layers[len(model_enhancement.layers) - 1].output_shape[2]
Expand Down Expand Up @@ -398,7 +380,7 @@ def resize_image_with_column_classifier(self, is_image_enhanced, img_bin):

_, page_coord = self.early_page_for_num_of_column_classification(img)
if not self.batch_processing_mode:
model_num_classifier = self.load_model(self.model_dir_of_col_classifier)
model_num_classifier = self.load_model(self.dirs.model_dir_of_col_classifier)
if self.input_binary:
img_in = np.copy(img)
img_in = img_in / 255.0
Expand Down Expand Up @@ -454,7 +436,7 @@ def resize_and_enhance_image_with_column_classifier(self,light_version):
prediction_bin = self.do_prediction(True, img, self.model_bin)
else:

model_bin = self.load_model(self.model_dir_of_binarization)
model_bin = self.load_model(self.dirs.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img, model_bin)

prediction_bin=prediction_bin[:,:,0]
Expand All @@ -473,7 +455,7 @@ def resize_and_enhance_image_with_column_classifier(self,light_version):
t1 = time.time()
_, page_coord = self.early_page_for_num_of_column_classification(img_bin)
if not self.batch_processing_mode:
model_num_classifier = self.load_model(self.model_dir_of_col_classifier)
model_num_classifier = self.load_model(self.dirs.model_dir_of_col_classifier)

if self.input_binary:
img_in = np.copy(img)
Expand Down Expand Up @@ -574,7 +556,7 @@ def get_image_and_scales_after_enhancing(self, img_org, img_res):
self.writer.width_org = self.width_org

def load_model(self, model_dir) -> tf.keras.Model:
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
self.logger.debug("enter load_model (model_dir=%s)", model_dir)
physical_devices = tf.config.list_physical_devices('GPU')
try:
for device in physical_devices:
Expand All @@ -597,6 +579,14 @@ def load_model(self, model_dir) -> tf.keras.Model:

return model

def our_load_model(self, model_file):
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
return model


def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
self.logger.debug("enter do_prediction")

Expand Down Expand Up @@ -892,7 +882,7 @@ def extract_page(self):
img = cv2.GaussianBlur(self.image, (5, 5), 0)

if not self.batch_processing_mode:
model_page = self.load_model(self.model_page_dir)
model_page = self.load_model(self.dirs.model_page_dir)

if not self.batch_processing_mode:
img_page_prediction = self.do_prediction(False, img, model_page)
Expand Down Expand Up @@ -940,7 +930,7 @@ def early_page_for_num_of_column_classification(self,img_bin):
else:
img = self.imread()
if not self.batch_processing_mode:
model_page = self.load_model(self.model_page_dir)
model_page = self.load_model(self.dirs.model_page_dir)
img = cv2.GaussianBlur(img, (5, 5), 0)

if self.batch_processing_mode:
Expand Down Expand Up @@ -973,7 +963,7 @@ def extract_text_regions(self, img, patches, cols):
img_height_h = img.shape[0]
img_width_h = img.shape[1]
if not self.batch_processing_mode:
model_region = self.load_model(self.model_region_dir_fully if patches else self.model_region_dir_fully_np)
model_region = self.load_model(self.dirs.model_region_dir_fully if patches else self.model_region_dir_fully_np)
else:
model_region = self.model_region_fl if patches else self.model_region_fl_np

Expand Down Expand Up @@ -1439,8 +1429,16 @@ def do_work_of_slopes_new(self, queue_of_all_params, boxes_text, textline_mask_t

def textline_contours(self, img, patches, scaler_h, scaler_w):
self.logger.debug('enter textline_contours')
# FIXME: If called in non-batch-procesing-mode, model_textline will be unbound
if not self.batch_processing_mode:
model_textline = self.load_model(self.model_textline_dir if patches else self.model_textline_dir_np)
# FIXME: model_textline_dir_np is not defined anywhere
if self.light_version:
# FIXME: What to use for light_version + patches?
model_dir = self.dirs.model_textline_dir_light
else:
# FIXME: What to use for non-light_version + patches?
model_dir = self.dirs.model_textline_dir
model_textline = self.load_model(model_dir)
img = img.astype(np.uint8)
img_org = np.copy(img)
img_h = img_org.shape[0]
Expand Down Expand Up @@ -1499,7 +1497,7 @@ def get_regions_light_v(self,img,is_image_enhanced, num_col_classifier):
img_resized = resize_image(img,img_h_new, img_w_new )

if not self.batch_processing_mode:
model_bin = self.load_model(self.model_dir_of_binarization)
model_bin = self.load_model(self.dirs.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_resized, model_bin)
else:
prediction_bin = self.do_prediction(True, img_resized, self.model_bin)
Expand All @@ -1518,7 +1516,7 @@ def get_regions_light_v(self,img,is_image_enhanced, num_col_classifier):
textline_mask_tot_ea = self.run_textline(img_bin)

if not self.batch_processing_mode:
model_region = self.load_model(self.model_region_dir_p_ens_light)
model_region = self.load_model(self.dirs.model_region_dir_p_ens_light)
prediction_regions_org = self.do_prediction_new_concept(True, img_bin, model_region)
else:
prediction_regions_org = self.do_prediction_new_concept(True, img_bin, self.model_region)
Expand Down Expand Up @@ -1563,7 +1561,7 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):
img_width_h = img_org.shape[1]

if not self.batch_processing_mode:
model_region = self.load_model(self.model_region_dir_p_ens)
model_region = self.load_model(self.dirs.model_region_dir_p_ens)

ratio_y=1.3
ratio_x=1
Expand Down Expand Up @@ -1602,7 +1600,7 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):


if not self.batch_processing_mode:
model_region = self.load_model(self.model_region_dir_p2)
model_region = self.load_model(self.dirs.model_region_dir_p2)

img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]))

Expand Down Expand Up @@ -1641,7 +1639,7 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):
prediction_bin = np.copy(img_org)
else:
if not self.batch_processing_mode:
model_bin = self.load_model(self.model_dir_of_binarization)
model_bin = self.load_model(self.dirs.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_org, model_bin)
else:
prediction_bin = self.do_prediction(True, img_org, self.model_bin)
Expand All @@ -1654,7 +1652,7 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):
prediction_bin =np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)

if not self.batch_processing_mode:
model_region = self.load_model(self.model_region_dir_p_ens)
model_region = self.load_model(self.dirs.model_region_dir_p_ens)
ratio_y=1
ratio_x=1

Expand Down Expand Up @@ -1694,7 +1692,7 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):
prediction_bin = np.copy(img_org)

if not self.batch_processing_mode:
model_bin = self.load_model(self.model_dir_of_binarization)
model_bin = self.load_model(self.dirs.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_org, model_bin)
else:
prediction_bin = self.do_prediction(True, img_org, self.model_bin)
Expand All @@ -1709,7 +1707,7 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):


if not self.batch_processing_mode:
model_region = self.load_model(self.model_region_dir_p_ens)
model_region = self.load_model(self.dirs.model_region_dir_p_ens)

else:
prediction_bin = np.copy(img_org)
Expand Down Expand Up @@ -2231,7 +2229,7 @@ def get_tables_from_model(self, img, num_col_classifier):
img_height_h = img_org.shape[0]
img_width_h = img_org.shape[1]

model_region = self.load_model(self.model_tables)
model_region = self.load_model(self.dirs.model_tables)

patches = False

Expand Down Expand Up @@ -2702,13 +2700,6 @@ def run_boxes_full_layout(self, image_page, textline_mask_tot, text_regions_p, s
self.logger.debug('exit run_boxes_full_layout')
return polygons_of_images, img_revised_tab, text_regions_p_1_n, textline_mask_tot_d, regions_without_separators_d, regions_fully, regions_without_separators, polygons_of_marginals, contours_tables

def our_load_model(self, model_file):
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
return model

def run(self):
"""
Get image and scales, then extract the page of scanned image
Expand Down
55 changes: 54 additions & 1 deletion qurator/eynollah/utils/dirs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Optional


@dataclass()
class EynollahDirs():
"""
Expand All @@ -16,4 +15,58 @@ class EynollahDirs():
dir_of_all : Optional[str] = None
dir_save_page : Optional[str] = None

@property
def model_dir_of_enhancement(self) -> str:
return self.dir_models + "/eynollah-enhancement_20210425"

@property
def model_dir_of_binarization(self) -> str:
return self.dir_models + "/eynollah-binarization_20210425"

@property
def model_dir_of_col_classifier(self) -> str:
return self.dir_models + "/eynollah-column-classifier_20210425"

@property
def model_region_dir_p2(self) -> str:
return self.dir_models + "/eynollah-main-regions-aug-rotation_20210425"

@property
def model_region_dir_fully_np(self) -> str:
return self.dir_models + "/eynollah-full-regions-1column_20210425"

@property
def model_region_dir_fully(self) -> str:
return self.dir_models + "/eynollah-full-regions-3+column_20210425"

@property
def model_page_dir(self) -> str:
return self.dir_models + "/eynollah-page-extraction_20210425"

@property
def model_region_dir_p_ens(self) -> str:
return self.dir_models + "/eynollah-main-regions-ensembled_20210425"

@property
def model_region_dir_p_ens_light(self) -> str:
return self.dir_models + "/eynollah-main-regions_20220314"

@property
def model_textline_dir(self) -> str:
return self.dir_models + "/eynollah-textline_20210425"

@property
def model_textline_dir_light(self) -> str:
return self.dir_models + "/eynollah-textline_light_20210425"

# FIXME: should have 'dir' in the name as well
@property
def model_tables(self) -> str:
return self.dir_models + "/eynollah-tables_20210319"

# FIXME: unused
@property
def model_region_dir_p(self):
return self.dir_models + "/eynollah-main-regions-aug-scaling_20210425"


0 comments on commit b954a55

Please sign in to comment.