diff --git a/plugins/rknn/package.json b/plugins/rknn/package.json index d1f07515ce..519774d102 100644 --- a/plugins/rknn/package.json +++ b/plugins/rknn/package.json @@ -39,11 +39,12 @@ "type": "API", "interfaces": [ "ObjectDetection", - "ObjectDetectionPreview" + "ObjectDetectionPreview", + "DeviceProvider" ] }, "devDependencies": { "@scrypted/sdk": "file:../../sdk" }, - "version": "0.0.4" + "version": "0.1.0" } diff --git a/plugins/rknn/src/det_utils/__init__.py b/plugins/rknn/src/det_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/rknn/src/det_utils/db_postprocess.py b/plugins/rknn/src/det_utils/db_postprocess.py new file mode 100644 index 0000000000..ac50634666 --- /dev/null +++ b/plugins/rknn/src/det_utils/db_postprocess.py @@ -0,0 +1,269 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refered from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 +# import paddle +from shapely.geometry import Polygon +import pyclipper + + +class DBPostProcess(object): + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__(self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + score_mode="fast", + **kwargs): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.score_mode = score_mode + assert score_mode in [ + "slow", "fast" + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + + self.dilation_kernel = None if not use_dilation else np.array( + [[1, 1], [1, 1]]) + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + ''' + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, + cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes = [] + scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + if self.score_mode == "fast": + score = self.box_score_fast(pred, points.reshape(-1, 2)) + else: + score = self.box_score_slow(pred, contour) + if self.box_thresh > score: + continue + + box = self.unclip(points).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.astype(np.int16)) + scores.append(score) + return np.array(boxes, dtype=np.int16), scores + + def unclip(self, box): + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + ''' + box_score_fast: use bbox mean score as the mean score + ''' + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + def box_score_slow(self, bitmap, contour): + ''' + box_score_slow: use polyon mean score as the mean score + ''' + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] + # if isinstance(pred, paddle.Tensor): + # pred = pred.numpy() + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + else: + mask = segmentation[batch_index] + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, + src_w, src_h) + + boxes_batch.append({'points': boxes}) + return boxes_batch + + +class DistillationDBPostProcess(object): + def __init__(self, + model_name=["student"], + key=None, + thresh=0.3, + box_thresh=0.6, + max_candidates=1000, + unclip_ratio=1.5, + use_dilation=False, + score_mode="fast", + **kwargs): + self.model_name = model_name + self.key = key + self.post_process = DBPostProcess( + thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode) + + def __call__(self, predicts, shape_list): + results = {} + for k in self.model_name: + results[k] = self.post_process(predicts[k], shape_list=shape_list) + return results + + +class DetPostProcess(object): + def __init__(self) -> None: + pass + + def order_points_clockwise(self, pts): + """ + reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py + # sort the points based on their x-coordinates + """ + xSorted = pts[np.argsort(pts[:, 0]), :] + + # grab the left-most and right-most points from the sorted + # x-roodinate points + leftMost = xSorted[:2, :] + rightMost = xSorted[2:, :] + + # now, sort the left-most coordinates according to their + # y-coordinates so we can grab the top-left and bottom-left + # points, respectively + leftMost = leftMost[np.argsort(leftMost[:, 1]), :] + (tl, bl) = leftMost + + rightMost = rightMost[np.argsort(rightMost[:, 1]), :] + (tr, br) = rightMost + + rect = np.array([tl, tr, br, bl], dtype="float32") + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes diff --git a/plugins/rknn/src/det_utils/operators.py b/plugins/rknn/src/det_utils/operators.py new file mode 100644 index 0000000000..ebf6f70c48 --- /dev/null +++ b/plugins/rknn/src/det_utils/operators.py @@ -0,0 +1,373 @@ +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +import cv2 +import numpy as np + + +class DecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(img, 1) + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) + + data['image'] = img + return data + + +class NRTRDecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + + img = cv2.imdecode(img, 1) + + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) + if self.channel_first: + img = img.transpose((2, 0, 1)) + data['image'] = img + return data + +class NormalizeImage(object): + """ normalize image such as substract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + + assert isinstance(img, + np.ndarray), "invalid input 'img' in NormalizeImage" + data['image'] = ( + img.astype('float32') * self.scale - self.mean) / self.std + return data + + +class ToCHWImage(object): + """ convert hwc image to chw image + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + data['image'] = img.transpose((2, 0, 1)) + return data + + +class KeepKeys(object): + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list + + +class DetResizeForTest(object): + def __init__(self, **kwargs): + super(DetResizeForTest, self).__init__() + self.square_input = True + self.resize_type = 0 + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + + if self.resize_type == 0: + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + + + + data['image'] = img + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + if len(data['shape'].shape) == 1: + data['shape'] = np.expand_dims(data['shape'], axis=0) + return data + + def resize_image_type1(self, img): + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, c = img.shape + + # limit the max side + if self.limit_type == 'max': + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + elif self.limit_type == 'min': + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + elif self.limit_type == 'resize_long': + ratio = float(limit_side_len) / max(h,w) + else: + raise Exception('not support limit type, image ') + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = max(int(round(resize_h / 32) * 32), 32) + resize_w = max(int(round(resize_w / 32) * 32), 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + print(img.shape, resize_w, resize_h) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] + + +class E2EResizeForTest(object): + def __init__(self, **kwargs): + super(E2EResizeForTest, self).__init__() + self.max_side_len = kwargs['max_side_len'] + self.valid_set = kwargs['valid_set'] + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + if self.valid_set == 'totaltext': + im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( + img, max_side_len=self.max_side_len) + else: + im_resized, (ratio_h, ratio_w) = self.resize_image( + img, max_side_len=self.max_side_len) + data['image'] = im_resized + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_for_totaltext(self, im, max_side_len=512): + + h, w, _ = im.shape + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + def resize_image(self, im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + + +class Pad_to_max_len(object): + def __init__(self, **kwargs): + super(Pad_to_max_len, self).__init__() + self.max_h = kwargs['max_h'] + self.max_w = kwargs['max_w'] + + def __call__(self, data): + img = data['image'] + if img.shape[-1] == 3: + # hwc + if img.shape[0]!= self.max_h: + # TODO support + # assert False, "not support" + pad_h = self.max_h - img.shape[0] + pad_w = self.max_w - img.shape[1] + img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant', constant_values=0) + if img.shape[1] < self.max_w: + pad_w = self.max_w - img.shape[1] + img = np.pad(img, ((0, 0), (0, pad_w), (0, 0)), 'constant', constant_values=0) + + elif img.shape[0] == 3: + # chw + img = img.transpose((1, 2, 0)) + if img.shape[1]!= self.max_h: + # TODO support + assert False, "not support" + if img.shape[0] < self.max_w: + pad_w = self.max_w - img.shape[0] + img = np.pad(img, ((0, 0), (0, 0), (0, pad_w)), 'constant', constant_values=0) + + else: + assert False, "not support" + + data['image'] = img + + return data \ No newline at end of file diff --git a/plugins/rknn/src/rec_utils/__init__.py b/plugins/rknn/src/rec_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/rknn/src/rec_utils/operators.py b/plugins/rknn/src/rec_utils/operators.py new file mode 100644 index 0000000000..847055710d --- /dev/null +++ b/plugins/rknn/src/rec_utils/operators.py @@ -0,0 +1,376 @@ +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +import cv2 +import numpy as np + + +class DecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(img, 1) + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) + + data['image'] = img + return data + + +class NRTRDecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + + img = cv2.imdecode(img, 1) + + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) + if self.channel_first: + img = img.transpose((2, 0, 1)) + data['image'] = img + return data + +class NormalizeImage(object): + """ normalize image such as substract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + + assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" + + i = img.astype('float32') + i = i * self.scale + i = i - self.mean + i = i / self.std + data['image'] = i + return data + + +class ToCHWImage(object): + """ convert hwc image to chw image + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + data['image'] = img.transpose((2, 0, 1)) + return data + + +class KeepKeys(object): + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list + + +class DetResizeForTest(object): + def __init__(self, **kwargs): + super(DetResizeForTest, self).__init__() + self.square_input = True + self.resize_type = 0 + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + + if self.resize_type == 0: + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + + + + data['image'] = img + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + if len(data['shape'].shape) == 1: + data['shape'] = np.expand_dims(data['shape'], axis=0) + return data + + def resize_image_type1(self, img): + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, c = img.shape + + # limit the max side + if self.limit_type == 'max': + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + elif self.limit_type == 'min': + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + elif self.limit_type == 'resize_long': + ratio = float(limit_side_len) / max(h,w) + else: + raise Exception('not support limit type, image ') + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = max(int(round(resize_h / 32) * 32), 32) + resize_w = max(int(round(resize_w / 32) * 32), 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + print(img.shape, resize_w, resize_h) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] + + +class E2EResizeForTest(object): + def __init__(self, **kwargs): + super(E2EResizeForTest, self).__init__() + self.max_side_len = kwargs['max_side_len'] + self.valid_set = kwargs['valid_set'] + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + if self.valid_set == 'totaltext': + im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( + img, max_side_len=self.max_side_len) + else: + im_resized, (ratio_h, ratio_w) = self.resize_image( + img, max_side_len=self.max_side_len) + data['image'] = im_resized + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_for_totaltext(self, im, max_side_len=512): + + h, w, _ = im.shape + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + def resize_image(self, im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + + +class Pad_to_max_len(object): + def __init__(self, **kwargs): + super(Pad_to_max_len, self).__init__() + self.max_h = kwargs['max_h'] + self.max_w = kwargs['max_w'] + + def __call__(self, data): + img = data['image'] + if img.shape[-1] == 3: + # hwc + if img.shape[0]!= self.max_h: + # TODO support + # assert False, "not support" + pad_h = self.max_h - img.shape[0] + pad_w = self.max_w - img.shape[1] + img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant', constant_values=0) + if img.shape[1] < self.max_w: + pad_w = self.max_w - img.shape[1] + img = np.pad(img, ((0, 0), (0, pad_w), (0, 0)), 'constant', constant_values=0) + + elif img.shape[0] == 3: + # chw + img = img.transpose((1, 2, 0)) + if img.shape[1]!= self.max_h: + # TODO support + assert False, "not support" + if img.shape[0] < self.max_w: + pad_w = self.max_w - img.shape[0] + img = np.pad(img, ((0, 0), (0, 0), (0, pad_w)), 'constant', constant_values=0) + + else: + assert False, "not support" + + data['image'] = img + + return data \ No newline at end of file diff --git a/plugins/rknn/src/rec_utils/rec_postprocess.py b/plugins/rknn/src/rec_utils/rec_postprocess.py new file mode 100644 index 0000000000..3aa35853cc --- /dev/null +++ b/plugins/rknn/src/rec_utils/rec_postprocess.py @@ -0,0 +1,814 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +# import paddle +# from paddle.nn import functional as F +import re + + +class BaseRecLabelDecode(object): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False): + self.beg_str = "sos" + self.end_str = "eos" + + self.character_str = [] + if character_dict_path is None: + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + else: + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + self.character_str.append(line) + if use_space_char: + self.character_str.append(" ") + dict_character = list(self.character_str) + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + if 'arabic' in character_dict_path: + self.reverse = True + else: + self.reverse = False + + def pred_reverse(self, pred): + pred_re = [] + c_current = '' + for c in pred: + if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)): + if c_current != '': + pred_re.append(c_current) + pred_re.append(c) + c_current = '' + else: + c_current += c + if c_current != '': + pred_re.append(c_current) + + return ''.join(pred_re[::-1]) + + def add_special_char(self, dict_character): + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[ + batch_idx][:-1] + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + + char_list = [ + self.character[text_id] + for text_id in text_index[batch_idx][selection] + ] + if text_prob is not None: + conf_list = text_prob[batch_idx][selection] + else: + conf_list = [1] * len(selection) + if len(conf_list) == 0: + conf_list = [0] + + text = ''.join(char_list) + + if self.reverse: # for arabic rec + text = self.pred_reverse(text) + + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def get_ignored_tokens(self): + return [0] # for ctc blank + + +class CTCLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(CTCLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, tuple) or isinstance(preds, list): + preds = preds[-1] + # if isinstance(preds, paddle.Tensor): + # preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + if label is None: + return text + label = self.decode(label) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank'] + dict_character + return dict_character + + +class DistillationCTCLabelDecode(CTCLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__(self, + character_dict_path=None, + use_space_char=False, + model_name=["student"], + key=None, + multi_head=False, + **kwargs): + super(DistillationCTCLabelDecode, self).__init__(character_dict_path, + use_space_char) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + self.multi_head = multi_head + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + if self.multi_head and isinstance(pred, dict): + pred = pred['ctc'] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + +class AttnLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(AttnLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + [beg_idx, end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + # if isinstance(preds, paddle.Tensor): + # preds = preds.numpy() + + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx + + +class SEEDLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SEEDLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + self.padding_str = "padding" + self.end_str = "eos" + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding_str, self.unknown + ] + return dict_character + + def get_ignored_tokens(self): + end_idx = self.get_beg_end_flag_idx("eos") + return [end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "sos": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "eos": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + [end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + preds_idx = preds["rec_pred"] + # if isinstance(preds_idx, paddle.Tensor): + # preds_idx = preds_idx.numpy() + if "rec_pred_scores" in preds: + preds_idx = preds["rec_pred"] + preds_prob = preds["rec_pred_scores"] + else: + preds_idx = preds["rec_pred"].argmax(axis=2) + preds_prob = preds["rec_pred"].max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + +class SRNLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SRNLabelDecode, self).__init__(character_dict_path, + use_space_char) + self.max_text_length = kwargs.get('max_text_length', 25) + + def __call__(self, preds, label=None, *args, **kwargs): + pred = preds['predict'] + char_num = len(self.character_str) + 2 + # if isinstance(pred, paddle.Tensor): + # pred = pred.numpy() + pred = np.reshape(pred, [-1, char_num]) + + preds_idx = np.argmax(pred, axis=1) + preds_prob = np.max(pred, axis=1) + + preds_idx = np.reshape(preds_idx, [-1, self.max_text_length]) + + preds_prob = np.reshape(preds_prob, [-1, self.max_text_length]) + + text = self.decode(preds_idx, preds_prob) + + if label is None: + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + return text + label = self.decode(label) + return text, label + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx + + +class SARLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SARLabelDecode, self).__init__(character_dict_path, + use_space_char) + + self.rm_symbol = kwargs.get('rm_symbol', False) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(self.end_idx): + if text_prob is None and idx == 0: + continue + else: + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + if self.rm_symbol: + comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') + text = text.lower() + text = comp.sub('', text) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + # if isinstance(preds, paddle.Tensor): + # preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + return [self.padding_idx] + + +class DistillationSARLabelDecode(SARLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__(self, + character_dict_path=None, + use_space_char=False, + model_name=["student"], + key=None, + multi_head=False, + **kwargs): + super(DistillationSARLabelDecode, self).__init__(character_dict_path, + use_space_char) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + self.multi_head = multi_head + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + if self.multi_head and isinstance(pred, dict): + pred = pred['sar'] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + +class PRENLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(PRENLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + padding_str = '' # 0 + end_str = '' # 1 + unknown_str = '' # 2 + + dict_character = [padding_str, end_str, unknown_str] + dict_character + self.padding_idx = 0 + self.end_idx = 1 + self.unknown_idx = 2 + + return dict_character + + def decode(self, text_index, text_prob=None): + """ convert text-index into text-label. """ + result_list = [] + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == self.end_idx: + break + if text_index[batch_idx][idx] in \ + [self.padding_idx, self.unknown_idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = ''.join(char_list) + if len(text) > 0: + result_list.append((text, np.mean(conf_list).tolist())) + else: + # here confidence of empty recog result is 1 + result_list.append(('', 1)) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob) + if label is None: + return text + label = self.decode(label) + return text, label + + +class NRTRLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=True, **kwargs): + super(NRTRLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + + if len(preds) == 2: + preds_id = preds[0] + preds_prob = preds[1] + # if isinstance(preds_id, paddle.Tensor): + # preds_id = preds_id.numpy() + # if isinstance(preds_prob, paddle.Tensor): + # preds_prob = preds_prob.numpy() + if preds_id[0][0] == 2: + preds_idx = preds_id[:, 1:] + preds_prob = preds_prob[:, 1:] + else: + preds_idx = preds_id + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:, 1:]) + else: + # if isinstance(preds, paddle.Tensor): + # preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:, 1:]) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank', '', '', ''] + dict_character + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + try: + char_idx = self.character[int(text_index[batch_idx][idx])] + except: + continue + if char_idx == '': # end + break + char_list.append(char_idx) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text.lower(), np.mean(conf_list).tolist())) + return result_list + + +class ViTSTRLabelDecode(NRTRLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(ViTSTRLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + # if isinstance(preds, paddle.Tensor): + # preds = preds[:, 1:].numpy() + # else: + # preds = preds[:, 1:] + preds = preds[:, 1:].numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:, 1:]) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['', ''] + dict_character + return dict_character + + +class ABINetLabelDecode(NRTRLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(ABINetLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, dict): + preds = preds['align'][-1].numpy() + # elif isinstance(preds, paddle.Tensor): + # preds = preds.numpy() + # else: + # preds = preds + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label) + return text, label + + def add_special_char(self, dict_character): + dict_character = [''] + dict_character + return dict_character + + +class SPINLabelDecode(AttnLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SPINLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + [self.end_str] + dict_character + return dict_character + + +# class VLLabelDecode(BaseRecLabelDecode): +# """ Convert between text-label and text-index """ + +# def __init__(self, character_dict_path=None, use_space_char=False, +# **kwargs): +# super(VLLabelDecode, self).__init__(character_dict_path, use_space_char) +# self.max_text_length = kwargs.get('max_text_length', 25) +# self.nclass = len(self.character) + 1 +# self.character = self.character[10:] + self.character[ +# 1:10] + [self.character[0]] + +# def decode(self, text_index, text_prob=None, is_remove_duplicate=False): +# """ convert text-index into text-label. """ +# result_list = [] +# ignored_tokens = self.get_ignored_tokens() +# batch_size = len(text_index) +# for batch_idx in range(batch_size): +# selection = np.ones(len(text_index[batch_idx]), dtype=bool) +# if is_remove_duplicate: +# selection[1:] = text_index[batch_idx][1:] != text_index[ +# batch_idx][:-1] +# for ignored_token in ignored_tokens: +# selection &= text_index[batch_idx] != ignored_token + +# char_list = [ +# self.character[text_id - 1] +# for text_id in text_index[batch_idx][selection] +# ] +# if text_prob is not None: +# conf_list = text_prob[batch_idx][selection] +# else: +# conf_list = [1] * len(selection) +# if len(conf_list) == 0: +# conf_list = [0] + +# text = ''.join(char_list) +# result_list.append((text, np.mean(conf_list).tolist())) +# return result_list + +# def __call__(self, preds, label=None, length=None, *args, **kwargs): +# if len(preds) == 2: # eval mode +# text_pre, x = preds +# b = text_pre.shape[1] +# lenText = self.max_text_length +# nsteps = self.max_text_length + +# if not isinstance(text_pre, paddle.Tensor): +# text_pre = paddle.to_tensor(text_pre, dtype='float32') + +# out_res = paddle.zeros( +# shape=[lenText, b, self.nclass], dtype=x.dtype) +# out_length = paddle.zeros(shape=[b], dtype=x.dtype) +# now_step = 0 +# for _ in range(nsteps): +# if 0 in out_length and now_step < nsteps: +# tmp_result = text_pre[now_step, :, :] +# out_res[now_step] = tmp_result +# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1) +# for j in range(b): +# if out_length[j] == 0 and tmp_result[j] == 0: +# out_length[j] = now_step + 1 +# now_step += 1 +# for j in range(0, b): +# if int(out_length[j]) == 0: +# out_length[j] = nsteps +# start = 0 +# output = paddle.zeros( +# shape=[int(out_length.sum()), self.nclass], dtype=x.dtype) +# for i in range(0, b): +# cur_length = int(out_length[i]) +# output[start:start + cur_length] = out_res[0:cur_length, i, :] +# start += cur_length +# net_out = output +# length = out_length + +# else: # train mode +# net_out = preds[0] +# length = length +# net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)]) +# text = [] +# if not isinstance(net_out, paddle.Tensor): +# net_out = paddle.to_tensor(net_out, dtype='float32') +# net_out = F.softmax(net_out, axis=1) +# for i in range(0, length.shape[0]): +# preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum( +# ) + length[i])].topk(1)[1][:, 0].tolist() +# preds_text = ''.join([ +# self.character[idx - 1] +# if idx > 0 and idx <= len(self.character) else '' +# for idx in preds_idx +# ]) +# preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum( +# ) + length[i])].topk(1)[0][:, 0] +# preds_prob = paddle.exp( +# paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6)) +# text.append((preds_text, preds_prob.numpy()[0])) +# if label is None: +# return text +# label = self.decode(label) +# return text, label + diff --git a/plugins/rknn/src/requirements.txt b/plugins/rknn/src/requirements.txt index 80633c4d46..b033d93fd3 100644 --- a/plugins/rknn/src/requirements.txt +++ b/plugins/rknn/src/requirements.txt @@ -1,2 +1,6 @@ https://github.com/airockchip/rknn-toolkit2/raw/v2.0.0-beta0/rknn-toolkit-lite2/packages/rknn_toolkit_lite2-2.0.0b0-cp310-cp310-linux_aarch64.whl -pillow==10.3.0 \ No newline at end of file +pillow==10.3.0 +six==1.16.0 +shapely== 2.0.4 +pyclipper==1.3.0.post5 +opencv-python-headless==4.9.0.80 \ No newline at end of file diff --git a/plugins/rknn/src/rknn/plugin.py b/plugins/rknn/src/rknn/plugin.py index f6ec2fadbc..05b2e5a10b 100644 --- a/plugins/rknn/src/rknn/plugin.py +++ b/plugins/rknn/src/rknn/plugin.py @@ -2,8 +2,8 @@ import concurrent.futures import os import platform -import queue import threading +import traceback from typing import Any, Coroutine, List, Tuple import urllib.request @@ -14,9 +14,14 @@ from predict import PredictPlugin, Prediction from predict.rectangle import Rectangle +import scrypted_sdk +from scrypted_sdk import DeviceProvider, ScryptedDeviceType, ScryptedInterface + # for Rockchip-optimized models, the postprocessing is slightly different from the original models from .optimized.yolo import post_process, IMG_SIZE, CLASSES +from .text_recognition import TEXT_RECOGNITION_NATIVE_ID, TextRecognition + rknn_verbose = False lib_download = 'https://github.com/airockchip/rknn-toolkit2/raw/v2.0.0-beta0/rknpu2/runtime/Linux/librknn_api/aarch64/librknnrt.so' @@ -53,13 +58,16 @@ def ensure_compatibility_and_get_cpu(): raise -class RKNNPlugin(PredictPlugin): +class RKNNPlugin(PredictPlugin, DeviceProvider): labels = {i: CLASSES[i] for i in range(len(CLASSES))} rknn_runtimes: dict + executor: concurrent.futures.ThreadPoolExecutor + text_recognition: TextRecognition = None + cpu: str def __init__(self, nativeId=None): super().__init__(nativeId) - cpu = ensure_compatibility_and_get_cpu() + self.cpu = ensure_compatibility_and_get_cpu() model = 'yolov6n' self.rknn_runtimes = {} @@ -72,7 +80,7 @@ def __init__(self, nativeId=None): else: raise RuntimeError('librknnrt.so not found. Please download it from {} and place it at {}'.format(lib_download, lib_path)) - model_download = model_download_tmpl.format(model, cpu) + model_download = model_download_tmpl.format(model, self.cpu) model_file = os.path.basename(model_download) model_path = self.downloadFile(model_download, model_file) print('Using model {}'.format(model_path)) @@ -101,7 +109,33 @@ def executor_initializer(): self.rknn_runtimes[thread_name] = rknn print('RKNNLite runtime initialized on thread {}'.format(thread_name)) - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3, initializer=executor_initializer) + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3, thread_name_prefix=type(self).__name__, initializer=executor_initializer) + + asyncio.create_task(self.discoverRecognitionModels()) + + async def discoverRecognitionModels(self) -> None: + devices = [ + { + "nativeId": TEXT_RECOGNITION_NATIVE_ID, + "name": "Rockchip NPU Text Recognition", + "type": ScryptedDeviceType.API.value, + "interfaces": [ + ScryptedInterface.ObjectDetection.value, + ], + } + ] + await scrypted_sdk.deviceManager.onDevicesChanged({ + "devices": devices, + }) + + async def getDevice(self, nativeId: str) -> TextRecognition: + try: + if nativeId == TEXT_RECOGNITION_NATIVE_ID: + self.text_recognition = self.text_recognition or TextRecognition(nativeId, self.cpu) + return self.text_recognition + except: + traceback.print_exc() + raise def get_input_details(self) -> Tuple[int]: return (IMG_SIZE[0], IMG_SIZE[1], 3) diff --git a/plugins/rknn/src/rknn/text_recognition.py b/plugins/rknn/src/rknn/text_recognition.py new file mode 100644 index 0000000000..22672c9cf4 --- /dev/null +++ b/plugins/rknn/src/rknn/text_recognition.py @@ -0,0 +1,264 @@ +import asyncio +import concurrent.futures +import math +import os +import threading +import traceback +from typing import Any, Callable, List + +import numpy as np +from PIL import Image, ImageOps +from rknnlite.api import RKNNLite + +from common.text import skew_image, crop_text, calculate_y_change +from predict import Prediction +from predict.rectangle import Rectangle +from predict.text_recognize import TextRecognition +import scrypted_sdk +from scrypted_sdk.types import ObjectsDetected, ObjectDetectionResult +import det_utils.operators +import det_utils.db_postprocess +import rec_utils.operators +import rec_utils.rec_postprocess + + +TEXT_RECOGNITION_NATIVE_ID = "rknntextrecognition" +DET_IMG_SIZE = (480, 480) + +RKNN_DET_PREPROCESS_CONFIG = [ + { + 'DetResizeForTest': { + 'image_shape': DET_IMG_SIZE + } + }, + { + 'NormalizeImage': { + 'std': [1., 1., 1.], + 'mean': [0., 0., 0.], + 'scale': '1.', + 'order': 'hwc' + } + } +] + +RKNN_DET_POSTPROCESS_CONFIG = { + 'DBPostProcess': { + 'thresh': 0.3, + 'box_thresh': 0.6, + 'max_candidates': 1000, + 'unclip_ratio': 1.5, + 'use_dilation': False, + 'score_mode': 'fast', + } +} + +RKNN_REC_PREPROCESS_CONFIG = [ + { + 'NormalizeImage': { + 'std': [1, 1, 1], + 'mean': [0, 0, 0], + 'scale': '1./255.', + 'order': 'hwc' + } + } +] + +RKNN_REC_POSTPROCESS_CONFIG = { + 'CTCLabelDecode':{ + "character_dict_path": None, # will be replaced by RKNNDetection.__init__() + "use_space_char": True + } +} + +rknn_verbose = False +model_download_tmpl = 'https://github.com/bjia56/scrypted-rknn/raw/main/models/{}_{}.rknn' +chardict_link = 'https://github.com/bjia56/scrypted-rknn/raw/main/models/ppocr_keys_v1.txt' + + +class RKNNText: + model_path: str + rknn_runtimes: dict + executor: concurrent.futures.ThreadPoolExecutor + preprocess_funcs: List[Callable] + postprocess_func: Callable + print: Callable + + def __init__(self, model_path, print) -> None: + self.model_path = model_path + self.rknn_runtimes = {} + self.print = print + + if not self.model_path: + raise ValueError('model_path is not set') + + def executor_initializer(): + thread_name = threading.current_thread().name + rknn = RKNNLite(verbose=rknn_verbose) + ret = rknn.load_rknn(self.model_path) + if ret != 0: + raise RuntimeError('Failed to load model: {}'.format(ret)) + + ret = rknn.init_runtime() + if ret != 0: + raise RuntimeError('Failed to init runtime: {}'.format(ret)) + + self.rknn_runtimes[thread_name] = rknn + self.print('RKNNLite runtime initialized on thread {}'.format(thread_name)) + + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3, thread_name_prefix=type(self).__name__, initializer=executor_initializer) + + def detect(self, img): + def do_detect(img): + model_input = img + for p in self.preprocess_funcs: + model_input = p(model_input) + + rknn = self.rknn_runtimes[threading.current_thread().name] + output = rknn.inference(inputs=[np.expand_dims(model_input['image'], axis=0)]) + + return self.postprocess_func(output, model_input['shape'], model_input['image'].shape) + + future = self.executor.submit(do_detect, {'image': img, 'shape': img.shape}) + return future + + +class RKNNDetection(RKNNText): + db_preprocess = None + det_postprocess = None + + def __init__(self, model_path, print): + super().__init__(model_path, print) + + self.preprocess_funcs = [] + for item in RKNN_DET_PREPROCESS_CONFIG: + for key in item: + pclass = getattr(det_utils.operators, key) + p = pclass(**item[key]) + self.preprocess_funcs.append(p) + + self.db_postprocess = det_utils.db_postprocess.DBPostProcess(**RKNN_DET_POSTPROCESS_CONFIG['DBPostProcess']) + self.det_postprocess = det_utils.db_postprocess.DetPostProcess() + + def postprocess(output, model_shape, img_shape): + preds = {'maps': output[0].astype(np.float32)} + result = self.db_postprocess(preds, model_shape) + return self.det_postprocess.filter_tag_det_res(result[0]['points'], img_shape) + self.postprocess_func = postprocess + + +class RKNNRecognition(RKNNText): + ctc_postprocess = None + + def __init__(self, model_path, print): + super().__init__(model_path, print) + + self.preprocess_funcs = [] + for item in RKNN_REC_PREPROCESS_CONFIG: + for key in item: + pclass = getattr(rec_utils.operators, key) + p = pclass(**item[key]) + self.preprocess_funcs.append(p) + + self.ctc_postprocess = rec_utils.rec_postprocess.CTCLabelDecode(**RKNN_REC_POSTPROCESS_CONFIG['CTCLabelDecode']) + + def postprocess(output, model_shape, img_shape): + preds = output[0].astype(np.float32) + output = self.ctc_postprocess(preds) + return output + self.postprocess_func = postprocess + + +async def prepare_text_result(d: ObjectDetectionResult, image: scrypted_sdk.Image, skew_angle: float): + textImage = await crop_text(d, image) + + skew_height_change = calculate_y_change(d["boundingBox"][3], skew_angle) + skew_height_change = math.floor(skew_height_change) + textImage = skew_image(textImage, skew_angle) + # crop skew_height_change from top + if skew_height_change > 0: + textImage = textImage.crop((0, 0, textImage.width, textImage.height - skew_height_change)) + elif skew_height_change < 0: + textImage = textImage.crop((0, -skew_height_change, textImage.width, textImage.height)) + + new_height = 48 + new_width = int(textImage.width * new_height / textImage.height) + textImage = textImage.resize((new_width, new_height), resample=Image.LANCZOS).convert("L") + + new_width = 320 + # calculate padding dimensions + padding = (0, 0, new_width - textImage.width, 0) + # todo: clamp entire edge rather than just center + edge_color = textImage.getpixel((textImage.width - 1, textImage.height // 2)) + # pad image + textImage = ImageOps.expand(textImage, padding, fill=edge_color) + # pil to numpy + image_array = np.array(textImage) + image_array = image_array.reshape(textImage.height, textImage.width, 1) + image_tensor = image_array#.transpose((2, 0, 1)) / 255 + + # test normalize contrast + # image_tensor = (image_tensor - np.min(image_tensor)) / (np.max(image_tensor) - np.min(image_tensor)) + + image_tensor = (image_tensor - 0.5) / 0.5 + + return image_tensor + + +class TextRecognition(TextRecognition): + detection: RKNNDetection + recognition: RKNNRecognition + + def __init__(self, nativeId=None, cpu=""): + super().__init__(nativeId) + + model_download = model_download_tmpl.format("ppocrv4_det", cpu) + model_file = os.path.basename(model_download) + det_model_path = self.downloadFile(model_download, model_file) + + model_download = model_download_tmpl.format("ppocrv4_rec", cpu) + model_file = os.path.basename(model_download) + rec_model_path = self.downloadFile(model_download, model_file) + + chardict_file = os.path.basename(chardict_link) + chardict_path = self.downloadFile(chardict_link, chardict_file) + RKNN_REC_POSTPROCESS_CONFIG['CTCLabelDecode']['character_dict_path'] = chardict_path + + self.detection = RKNNDetection(det_model_path, lambda *args, **kwargs: self.print(*args, **kwargs)) + self.recognition = RKNNRecognition(rec_model_path, lambda *args, **kwargs: self.print(*args, **kwargs)) + self.inputheight = DET_IMG_SIZE[0] + self.inputwidth = DET_IMG_SIZE[1] + + async def detect_once(self, input: Image, settings: Any, src_size, cvss) -> ObjectsDetected: + detections = await asyncio.wrap_future( + self.detection.detect(np.array(input)), loop=asyncio.get_event_loop() + ) + + #self.print(detections) + + predictions: List[Prediction] = [] + for box in detections: + #self.print(box) + tl, tr, br, bl = box + l = min(tl[0], bl[0]) + t = min(tl[1], tr[1]) + r = max(tr[0], br[0]) + b = max(bl[1], br[1]) + + pred = Prediction(0, 1, Rectangle(l, t, r, b)) + predictions.append(pred) + + return self.create_detection_result(predictions, src_size, cvss) + + async def setLabel( + self, d: ObjectDetectionResult, image: scrypted_sdk.Image, skew_angle: float + ): + try: + image_tensor = await prepare_text_result(d, image, skew_angle) + preds = await asyncio.wrap_future( + self.recognition.detect(image_tensor), loop=asyncio.get_event_loop() + ) + #self.print("preds", preds) + d["label"] = preds[0][0] + except Exception as e: + traceback.print_exc() + pass