diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index bed0c03..4bf5257 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -1,7 +1,7 @@ import sys import click from ocrd_utils import initLogging, setOverrideLogLevel -from eynollah.eynollah import Eynollah +from eynollah.eynollah import Eynollah, Eynollah_ocr from eynollah.sbb_binarize import SbbBinarizer @click.group() @@ -305,6 +305,60 @@ def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, s else: pcgts = eynollah.run() eynollah.writer.write_pagexml(pcgts) + + +@main.command() +@click.option( + "--dir_in", + "-di", + help="directory of images", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--out", + "-o", + help="directory to write output xml data", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--dir_xmls", + "-dx", + help="directory of xmls", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--model", + "-m", + help="directory of models", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--tr_ocr", + "-trocr/-notrocr", + is_flag=True, + help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.", +) +@click.option( + "--log_level", + "-l", + type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), + help="Override log level globally to this", +) + +def ocr(dir_in, out, dir_xmls, model, tr_ocr, log_level): + if log_level: + setOverrideLogLevel(log_level) + initLogging() + eynollah_ocr = Eynollah_ocr( + dir_xmls=dir_xmls, + dir_in=dir_in, + dir_out=out, + dir_models=model, + tr_ocr=tr_ocr, + ) + eynollah_ocr.run() if __name__ == "__main__": main() diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 006bfea..95033e9 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -41,6 +41,9 @@ # use tf1 compatibility for keras backend from tensorflow.compat.v1.keras.backend import set_session from tensorflow.keras import layers +import json +import xml.etree.ElementTree as ET +from tensorflow.keras.layers import StringLookup from .utils.contour import ( filter_contours_area_of_image, @@ -2188,18 +2191,18 @@ def textline_contours(self, img, patches, scaler_h, scaler_w, num_col_classifier img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) if not self.dir_in: - prediction_textline = self.do_prediction(patches, img, model_textline, marginal_of_patch_percent=0.15, n_batch_inference=3, thresholding_for_artificial_class_in_light_version=thresholding_for_artificial_class_in_light_version) + ###prediction_textline = self.do_prediction(patches, img, model_textline, marginal_of_patch_percent=0.15, n_batch_inference=3, thresholding_for_artificial_class_in_light_version=thresholding_for_artificial_class_in_light_version) - ##prediction_textline = self.do_prediction_new_concept_scatter_nd(patches, img, model_textline, n_batch_inference=3) + prediction_textline = self.do_prediction_new_concept_scatter_nd(patches, img, model_textline, n_batch_inference=3) #if not thresholding_for_artificial_class_in_light_version: #if num_col_classifier==1: #prediction_textline_nopatch = self.do_prediction(False, img, model_textline) #prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0 else: - prediction_textline = self.do_prediction(patches, img, self.model_textline, marginal_of_patch_percent=0.15, n_batch_inference=3,thresholding_for_artificial_class_in_light_version=thresholding_for_artificial_class_in_light_version) + ##prediction_textline = self.do_prediction(patches, img, self.model_textline, marginal_of_patch_percent=0.15, n_batch_inference=3,thresholding_for_artificial_class_in_light_version=thresholding_for_artificial_class_in_light_version) - ###prediction_textline = self.do_prediction_new_concept_scatter_nd(patches, img, self.model_textline, n_batch_inference=3) + prediction_textline = self.do_prediction_new_concept_scatter_nd(patches, img, self.model_textline, n_batch_inference=3) #if not thresholding_for_artificial_class_in_light_version: #if num_col_classifier==1: #prediction_textline_nopatch = self.do_prediction(False, img, model_textline) @@ -2479,17 +2482,17 @@ def get_regions_light_v(self,img,is_image_enhanced, num_col_classifier, skip_lay if num_col_classifier == 1 or num_col_classifier == 2: model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np) if self.image_org.shape[0]/self.image_org.shape[1] > 2.5: - ##prediction_regions_org = self.do_prediction_new_concept_scatter_nd(True, img_resized, model_region, n_batch_inference=1, thresholding_for_some_classes_in_light_version = True) - prediction_regions_org = self.do_prediction_new_concept(True, img_resized, model_region, n_batch_inference=1, thresholding_for_some_classes_in_light_version = True) + prediction_regions_org = self.do_prediction_new_concept_scatter_nd(True, img_resized, model_region, n_batch_inference=1, thresholding_for_some_classes_in_light_version = True) + ###prediction_regions_org = self.do_prediction_new_concept(True, img_resized, model_region, n_batch_inference=1, thresholding_for_some_classes_in_light_version = True) else: prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) - ##prediction_regions_page = self.do_prediction_new_concept_scatter_nd(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = True) - prediction_regions_page = self.do_prediction_new_concept(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = True) + prediction_regions_page = self.do_prediction_new_concept_scatter_nd(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = True) + ##prediction_regions_page = self.do_prediction_new_concept(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = True) prediction_regions_org[self.page_coord[0] : self.page_coord[1], self.page_coord[2] : self.page_coord[3],:] = prediction_regions_page else: model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np) - prediction_regions_org = self.do_prediction_new_concept(True, resize_image(img_bin, int( (900+ (num_col_classifier-3)*100) *(img_bin.shape[0]/img_bin.shape[1]) ), 900+ (num_col_classifier-3)*100), model_region, n_batch_inference=2, thresholding_for_some_classes_in_light_version=True) - ###prediction_regions_org = self.do_prediction_new_concept_scatter_nd(True, resize_image(img_bin, int( (900+ (num_col_classifier-3)*100) *(img_bin.shape[0]/img_bin.shape[1]) ), 900+ (num_col_classifier-3)*100), model_region, n_batch_inference=2, thresholding_for_some_classes_in_light_version=True) + ###prediction_regions_org = self.do_prediction_new_concept(True, resize_image(img_bin, int( (900+ (num_col_classifier-3)*100) *(img_bin.shape[0]/img_bin.shape[1]) ), 900+ (num_col_classifier-3)*100), model_region, n_batch_inference=2, thresholding_for_some_classes_in_light_version=True) + prediction_regions_org = self.do_prediction_new_concept_scatter_nd(True, resize_image(img_bin, int( (900+ (num_col_classifier-3)*100) *(img_bin.shape[0]/img_bin.shape[1]) ), 900+ (num_col_classifier-3)*100), model_region, n_batch_inference=2, thresholding_for_some_classes_in_light_version=True) ##model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens_light) ##prediction_regions_org = self.do_prediction(True, img_bin, model_region, n_batch_inference=3, thresholding_for_some_classes_in_light_version=True) else: @@ -5610,3 +5613,394 @@ def run(self): if self.dir_in: self.logger.info("All jobs done in %.1fs", time.time() - t0_tot) print("all Job done in %.1fs", time.time() - t0_tot) + + +class Eynollah_ocr: + def __init__( + self, + dir_models, + dir_xmls=None, + dir_in=None, + dir_out=None, + tr_ocr=False, + logger=None, + ): + self.dir_in = dir_in + self.dir_out = dir_out + self.dir_xmls = dir_xmls + self.dir_models = dir_models + self.tr_ocr = tr_ocr + if tr_ocr: + self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124" + self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) + self.model_ocr.to(self.device) + + else: + self.model_ocr_dir = dir_models + "/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn" + model_ocr = load_model(self.model_ocr_dir , compile=False) + + self.prediction_model = tf.keras.models.Model( + model_ocr.get_layer(name = "image").input, + model_ocr.get_layer(name = "dense2").output) + + + with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file: + characters = json.load(config_file) + + + AUTOTUNE = tf.data.AUTOTUNE + + # Mapping characters to integers. + char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) + + # Mapping integers back to original characters. + self.num_to_char = StringLookup( + vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True + ) + + def decode_batch_predictions(self, pred, max_len = 128): + # input_len is the product of the batch size and the + # number of time steps. + input_len = np.ones(pred.shape[0]) * pred.shape[1] + + # Decode CTC predictions using greedy search. + # decoded is a tuple with 2 elements. + decoded = tf.keras.backend.ctc_decode(pred, + input_length = input_len, + beam_width = 100) + # The outputs are in the first element of the tuple. + # Additionally, the first element is actually a list, + # therefore we take the first element of that list as well. + #print(decoded,'decoded') + decoded = decoded[0][0][:, :max_len] + + #print(decoded, decoded.shape,'decoded') + + output = [] + for d in decoded: + # Convert the predicted indices to the corresponding chars. + d = tf.strings.reduce_join(self.num_to_char(d)) + d = d.numpy().decode("utf-8") + output.append(d) + return output + + + def distortion_free_resize(self, image, img_size): + w, h = img_size + image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True) + + # Check tha amount of padding needed to be done. + pad_height = h - tf.shape(image)[0] + pad_width = w - tf.shape(image)[1] + + # Only necessary if you want to do same amount of padding on both sides. + if pad_height % 2 != 0: + height = pad_height // 2 + pad_height_top = height + 1 + pad_height_bottom = height + else: + pad_height_top = pad_height_bottom = pad_height // 2 + + if pad_width % 2 != 0: + width = pad_width // 2 + pad_width_left = width + 1 + pad_width_right = width + else: + pad_width_left = pad_width_right = pad_width // 2 + + image = tf.pad( + image, + paddings=[ + [pad_height_top, pad_height_bottom], + [pad_width_left, pad_width_right], + [0, 0], + ], + ) + + image = tf.transpose(image, (1, 0, 2)) + image = tf.image.flip_left_right(image) + return image + + def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self, textline_image): + width = np.shape(textline_image)[1] + height = np.shape(textline_image)[0] + common_window = int(0.06*width) + + width1 = int ( width/2. - common_window ) + width2 = int ( width/2. + common_window ) + + img_sum = np.sum(textline_image[:,:,0], axis=0) + sum_smoothed = gaussian_filter1d(img_sum, 3) + + peaks_real, _ = find_peaks(sum_smoothed, height=0) + + if len(peaks_real)>70: + + peaks_real = peaks_real[(peaks_realwidth1)] + + arg_max = np.argmax(sum_smoothed[peaks_real]) + + peaks_final = peaks_real[arg_max] + + return peaks_final + else: + return None + + def return_textlines_split_if_needed(self, textline_image): + + split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image) + if split_point: + image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height)) + image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height)) + return [image1, image2] + else: + return None + + def run(self): + ls_imgs = os.listdir(self.dir_in) + + if self.tr_ocr: + b_s = 2 + for ind_img in ls_imgs: + t0 = time.time() + file_name = ind_img.split('.')[0] + dir_img = os.path.join(self.dir_in, ind_img) + dir_xml = os.path.join(self.dir_xmls, file_name+'.xml') + out_file_ocr = os.path.join(self.dir_out, file_name+'.xml') + img = cv2.imread(dir_img) + + ##file_name = Path(dir_xmls).stem + tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding = 'iso-8859-5')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) + + + + cropped_lines = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + indexer_text_region = 0 + for nn in root1.iter(region_tags): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) + x,y,w,h = cv2.boundingRect(textline_coords) + + h2w_ratio = h/float(w) + + img_poly_on_img = np.copy(img) + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + img_crop[mask_poly==0] = 255 + + if h2w_ratio > 0.05: + cropped_lines.append(img_crop) + cropped_lines_meging_indexing.append(0) + else: + splited_images = self.return_textlines_split_if_needed(img_crop) + #print(splited_images) + if splited_images: + cropped_lines.append(splited_images[0]) + cropped_lines_meging_indexing.append(1) + cropped_lines.append(splited_images[1]) + cropped_lines_meging_indexing.append(-1) + else: + cropped_lines.append(img_crop) + cropped_lines_meging_indexing.append(0) + indexer_text_region = indexer_text_region +1 + + + extracted_texts = [] + n_iterations = math.ceil(len(cropped_lines) / b_s) + + for i in range(n_iterations): + if i==(n_iterations-1): + n_start = i*b_s + imgs = cropped_lines[n_start:] + else: + n_start = i*b_s + n_end = (i+1)*b_s + imgs = cropped_lines[n_start:n_end] + pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))] + + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + #print(extracted_texts_merged, len(extracted_texts_merged)) + + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + + #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + extracted_texts_merged_un = np.array(extracted_texts_merged)[np.array(cropped_lines_region_indexer)==ind] + + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + #print(len(text_by_textregion) , indexer_text_region, "text_by_textregion") + + + #print(time.time() - t0 ,'elapsed time') + + + indexer = 0 + indexer_textregion = 0 + for nn in root1.iter(region_tags): + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + indexer = indexer + 1 + has_textline = True + if has_textline: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + + + ET.register_namespace("",name_space) + tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + #print("Job done in %.1fs", time.time() - t0) + else: + max_len = 512 + padding_token = 299 + image_width = max_len * 4 + image_height = 32 + b_s = 8 + + + img_size=(image_width, image_height) + + for ind_img in ls_imgs: + t0 = time.time() + file_name = ind_img.split('.')[0] + dir_img = os.path.join(self.dir_in, ind_img) + dir_xml = os.path.join(self.dir_xmls, file_name+'.xml') + out_file_ocr = os.path.join(self.dir_out, file_name+'.xml') + img = cv2.imread(dir_img) + + tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding = 'iso-8859-5')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) + + cropped_lines = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + tinl = time.time() + indexer_text_region = 0 + for nn in root1.iter(region_tags): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) + x,y,w,h = cv2.boundingRect(textline_coords) + + h2w_ratio = h/float(w) + + img_poly_on_img = np.copy(img) + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + img_crop[mask_poly==0] = 255 + img_crop = tf.reverse(img_crop,axis=[-1]) + img_crop = self.distortion_free_resize(img_crop, img_size) + img_crop = tf.cast(img_crop, tf.float32) / 255.0 + cropped_lines.append(img_crop) + + indexer_text_region = indexer_text_region +1 + + + extracted_texts = [] + + n_iterations = math.ceil(len(cropped_lines) / b_s) + + for i in range(n_iterations): + if i==(n_iterations-1): + n_start = i*b_s + imgs = cropped_lines[n_start:] + imgs = np.array(imgs) + imgs = imgs.reshape(imgs.shape[0], image_width, image_height, 3) + else: + n_start = i*b_s + n_end = (i+1)*b_s + imgs = cropped_lines[n_start:n_end] + imgs = np.array(imgs).reshape(b_s, image_width, image_height, 3) + + + preds = self.prediction_model.predict(imgs, verbose=0) + pred_texts = self.decode_batch_predictions(preds) + + for ib in range(imgs.shape[0]): + pred_texts_ib = pred_texts[ib].strip("[UNK]") + extracted_texts.append(pred_texts_ib) + + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + extracted_texts_merged_un = np.array(extracted_texts)[np.array(cropped_lines_region_indexer)==ind] + + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + indexer = 0 + indexer_textregion = 0 + for nn in root1.iter(region_tags): + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts[indexer] + indexer = indexer + 1 + has_textline = True + if has_textline: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + ET.register_namespace("",name_space) + tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + #print("Job done in %.1fs", time.time() - t0)