diff --git a/qurator/eynollah/eynollah.py b/qurator/eynollah/eynollah.py index 2fe7325..72a72d9 100644 --- a/qurator/eynollah/eynollah.py +++ b/qurator/eynollah/eynollah.py @@ -252,7 +252,7 @@ def __init__( self.model_region_dir_p_ens = dir_models + "/eynollah-main-regions-ensembled_20210425" self.model_region_dir_p_ens_light = dir_models + "/eynollah-main-regions_20220314" self.model_reading_order_machine_dir = dir_models + "/model_ens_reading_order_machine_based" - self.model_region_dir_p_1_2_sp_np = dir_models + "/modelens_earlylayout_12spaltige_2_3_5_6_7_8"#"/modelens_1_2_4_5_early_lay_1_2_spaltige"#"/model_3_eraly_layout_no_patches_1_2_spaltige" + self.model_region_dir_p_1_2_sp_np = dir_models + "/modelens_earlylayout_12spaltige_2_3_5_6_7_8"#"/modelens_earlylayout_12spaltige_2_3_5_6_7_8"#"/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"#"/modelens_1_2_4_5_early_lay_1_2_spaltige"#"/model_3_eraly_layout_no_patches_1_2_spaltige" ##self.model_region_dir_fully_new = dir_models + "/model_2_full_layout_new_trans" self.model_region_dir_fully = dir_models + "/modelens_full_layout_24_till_28"#"/model_2_full_layout_new_trans" if self.textline_light: @@ -541,6 +541,7 @@ def resize_image_with_column_classifier(self, is_image_enhanced, img_bin): img = self.imread() _, page_coord = self.early_page_for_num_of_column_classification(img) + if not self.dir_in: model_num_classifier, session_col_classifier = self.start_new_session_and_model(self.model_dir_of_col_classifier) if self.input_binary: @@ -611,6 +612,10 @@ def resize_and_enhance_image_with_column_classifier(self,light_version): width_early = img.shape[1] t1 = time.time() _, page_coord = self.early_page_for_num_of_column_classification(img_bin) + + self.image_page_org_size = img[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3], :] + self.page_coord = page_coord + if not self.dir_in: model_num_classifier, session_col_classifier = self.start_new_session_and_model(self.model_dir_of_col_classifier) @@ -737,7 +742,7 @@ def get_image_and_scales(self, img_org, img_res, scale): def get_image_and_scales_after_enhancing(self, img_org, img_res): self.logger.debug("enter get_image_and_scales_after_enhancing") self.image = np.copy(img_res) - self.image = self.image.astype(np.uint8) + #self.image = self.image.astype(np.uint8) self.image_org = np.copy(img_org) self.height_org = self.image_org.shape[0] self.width_org = self.image_org.shape[1] @@ -1059,19 +1064,18 @@ def do_prediction_new_concept(self, patches, img, model, n_batch_inference=1, ma if not patches: img_h_page = img.shape[0] img_w_page = img.shape[1] - img = img / float(255.0) + img = img / 255.0 img = resize_image(img, img_height_model, img_width_model) label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]), verbose=0) - - #seg_not_base = label_p_pred[0,:,:,4] - - #seg_not_base[seg_not_base>0.4] =1 - #seg_not_base[seg_not_base<1] =0 - seg = np.argmax(label_p_pred, axis=3)[0] - #seg[seg_not_base==1]=4 + if thresholding_for_artificial_class_in_light_version: + seg_art = label_p_pred[0,:,:,4] + seg_art[seg_art<0.1] =0 + seg_art[seg_art>0] =1 + seg[seg_art==1]=4 + seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) prediction_true = resize_image(seg_color, img_h_page, img_w_page) @@ -2151,7 +2155,7 @@ def get_regions_light_v(self,img,is_image_enhanced, num_col_classifier, skip_lay #print(num_col_classifier,'num_col_classifier') if num_col_classifier == 1: - img_w_new = 1000 + img_w_new = 800 img_h_new = int(img_org.shape[0] / float(img_org.shape[1]) * img_w_new) elif num_col_classifier == 2: @@ -2206,29 +2210,39 @@ def get_regions_light_v(self,img,is_image_enhanced, num_col_classifier, skip_lay textline_mask_tot_ea = resize_image(textline_mask_tot_ea,img_height_h, img_width_h ) + + #print(self.image_org.shape) + + #plt.imshwo(self.image_page_org_size) + #plt.show() if not skip_layout_and_reading_order: #print("inside 2 ", time.time()-t_in) - #print(img_resized.shape, num_col_classifier, "num_col_classifier") if not self.dir_in: if num_col_classifier == 1 or num_col_classifier == 2: + prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np) - prediction_regions_org = self.do_prediction_new_concept(False, img_resized, model_region, n_batch_inference=1) + prediction_regions_page = self.do_prediction_new_concept(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = False) + prediction_regions_org[self.page_coord[0] : self.page_coord[1], self.page_coord[2] : self.page_coord[3],:] = prediction_regions_page else: - model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens_light, n_batch_inference=3) + model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens_light) prediction_regions_org = self.do_prediction_new_concept(True, img_bin, model_region) ##model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens_light) ##prediction_regions_org = self.do_prediction(True, img_bin, model_region, n_batch_inference=3, thresholding_for_some_classes_in_light_version=True) else: if num_col_classifier == 1 or num_col_classifier == 2: - prediction_regions_org = self.do_prediction_new_concept(False, img_resized, self.model_region_1_2, n_batch_inference=1) + prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) + prediction_regions_page = self.do_prediction_new_concept(False, self.image_page_org_size, self.model_region_1_2, n_batch_inference=1, thresholding_for_artificial_class_in_light_version=False) + prediction_regions_org[self.page_coord[0] : self.page_coord[1], self.page_coord[2] : self.page_coord[3],:] = prediction_regions_page else: prediction_regions_org = self.do_prediction_new_concept(True, img_bin, self.model_region, n_batch_inference=3) ###prediction_regions_org = self.do_prediction(True, img_bin, self.model_region, n_batch_inference=3, thresholding_for_some_classes_in_light_version=True) #print("inside 3 ", time.time()-t_in) + #plt.imshow(prediction_regions_org[:,:,0]) #plt.show() + prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h ) @@ -3195,7 +3209,7 @@ def run_enhancement(self,light_version): scale = 1 if is_image_enhanced: if self.allow_enhancement: - img_res = img_res.astype(np.uint8) + #img_res = img_res.astype(np.uint8) self.get_image_and_scales(img_org, img_res, scale) if self.plotter: self.plotter.save_enhanced_image(img_res)