From ad7bf8ed0d48a1d0d457022f44c71d31f49d000b Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Mon, 28 Oct 2024 17:39:49 +0800 Subject: [PATCH 1/3] feat: add yolo cls model --- table_cls/main.py | 70 ++++++++++++++++++++++++------------ table_cls/utils.py | 90 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 23 deletions(-) diff --git a/table_cls/main.py b/table_cls/main.py index 9ad7820..d6ddb66 100644 --- a/table_cls/main.py +++ b/table_cls/main.py @@ -3,32 +3,46 @@ import cv2 import numpy as np -import onnxruntime from PIL import Image -from .utils import InputType, LoadImage +from .utils import InputType, LoadImage, OrtInferSession, ResizePad cur_dir = Path(__file__).resolve().parent -table_cls_model_path = cur_dir / "models" / "table_cls.onnx" +q_cls_model_path = cur_dir / "models" / "table_cls.onnx" +yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx" class TableCls: - def __init__(self, device="cpu"): - providers = ( - ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] - ) - self.table_cls = onnxruntime.InferenceSession( - table_cls_model_path, providers=providers - ) + def __init__(self, model="yolo"): + if model == "yolo": + self.table_engine = YoloCls() + else: + self.table_engine = QanythingCls() + self.load_img = LoadImage() + + def __call__(self, content: InputType): + ss = time.perf_counter() + img = self.load_img(content) + img = self.table_engine.preprocess(img) + predict_cla = self.table_engine([img]) + table_elapse = time.perf_counter() - ss + return predict_cla, table_elapse + + +class QanythingCls: + def __init__(self): + self.table_cls = OrtInferSession(q_cls_model_path) self.inp_h = 224 self.inp_w = 224 self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32) self.cls = {0: "wired", 1: "wireless"} - self.load_img = LoadImage() - def _preprocess(self, image): - img = Image.fromarray(np.uint8(image)) + def preprocess(self, img): + img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2RGB) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img = np.stack((img,) * 3, axis=-1) + img = Image.fromarray(np.uint8(img)) img = img.resize((self.inp_h, self.inp_w)) img = np.array(img, dtype=np.float32) / 255.0 img -= self.mean @@ -37,15 +51,27 @@ def _preprocess(self, image): img = np.expand_dims(img, axis=0) # Add batch dimension, only one image return img - def __call__(self, content: InputType): - ss = time.perf_counter() - img = self.load_img(content) - gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - gray_img = np.stack((gray_img,) * 3, axis=-1) - gray_img = self._preprocess(gray_img) - output = self.table_cls.run(None, {"input": gray_img}) + def __call__(self, img): + output = self.table_cls(img) predict = np.exp(output[0] - np.max(output[0], axis=1, keepdims=True)) predict /= np.sum(predict, axis=1, keepdims=True) predict_cla = np.argmax(predict, axis=1)[0] - table_elapse = time.perf_counter() - ss - return self.cls[predict_cla], table_elapse + return self.cls[predict_cla] + + +class YoloCls: + def __init__(self): + self.table_cls = OrtInferSession(yolo_cls_model_path) + self.cls = {0: "wireless", 1: "wired"} + + def preprocess(self, img): + img, *_ = ResizePad(img, 640) + img = np.array(img, dtype=np.float32) / 255.0 + img = img.transpose(2, 0, 1) # HWC to CHW + img = np.expand_dims(img, axis=0) # Add batch dimension, only one image + return img + + def __call__(self, img): + output = self.table_cls(img) + predict_cla = np.argmax(output[0], axis=1)[0] + return self.cls[predict_cla] diff --git a/table_cls/utils.py b/table_cls/utils.py index 2b9288b..9df30f7 100644 --- a/table_cls/utils.py +++ b/table_cls/utils.py @@ -1,14 +1,86 @@ +import traceback from io import BytesIO from pathlib import Path -from typing import Union +from typing import Union, List import cv2 import numpy as np from PIL import Image, UnidentifiedImageError +from onnxruntime import InferenceSession +from onnxruntime.capi.onnxruntime_pybind11_state import ( + SessionOptions, + GraphOptimizationLevel, +) InputType = Union[str, np.ndarray, bytes, Path, Image.Image] +class OrtInferSession: + def __init__(self, model_path: Union[str, Path], num_threads: int = -1): + self.verify_exist(model_path) + + self.num_threads = num_threads + self._init_sess_opt() + + cpu_ep = "CPUExecutionProvider" + cpu_provider_options = { + "arena_extend_strategy": "kSameAsRequested", + } + EP_list = [(cpu_ep, cpu_provider_options)] + try: + self.session = InferenceSession( + str(model_path), sess_options=self.sess_opt, providers=EP_list + ) + except TypeError: + # 这里兼容ort 1.5.2 + self.session = InferenceSession(str(model_path), sess_options=self.sess_opt) + + def _init_sess_opt(self): + self.sess_opt = SessionOptions() + self.sess_opt.log_severity_level = 4 + self.sess_opt.enable_cpu_mem_arena = False + + if self.num_threads != -1: + self.sess_opt.intra_op_num_threads = self.num_threads + + self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), input_content)) + try: + return self.session.run(None, input_dict) + except Exception as e: + error_info = traceback.format_exc() + raise ONNXRuntimeError(error_info) from e + + def get_input_names( + self, + ): + return [v.name for v in self.session.get_inputs()] + + def get_output_name(self, output_idx=0): + return self.session.get_outputs()[output_idx].name + + def get_metadata(self): + meta_dict = self.session.get_modelmeta().custom_metadata_map + return meta_dict + + @staticmethod + def verify_exist(model_path: Union[Path, str]): + if not isinstance(model_path, Path): + model_path = Path(model_path) + + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exist!") + + if not model_path.is_file(): + raise FileExistsError(f"{model_path} must be a file") + + +class ONNXRuntimeError(Exception): + pass + + class LoadImageError(Exception): pass @@ -106,3 +178,19 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray: def verify_exist(file_path: Union[str, Path]): if not Path(file_path).exists(): raise LoadImageError(f"{file_path} does not exist.") + + +def ResizePad(img, target_size): + h, w = img.shape[:2] + m = max(h, w) + ratio = target_size / m + new_w, new_h = int(ratio * w), int(ratio * h) + img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR) + top = (target_size - new_h) // 2 + bottom = (target_size - new_h) - top + left = (target_size - new_w) // 2 + right = (target_size - new_w) - left + img1 = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) + return img1, new_w, new_h, left, top From 490f32804f5de288e42b9d2c0f704d8c9b15bfd4 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Mon, 28 Oct 2024 17:47:12 +0800 Subject: [PATCH 2/3] fix: fix logic col calculate --- wired_table_rec/table_recover.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/wired_table_rec/table_recover.py b/wired_table_rec/table_recover.py index 72f6c83..8e89672 100644 --- a/wired_table_rec/table_recover.py +++ b/wired_table_rec/table_recover.py @@ -93,14 +93,20 @@ def get_benchmark_cols( sorted(range_res.items(), key=lambda x: x[0], reverse=True) ) for k, v in sorted_res.items(): - if not all(v): - continue - - longest_x = np.insert(longest_x, v[1], cur_row[k]) - longest_col_points = np.insert( - longest_col_points, v[1], polygons[row_value[k]], axis=0 - ) - + # bugfix: https://github.com/RapidAI/TableStructureRec/discussions/55 + # 最长列不包含第一列和最后一列的场景需要兼容 + if all(v) or v[1] == 0: + longest_x = np.insert(longest_x, v[1], cur_row[k]) + longest_col_points = np.insert( + longest_col_points, v[1], polygons[row_value[k]], axis=0 + ) + elif v[0] and v[0] + 1 == len(longest_x): + longest_x = np.append(longest_x, cur_row[k]) + longest_col_points = np.append( + longest_col_points, + polygons[row_value[k]][np.newaxis, :, :], + axis=0, + ) # 求出最右侧所有cell的宽,其中最小的作为最后一列宽度 rightmost_idxs = [v[-1] for v in rows.values()] rightmost_boxes = polygons[rightmost_idxs] From 9ea4a74f52de788e2e4aed264cfe3c23975c78db Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Mon, 28 Oct 2024 18:01:35 +0800 Subject: [PATCH 3/3] feat: optim param use for table cls --- table_cls/main.py | 17 +++++++++-------- tests/test_table_cls.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/table_cls/main.py b/table_cls/main.py index d6ddb66..ca7ab4c 100644 --- a/table_cls/main.py +++ b/table_cls/main.py @@ -13,11 +13,12 @@ class TableCls: - def __init__(self, model="yolo"): - if model == "yolo": - self.table_engine = YoloCls() + def __init__(self, model_type="yolo", model_path=yolo_cls_model_path): + if model_type == "yolo": + self.table_engine = YoloCls(model_path) else: - self.table_engine = QanythingCls() + model_path = q_cls_model_path + self.table_engine = QanythingCls(model_path) self.load_img = LoadImage() def __call__(self, content: InputType): @@ -30,8 +31,8 @@ def __call__(self, content: InputType): class QanythingCls: - def __init__(self): - self.table_cls = OrtInferSession(q_cls_model_path) + def __init__(self, model_path): + self.table_cls = OrtInferSession(model_path) self.inp_h = 224 self.inp_w = 224 self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) @@ -60,8 +61,8 @@ def __call__(self, img): class YoloCls: - def __init__(self): - self.table_cls = OrtInferSession(yolo_cls_model_path) + def __init__(self, model_path): + self.table_cls = OrtInferSession(model_path) self.cls = {0: "wireless", 1: "wired"} def preprocess(self, img): diff --git a/tests/test_table_cls.py b/tests/test_table_cls.py index d9d1813..b5c5611 100644 --- a/tests/test_table_cls.py +++ b/tests/test_table_cls.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize( "img_path, expected", - [("wired_table.png", "wired"), ("lineless_table.png", "wireless")], + [("wired_table.jpg", "wired"), ("lineless_table.png", "wireless")], ) def test_input_normal(img_path, expected): img_path = test_file_dir / img_path