Skip to content

Commit

Permalink
Merge pull request #59 from RapidAI/optim_cls_model
Browse files Browse the repository at this point in the history
Optim cls model
  • Loading branch information
SWHL authored Oct 28, 2024
2 parents 87034f1 + 9ea4a74 commit 510dd16
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 32 deletions.
71 changes: 49 additions & 22 deletions table_cls/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,47 @@

import cv2
import numpy as np
import onnxruntime
from PIL import Image

from .utils import InputType, LoadImage
from .utils import InputType, LoadImage, OrtInferSession, ResizePad

cur_dir = Path(__file__).resolve().parent
table_cls_model_path = cur_dir / "models" / "table_cls.onnx"
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"


class TableCls:
def __init__(self, device="cpu"):
providers = (
["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
)
self.table_cls = onnxruntime.InferenceSession(
table_cls_model_path, providers=providers
)
def __init__(self, model_type="yolo", model_path=yolo_cls_model_path):
if model_type == "yolo":
self.table_engine = YoloCls(model_path)
else:
model_path = q_cls_model_path
self.table_engine = QanythingCls(model_path)
self.load_img = LoadImage()

def __call__(self, content: InputType):
ss = time.perf_counter()
img = self.load_img(content)
img = self.table_engine.preprocess(img)
predict_cla = self.table_engine([img])
table_elapse = time.perf_counter() - ss
return predict_cla, table_elapse


class QanythingCls:
def __init__(self, model_path):
self.table_cls = OrtInferSession(model_path)
self.inp_h = 224
self.inp_w = 224
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
self.cls = {0: "wired", 1: "wireless"}
self.load_img = LoadImage()

def _preprocess(self, image):
img = Image.fromarray(np.uint8(image))
def preprocess(self, img):
img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2RGB)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = np.stack((img,) * 3, axis=-1)
img = Image.fromarray(np.uint8(img))
img = img.resize((self.inp_h, self.inp_w))
img = np.array(img, dtype=np.float32) / 255.0
img -= self.mean
Expand All @@ -37,15 +52,27 @@ def _preprocess(self, image):
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
return img

def __call__(self, content: InputType):
ss = time.perf_counter()
img = self.load_img(content)
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = np.stack((gray_img,) * 3, axis=-1)
gray_img = self._preprocess(gray_img)
output = self.table_cls.run(None, {"input": gray_img})
def __call__(self, img):
output = self.table_cls(img)
predict = np.exp(output[0] - np.max(output[0], axis=1, keepdims=True))
predict /= np.sum(predict, axis=1, keepdims=True)
predict_cla = np.argmax(predict, axis=1)[0]
table_elapse = time.perf_counter() - ss
return self.cls[predict_cla], table_elapse
return self.cls[predict_cla]


class YoloCls:
def __init__(self, model_path):
self.table_cls = OrtInferSession(model_path)
self.cls = {0: "wireless", 1: "wired"}

def preprocess(self, img):
img, *_ = ResizePad(img, 640)
img = np.array(img, dtype=np.float32) / 255.0
img = img.transpose(2, 0, 1) # HWC to CHW
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
return img

def __call__(self, img):
output = self.table_cls(img)
predict_cla = np.argmax(output[0], axis=1)[0]
return self.cls[predict_cla]
90 changes: 89 additions & 1 deletion table_cls/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,86 @@
import traceback
from io import BytesIO
from pathlib import Path
from typing import Union
from typing import Union, List

import cv2
import numpy as np
from PIL import Image, UnidentifiedImageError
from onnxruntime import InferenceSession
from onnxruntime.capi.onnxruntime_pybind11_state import (
SessionOptions,
GraphOptimizationLevel,
)

InputType = Union[str, np.ndarray, bytes, Path, Image.Image]


class OrtInferSession:
def __init__(self, model_path: Union[str, Path], num_threads: int = -1):
self.verify_exist(model_path)

self.num_threads = num_threads
self._init_sess_opt()

cpu_ep = "CPUExecutionProvider"
cpu_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = [(cpu_ep, cpu_provider_options)]
try:
self.session = InferenceSession(
str(model_path), sess_options=self.sess_opt, providers=EP_list
)
except TypeError:
# 这里兼容ort 1.5.2
self.session = InferenceSession(str(model_path), sess_options=self.sess_opt)

def _init_sess_opt(self):
self.sess_opt = SessionOptions()
self.sess_opt.log_severity_level = 4
self.sess_opt.enable_cpu_mem_arena = False

if self.num_threads != -1:
self.sess_opt.intra_op_num_threads = self.num_threads

self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL

def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
return self.session.run(None, input_dict)
except Exception as e:
error_info = traceback.format_exc()
raise ONNXRuntimeError(error_info) from e

def get_input_names(
self,
):
return [v.name for v in self.session.get_inputs()]

def get_output_name(self, output_idx=0):
return self.session.get_outputs()[output_idx].name

def get_metadata(self):
meta_dict = self.session.get_modelmeta().custom_metadata_map
return meta_dict

@staticmethod
def verify_exist(model_path: Union[Path, str]):
if not isinstance(model_path, Path):
model_path = Path(model_path)

if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exist!")

if not model_path.is_file():
raise FileExistsError(f"{model_path} must be a file")


class ONNXRuntimeError(Exception):
pass


class LoadImageError(Exception):
pass

Expand Down Expand Up @@ -106,3 +178,19 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
raise LoadImageError(f"{file_path} does not exist.")


def ResizePad(img, target_size):
h, w = img.shape[:2]
m = max(h, w)
ratio = target_size / m
new_w, new_h = int(ratio * w), int(ratio * h)
img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR)
top = (target_size - new_h) // 2
bottom = (target_size - new_h) - top
left = (target_size - new_w) // 2
right = (target_size - new_w) - left
img1 = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
)
return img1, new_w, new_h, left, top
2 changes: 1 addition & 1 deletion tests/test_table_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@pytest.mark.parametrize(
"img_path, expected",
[("wired_table.png", "wired"), ("lineless_table.png", "wireless")],
[("wired_table.jpg", "wired"), ("lineless_table.png", "wireless")],
)
def test_input_normal(img_path, expected):
img_path = test_file_dir / img_path
Expand Down
22 changes: 14 additions & 8 deletions wired_table_rec/table_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,20 @@ def get_benchmark_cols(
sorted(range_res.items(), key=lambda x: x[0], reverse=True)
)
for k, v in sorted_res.items():
if not all(v):
continue

longest_x = np.insert(longest_x, v[1], cur_row[k])
longest_col_points = np.insert(
longest_col_points, v[1], polygons[row_value[k]], axis=0
)

# bugfix: https://github.com/RapidAI/TableStructureRec/discussions/55
# 最长列不包含第一列和最后一列的场景需要兼容
if all(v) or v[1] == 0:
longest_x = np.insert(longest_x, v[1], cur_row[k])
longest_col_points = np.insert(
longest_col_points, v[1], polygons[row_value[k]], axis=0
)
elif v[0] and v[0] + 1 == len(longest_x):
longest_x = np.append(longest_x, cur_row[k])
longest_col_points = np.append(
longest_col_points,
polygons[row_value[k]][np.newaxis, :, :],
axis=0,
)
# 求出最右侧所有cell的宽,其中最小的作为最后一列宽度
rightmost_idxs = [v[-1] for v in rows.values()]
rightmost_boxes = polygons[rightmost_idxs]
Expand Down

0 comments on commit 510dd16

Please sign in to comment.