ソースを参照

图像方向分类上线,maxcompute相关代码优化

fangjiasheng 2 年 前
コミット
f593404a55

+ 5 - 7
atc/atc_interface.py

@@ -4,13 +4,11 @@ import os
 import sys
 import time
 import traceback
-from glob import glob
-import numpy as np
-import cv2
-import tensorflow as tf
-from PIL import Image
 os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
-MAX_COMPUTE = False
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+from format_convert.max_compute_config import max_compute
+import tensorflow as tf
+MAX_COMPUTE = max_compute
 
 if not MAX_COMPUTE:
     # tensorflow 内存设置
@@ -31,7 +29,7 @@ if not MAX_COMPUTE:
         os.environ['CUDA_CACHE_DISABLE'] = str(0)
         gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
         sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+
 from format_convert import _global
 from format_convert.utils import log, bytes2np, request_post, get_md5_from_bytes
 from flask import Flask, request

+ 134 - 115
format_convert/convert.py

@@ -37,23 +37,16 @@ logging.getLogger("pdfminer").setLevel(logging.WARNING)
 from format_convert.table_correct import *
 from format_convert.wrapt_timeout_decorator import *
 from format_convert import _global
+from format_convert.max_compute_config import max_compute
 
-MAX_COMPUTE = False
 
-port_num = [0]
-def choose_port():
-    process_num = 4
-    if port_num[0] % process_num == 0:
-        _url = local_url + ":15011"
-    elif port_num[0] % process_num == 1:
-        _url = local_url + ":15012"
-    elif port_num[0] % process_num == 2:
-        _url = local_url + ":15013"
-    elif port_num[0] % process_num == 3:
-        _url = local_url + ":15014"
+MAX_COMPUTE = max_compute
 
-    port_num[0] = port_num[0] + 1
-    return _url
+
+if get_platform() == "Windows":
+    globals().update({"time_out": 1000})
+else:
+    globals().update({"time_out": 300})
 
 
 @memory_decorator
@@ -124,72 +117,6 @@ def to_html(path, text):
         f.write("</body>")
 
 
-def resize_image(image_path, size):
-    try:
-        image_np = cv2.imread(image_path)
-        # print(image_np.shape)
-        width = image_np.shape[1]
-        height = image_np.shape[0]
-        h_w_rate = height / width
-
-        # width_standard = 900
-        # height_standard = 1400
-
-        width_standard = size[1]
-        height_standard = size[0]
-
-        width_new = int(height_standard / h_w_rate)
-        height_new = int(width_standard * h_w_rate)
-
-        if width > width_standard:
-            image_np = cv2.resize(image_np, (width_standard, height_new))
-        elif height > height_standard:
-            image_np = cv2.resize(image_np, (width_new, height_standard))
-
-        cv2.imwrite(image_path, image_np)
-        # print("resize_image", image_np.shape)
-        return
-    except Exception as e:
-        log("resize_image")
-        print("resize_image", e, global_type)
-        return
-
-
-def remove_red_seal(image_np):
-    """
-    去除红色印章
-    """
-    # 获得红色通道
-    blue_c, green_c, red_c = cv2.split(image_np)
-
-    # 多传入一个参数cv2.THRESH_OTSU,并且把阈值thresh设为0,算法会找到最优阈值
-    thresh, ret = cv2.threshold(red_c, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
-    # print("remove_red_seal thresh", thresh)
-
-    # 实测调整为95%效果好一些
-    filter_condition = int(thresh * 0.98)
-    thresh1, red_thresh = cv2.threshold(red_c, filter_condition, 255, cv2.THRESH_BINARY)
-
-    # 把图片转回 3 通道
-    image_and = np.expand_dims(red_thresh, axis=2)
-    image_and = np.concatenate((image_and, image_and, image_and), axis=-1)
-    # print(image_and.shape)
-
-    # 膨胀
-    gray = cv2.cvtColor(image_and, cv2.COLOR_RGB2GRAY)
-    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
-    erode = cv2.erode(gray, kernel)
-    cv2.imshow("erode", erode)
-    cv2.waitKey(0)
-
-    image_and = np.bitwise_and(cv2.bitwise_not(blue_c), cv2.bitwise_not(erode))
-    result_img = cv2.bitwise_not(image_and)
-
-    cv2.imshow("remove_red_seal", result_img)
-    cv2.waitKey(0)
-    return result_img
-
-
 def remove_underline(image_np):
     """
     去除文字下划线
@@ -224,41 +151,6 @@ def remove_underline(image_np):
     return
 
 
-def getMDFFromFile(path):
-    _length = 0
-    try:
-        _md5 = hashlib.md5()
-        with open(path, "rb") as ff:
-            while True:
-                data = ff.read(4096)
-                if not data:
-                    break
-                _length += len(data)
-                _md5.update(data)
-        return _md5.hexdigest(), _length
-    except Exception as e:
-        traceback.print_exc()
-        return None, _length
-
-
-def add_html_format(text_list):
-    new_text_list = []
-    for t in text_list:
-        html_t = "<!DOCTYPE HTML>\n"
-        html_t += '<head><meta charset="UTF-8"></head>\n'
-        html_t += "<body>\n"
-        html_t += t
-        html_t += "\n</body>\n"
-        new_text_list.append(html_t)
-    return new_text_list
-
-
-if get_platform() == "Windows":
-    globals().update({"time_out": 1000})
-else:
-    globals().update({"time_out": 300})
-
-
 # @timeout_decorator.timeout(100, timeout_exception=TimeoutError)
 # @timeout(globals().get("time_out"), timeout_exception=TimeoutError, use_signals=False)
 def unique_temp_file_process(stream, _type, _md5, time_out=300):
@@ -633,6 +525,133 @@ def _convert():
 
 
 def convert(data, ocr_model, otr_model):
+    """
+    接口返回值:
+    {[str], 1}: 处理成功
+    {[-1], 0}: 逻辑处理错误
+    {[-2], 0}: 接口调用错误
+    {[-3], 1}: 文件格式错误,无法打开
+    {[-4], 0}: 各类文件调用第三方包读取超时
+    {[-5], 0}: 整个转换过程超时
+    {[-6], 0}: 阿里云UDF队列超时
+    {[-7], 1}: 文件需密码,无法打开
+    :return: {"result_html": str([]), "result_text":str([]) "is_success": int}
+    """
+    _global._init()
+    _global.update({"md5": "1"+"0"*15})
+    set_flask_global()
+
+    log("into convert")
+    start_time = time.time()
+    _md5 = _global.get("md5")
+    _type = None
+    try:
+        _time = time.time()
+        # 模型加入全局变量
+        globals().update({"global_ocr_model": ocr_model})
+        globals().update({"global_otr_model": otr_model})
+
+        stream = base64.b64decode(data.get("file"))
+        _type = data.get("type")
+        _md5 = get_md5_from_bytes(stream)
+        _md5 = _md5[0]
+        _global.update({"md5": _md5})
+
+        if get_platform() == "Windows":
+            # 解除超时装饰器,直接访问原函数
+            # origin_unique_temp_file_process = unique_temp_file_process.__wrapped__
+            # text, swf_images = origin_unique_temp_file_process(stream, _type)
+            try:
+                text, swf_images = unique_temp_file_process(stream, _type, _md5)
+            except TimeoutError:
+                log("convert time out! 300 sec")
+                text = [-5]
+                swf_images = []
+        else:
+            # Linux 通过装饰器设置整个转换超时时间
+            try:
+                text, swf_images = unique_temp_file_process(stream, _type, _md5)
+            except TimeoutError:
+                log("convert time out! 300 sec")
+                text = [-5]
+                swf_images = []
+
+        still_success_code = [-3, -4, -7]
+        if judge_error_code(text):
+            if judge_error_code(text, still_success_code):
+                is_success = 1
+            else:
+                is_success = 0
+            log("md5: " + str(_md5)
+                + " finished result: " + str(text)
+                + " is_success: " + str(is_success) + " "
+                + str(_type) + " "
+                + " " + str(time.time() - start_time))
+            return {"result_html": [str(text[0])], "result_text": [str(text[0])],
+                    "is_success": is_success, "swf_images": str(swf_images)}
+
+        # 结果保存result.html
+        text_str = ""
+        for t in text:
+            text_str += t
+        to_html(os.path.dirname(os.path.abspath(__file__)) + "/../result.html", text_str)
+
+        # 取纯文本
+        only_text = []
+        for t in text:
+            new_t = BeautifulSoup(t, "lxml").get_text()
+            new_t = re.sub("\n", "", new_t)
+            only_text.append(new_t)
+
+        # 判断附件类型
+        classification = from_atc_interface(' '.join(only_text))
+        if judge_error_code(classification):
+            classification = [str(classification[0])]
+
+        # 判断长度,过长截取
+        text = cut_str(text, only_text)
+        only_text = cut_str(only_text, only_text)
+
+        if len(only_text) == 0:
+            only_text = [""]
+
+        if only_text[0] == '' and len(only_text) <= 1:
+            print({"finished result": ["", 0], "is_success": 1}, time.time() - start_time)
+            log("md5: " + str(_md5) + " "
+                + " finished result: ['', 0] is_success: 1 "
+                + str(_type) + " "
+                + str(time.time() - start_time))
+        else:
+            log("md5: " + str(_md5) +
+                " finished result: " + str(only_text)[:20] + " "
+                + str(len(str(text))) + " is_success: 1 "
+                + str(_type) + " "
+                + str(classification) + " "
+                + str(time.time() - start_time))
+
+        return {"result_html": text, "result_text": only_text,
+                "is_success": 1, "swf_images": str(swf_images),
+                "classification": classification}
+
+    except ConnectionError:
+        log("convert post has no data!" + " failed result: [-2] is_success: 0 "
+            + str(time.time() - start_time))
+        return {"result_html": ["-2"], "result_text": ["-2"],
+                "is_success": 0, "swf_images": str([]),
+                "classification": ""}
+    except Exception as e:
+        log("md5: " + str(_md5) + " failed result: [-1] is_success: 0 "
+            + str(_type) + " " +
+            str(time.time() - start_time))
+        traceback.print_exc()
+        return {"result_html": ["-1"], "result_text": ["-1"],
+                "is_success": 0, "swf_images": str([]),
+                "classification": ""}
+    finally:
+        log("finally")
+
+
+def convert_old(data, ocr_model, otr_model):
     """
     接口返回值:
     {[str], 1}: 处理成功

+ 135 - 95
format_convert/convert_image.py

@@ -15,7 +15,7 @@ import traceback
 import cv2
 from isr.pre_process import count_red_pixel
 from format_convert.utils import judge_error_code, add_div, LineTable, get_table_html, get_logger, log, \
-    memory_decorator, pil_resize
+    memory_decorator, pil_resize, np2bytes, ocr_cant_read
 from format_convert.convert_need_interface import from_otr_interface, from_ocr_interface, from_gpu_interface_redis, \
     from_idc_interface, from_isr_interface
 from format_convert.table_correct import get_rotated_image
@@ -84,130 +84,116 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
                 textbox_list.remove(_obj)
         return textbox_list
 
-    log("into image_preprocess")
-    try:
-        if image_np is None:
-            return []
-
-        # 整体分辨率限制
-        threshold = 2000
-        if image_np.shape[0] > threshold or image_np.shape[1] > threshold:
-            h, w = get_best_predict_size2(image_np, threshold=threshold)
-            log("global image resize " + str(image_np.shape[:2]) + " -> " + str(h) + "," + str(w))
-            image_np = pil_resize(image_np, h, w)
-
+    def idc_process(_image_np):
         # 图片倾斜校正,写入原来的图片路径
         # print("image_process", image_path)
-        g_r_i = get_rotated_image(image_np, image_path)
-        if judge_error_code(g_r_i):
-            if is_from_docx:
-                return []
-            else:
-                return g_r_i
-        image_np = cv2.imread(image_path)
-        image_np_copy = copy.deepcopy(image_np)
-        if image_np is None:
-            return []
+        # g_r_i = get_rotated_image(_image_np, image_path)
+        # if judge_error_code(g_r_i):
+        #     if is_from_docx:
+        #         return []
+        #     else:
+        #         return g_r_i
+        # _image_np = cv2.imread(image_path)
+        # if _image_np is None:
+        #     return []
+        # return _image_np
 
-        # if image_np is None:
+        # if _image_np is None:
         #     return []
-        #
-        # # idc模型实现图片倾斜校正
-        # image_resize = pil_resize(image_np, 640, 640)
+
+        # idc模型实现图片倾斜校正
+        h, w = get_best_predict_size2(_image_np, 1080)
+        image_resize = pil_resize(_image_np, h, w)
         # image_resize_path = image_path.split(".")[0] + "_resize_idc." + image_path.split(".")[-1]
         # cv2.imwrite(image_resize_path, image_resize)
-        #
+
         # with open(image_resize_path, "rb") as f:
         #     image_bytes = f.read()
-        # angle = from_idc_interface(image_bytes)
-        # if judge_error_code(angle):
-        #     if is_from_docx:
-        #         return []
-        #     else:
-        #         return angle
-        # # 根据角度旋转
-        # image_pil = Image.fromarray(image_np)
-        # image_np = np.array(image_pil.rotate(angle, expand=1))
-        # # 写入
+        image_bytes = np2bytes(image_resize)
+        angle = from_idc_interface(image_bytes)
+        if judge_error_code(angle):
+            if is_from_docx:
+                return []
+            else:
+                return angle
+        # 根据角度旋转
+        image_pil = Image.fromarray(_image_np)
+        _image_np = np.array(image_pil.rotate(angle, expand=1))
+        # 写入
         # idc_path = image_path.split(".")[0] + "_idc." + image_path.split(".")[-1]
         # cv2.imwrite(idc_path, image_np)
+        return _image_np
 
+    def isr_process(_image_np):
+        image_np_copy = copy.deepcopy(_image_np)
         # isr模型去除印章
         _isr_time = time.time()
-        if count_red_pixel(image_np):
+        if count_red_pixel(_image_np):
             # 红色像素达到一定值才过模型
             with open(image_path, "rb") as f:
                 image_bytes = f.read()
-            image_np = from_isr_interface(image_bytes)
-            if judge_error_code(image_np):
+            _image_np = from_isr_interface(image_bytes)
+            if judge_error_code(_image_np):
                 if is_from_docx:
                     return []
                 else:
-                    return image_np
+                    return _image_np
             # [1]代表检测不到印章,直接返回
-            if isinstance(image_np, list) and image_np == [1]:
+            if isinstance(_image_np, list) and _image_np == [1]:
                 log("no seals detected!")
-                image_np = image_np_copy
+                _image_np = image_np_copy
             else:
                 isr_path = image_path.split(".")[0] + "_isr." + image_path.split(".")[-1]
-                cv2.imwrite(isr_path, image_np)
+                cv2.imwrite(isr_path, _image_np)
         log("isr total time "+str(time.time()-_isr_time))
+        return _image_np
 
+    def ocr_process(_image_np):
+        # ocr图片过大内存溢出,需resize
+        start_time = time.time()
+        # 调用ocr模型接口
+        image_bytes = np2bytes(_image_np)
+        text_list, bbox_list = from_ocr_interface(image_bytes, is_table=True)
+        if judge_error_code(text_list):
+            return text_list, text_list
+
+        for i in range(len(bbox_list)):
+            point = bbox_list[i]
+            bbox_list[i] = [[int(point[0][0]), int(point[0][1])],
+                            [int(point[1][0]), int(point[1][1])],
+                            [int(point[2][0]), int(point[2][1])],
+                            [int(point[3][0]), int(point[3][1])]]
+        return text_list, bbox_list
+
+    def otr_process(_image_np):
         # otr模型识别表格,需要图片resize成模型所需大小, 写入另一个路径
-        best_h, best_w = get_best_predict_size(image_np)
-        # image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
-        image_resize = pil_resize(image_np, best_h, best_w)
-        image_resize_path = image_path.split(".")[0] + "_resize_otr." + image_path.split(".")[-1]
-        cv2.imwrite(image_resize_path, image_resize)
+        best_h, best_w = get_best_predict_size(_image_np)
+        image_resize = pil_resize(_image_np, best_h, best_w)
+        # image_resize_path = image_path.split(".")[0] + "_resize_otr." + image_path.split(".")[-1]
+        # cv2.imwrite(image_resize_path, image_resize)
 
         # 调用otr模型接口
-        with open(image_resize_path, "rb") as f:
-            image_bytes = f.read()
+        # with open(image_resize_path, "rb") as f:
+        #     image_bytes = f.read()
+        image_bytes = np2bytes(image_resize)
         list_line = from_otr_interface(image_bytes, is_from_pdf)
         if judge_error_code(list_line):
-            return list_line
+            if is_from_docx:
+                return []
+            else:
+                return list_line
 
         # otr resize后得到的bbox根据比例还原
         start_time = time.time()
-        ratio = (image_np.shape[0]/best_h, image_np.shape[1]/best_w)
+        ratio = (_image_np.shape[0]/best_h, _image_np.shape[1]/best_w)
         for i in range(len(list_line)):
             point = list_line[i]
             list_line[i] = [int(point[0]*ratio[1]), int(point[1]*ratio[0]),
                             int(point[2]*ratio[1]), int(point[3]*ratio[0])]
         log("otr resize bbox recover " + str(time.time()-start_time))
+        return list_line
 
-        # ocr图片过大内存溢出,需resize
-        start_time = time.time()
-        threshold = 3000
-        ocr_resize_flag = 0
-        if image_np.shape[0] >= threshold or image_np.shape[1] >= threshold:
-            ocr_resize_flag = 1
-            best_h, best_w = get_best_predict_size2(image_np, threshold)
-            # image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
-            image_resize = pil_resize(image_np, best_h, best_w)
-            image_resize_path = image_path.split(".")[0] + "_resize_ocr." + image_path.split(".")[-1]
-            cv2.imwrite(image_resize_path, image_resize)
-        log("ocr resize before " + str(time.time()-start_time))
-
-        # 调用ocr模型接口
-        with open(image_resize_path, "rb") as f:
-            image_bytes = f.read()
-        text_list, bbox_list = from_ocr_interface(image_bytes, is_table=True)
-        if judge_error_code(text_list):
-            return text_list
-
-        # ocr resize后的bbox还原
-        if ocr_resize_flag:
-            ratio = (image_np.shape[0]/best_h, image_np.shape[1]/best_w)
-        else:
-            ratio = (1, 1)
-        for i in range(len(bbox_list)):
-            point = bbox_list[i]
-            bbox_list[i] = [[int(point[0][0]*ratio[1]), int(point[0][1]*ratio[0])],
-                            [int(point[1][0]*ratio[1]), int(point[1][1]*ratio[0])],
-                            [int(point[2][0]*ratio[1]), int(point[2][1]*ratio[0])],
-                            [int(point[3][0]*ratio[1]), int(point[3][1]*ratio[0])]]
-
+    def table_process(list_line, text_list, bbox_list):
         # 调用现成方法形成表格
         try:
             from format_convert.convert_tree import TableLine
@@ -229,17 +215,71 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
 
             # 合并同一行textbox
             list_text_boxes = merge_textbox(list_text_boxes, obj_in_table)
-
-            obj_list = []
-            for table in tables:
-                obj_list.append(_Table(table["table"], table["bbox"]))
-            for text_box in list_text_boxes:
-                if text_box not in obj_in_table:
-                    obj_list.append(_Sentence(text_box.get_text(), text_box.bbox))
-            return obj_list
+            return list_text_boxes, tables, obj_in_table
         except:
             traceback.print_exc()
-            return [-8]
+            return [-8], [-8], [-8]
+
+    log("into image_preprocess")
+    try:
+        if image_np is None:
+            return []
+
+        # 整体分辨率限制
+        threshold = 2000
+        if image_np.shape[0] > threshold or image_np.shape[1] > threshold:
+            h, w = get_best_predict_size2(image_np, threshold=threshold)
+            log("global image resize " + str(image_np.shape[:2]) + " -> " + str(h) + "," + str(w))
+            image_np = pil_resize(image_np, h, w)
+
+        # 印章去除
+        image_np = isr_process(image_np)
+        if isinstance(image_np, list):
+            return image_np
+
+        # 文字识别
+        text_list, box_list = ocr_process(image_np)
+        if judge_error_code(text_list):
+            return text_list
+
+        # 判断ocr识别是否正确
+        if ocr_cant_read(text_list, box_list):
+            # 方向分类
+            image_np = idc_process(image_np)
+            # cv2.imshow("idc_process", image_np)
+            # cv2.waitKey(0)
+            if isinstance(image_np, list):
+                return image_np
+
+            # 文字识别
+            text_list1, box_list_1 = ocr_process(image_np)
+            if judge_error_code(text_list1):
+                return text_list1
+
+            # 比较字数
+            # print("ocr process", len("".join(text_list)), len("".join(text_list1)))
+            if len("".join(text_list)) < len("".join(text_list1)):
+                text_list = text_list1
+                box_list = box_list_1
+
+        # 表格识别
+        line_list = otr_process(image_np)
+        if judge_error_code(line_list):
+            return line_list
+
+        # 表格生成
+        text_box_list, table_list, obj_in_table_list = table_process(line_list, text_list, box_list)
+        if judge_error_code(table_list):
+            return table_list
+
+        # 对象生成
+        obj_list = []
+        for table in table_list:
+            obj_list.append(_Table(table["table"], table["bbox"]))
+        for text_box in text_box_list:
+            if text_box not in obj_in_table_list:
+                obj_list.append(_Sentence(text_box.get_text(), text_box.bbox))
+        return obj_list
 
     except Exception as e:
         log("image_preprocess error")

+ 2 - 1
format_convert/convert_need_interface.py

@@ -26,9 +26,10 @@ from ocr.ocr_interface import ocr, OcrModels
 from otr.otr_interface import otr, OtrModels
 from format_convert.libreoffice_interface import office_convert
 import numpy as np
+from format_convert.max_compute_config import max_compute
 
 
-MAX_COMPUTE = False
+MAX_COMPUTE = max_compute
 
 if get_platform() == "Windows":
     FROM_REMOTE = False

+ 2 - 2
format_convert/interface.yml

@@ -25,8 +25,8 @@ MASTER:
     port_no: [1]
 
   IDC:
-    port_start:
-    port_no:
+    port_start: [18020]
+    port_no: [1]
 
   ISR:
     port_start: [18040]

+ 1 - 0
format_convert/max_compute_config.py

@@ -0,0 +1 @@
+max_compute = False

+ 10 - 1
format_convert/test_walk.py

@@ -5,4 +5,13 @@ for root, dirs, files in os.walk("./", topdown=False):
         file_list.append(os.path.join(root, name) + os.sep)
     for name in files:
         file_list.append(os.path.join(root, name))
-print(file_list)
+print(file_list)
+
+
+s = set()
+s.update("1231asdb我深大")
+s.update("g6712")
+
+print(len(s))
+
+print(len("".join(["sdas", "我是觉得", "111"])))

+ 41 - 0
format_convert/utils.py

@@ -2122,6 +2122,47 @@ def np2bytes(image_np):
     return img_bytes
 
 
+def ocr_cant_read(text_list, box_list):
+    """
+    判断ocr因为图片方向无法识别情况
+
+    :param text_list: 文字list
+    :param box_list: 文字框list
+    :return: bool
+    """
+    # 无文字及框
+    if not text_list or not box_list:
+        return True
+
+    # 根据bbox长宽比判断
+    box_cnt = 0
+    box_flag = 0
+    for box in box_list:
+        if abs(box[0][1] - box[2][1]) > abs(box[0][0] - box[2][0]):
+            box_cnt += 1
+    if box_cnt >= int(len(box_list) / 2):
+        box_flag = 1
+
+    # 根据识别字数判断
+    charac_flag = 0
+    charac_set = set()
+    for text in text_list:
+        charac_set.update(text)
+    if len(charac_set) < 40:
+        charac_flag = 1
+
+    # 字数少
+    if charac_flag:
+        result = True
+    # 字数多但格子长
+    elif box_flag:
+        result = True
+    else:
+        result = False
+    log(result)
+    return result
+
+
 if __name__ == "__main__":
     # strs = r"D:\Project\temp\04384fcc9e8911ecbd2844f971944973\043876ca9e8911eca5e144f971944973_rar\1624114035529.jpeg"
     # print(slash_replace(strs))

+ 13 - 7
idc/idc_interface.py

@@ -6,9 +6,12 @@ import time
 import sys
 import traceback
 from glob import glob
-# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+from format_convert.max_compute_config import max_compute
 import tensorflow as tf
-MAX_COMPUTE = False
+
+MAX_COMPUTE = max_compute
 
 if not MAX_COMPUTE:
     # tensorflow 内存设置
@@ -29,7 +32,7 @@ if not MAX_COMPUTE:
         os.environ['CUDA_CACHE_DISABLE'] = str(0)
         gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
         sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+
 from format_convert import _global
 import cv2
 import numpy as np
@@ -56,7 +59,9 @@ def adjust_direction(image_np, model, if_return_angle=False):
     # image_np = pil_resize(image_np, image_shape[0], image_shape[1])
 
     # 获取合适的文字区域
-    result_list, image_np = get_text_region(image_np, image_shape, 1)
+    result_list, image_np = get_text_region(image_np, image_shape)
+    # cv2.imshow("get_text_region", image_np)
+    # cv2.waitKey(0)
     if not result_list:
         return None
     if len(image_np.shape) < 3:
@@ -96,6 +101,7 @@ def idc(data, model):
         if angle is None:
             angle = 0
         # print(time.time()-start_time)
+        log("idc angle " + str(angle))
         return {"angle": angle}
     except TimeoutError:
         return {"angle": [-5]}
@@ -193,9 +199,9 @@ def test_idc_model(from_remote=False):
                 img = np.array(img.rotate(angle, expand=1))
                 print("angle", angle)
                 print(img.shape)
-                cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
-                cv2.imshow("img", img)
-                cv2.waitKey(0)
+                # cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
+                # cv2.imshow("img", img)
+                # cv2.waitKey(0)
             # print(result)
 
 

+ 1 - 1
idc/model.py

@@ -19,7 +19,7 @@ def direction_model(input_shape, output_shape):
     print(input_shape, output_shape)
     # model = mobile_net_v3_tiny(input_shape, output_shape)
     # model = fpn(input_shape, output_shape)
-    model.summary(line_length=100)
+    # model.summary(line_length=100)
     return model
 
 

BIN
idc/models/cnn.h5


+ 58 - 38
idc/pre_process.py

@@ -35,7 +35,7 @@ def get_img_label(img_np, size, cls_num=4):
     # cv2.imshow("origin", img_np)
 
     # 获取合适的文字区域
-    result_list, img_np = get_text_region(img_np, size, 1)
+    result_list, img_np = get_text_region(img_np, size)
     # print(len(result_list), img_np.shape)
     if not result_list:
         return []
@@ -59,7 +59,7 @@ def get_img_label(img_np, size, cls_num=4):
     return img_label_list
 
 
-def get_text_region(img_np, size, top_n=1):
+def get_text_region(img_np, size):
     img_np = remove_black_border(img_np)
     origin_h, origin_w = img_np.shape[:2]
     gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
@@ -103,41 +103,54 @@ def get_text_region(img_np, size, top_n=1):
         box = cv2.boxPoints(rect)
         box = np.int0(box)
         box = box.tolist()
-        box.sort(key=lambda x: (x[0], x[1]))
-        if box[0][0] > box[3][0]:
-            temp = box[0][0]
-            box[0][0] = box[3][0]
-            box[3][0] = temp
-        if box[0][1] > box[3][1]:
-            temp = box[0][1]
-            box[0][1] = box[3][1]
-            box[3][1] = temp
-        region_list.append(box)
-
-    # dilation = np.expand_dims(dilation, axis=-1)
-    # dilation = np.concatenate([dilation, dilation, dilation], axis=-1)
+        new_box = [[10000, 10000], [0, 0]]
+        for p in box:
+            if p[0] < new_box[0][0]:
+                new_box[0][0] = p[0]
+            elif p[0] > new_box[1][0]:
+                new_box[1][0] = p[0]
+
+            if p[1] < new_box[0][1]:
+                new_box[0][1] = p[1]
+            elif p[1] > new_box[1][1]:
+                new_box[1][1] = p[1]
+
+        # box.sort(key=lambda x: (x[0], x[1]))
+        # if box[0][0] > box[3][0]:
+        #     temp = box[0][0]
+        #     box[0][0] = box[3][0]
+        #     box[3][0] = temp
+        # if box[0][1] > box[3][1]:
+        #     temp = box[0][1]
+        #     box[0][1] = box[3][1]
+        #     box[3][1] = temp
+        region_list.append(new_box)
+
+    dilation = np.expand_dims(dilation, axis=-1)
+    dilation = np.concatenate([dilation, dilation, dilation], axis=-1)
     # for box in region_list:
     # #     cv2.drawContours(dilation, [box], 0, (0, 255, 0), 2)
     #     cv2.rectangle(dilation, (box[0][0], box[0][1]), (box[3][0], box[3][1]), (0, 255, 0), 2)
 
-    region_list.sort(key=lambda x: abs((x[3][0] - x[0][0])*(x[3][1] - x[0][1])), reverse=True)
+    region_list.sort(key=lambda x: abs((x[1][0] - x[0][0])*(x[1][1] - x[0][1])), reverse=True)
+    # print("len(region_list)", len(region_list))
 
     # 筛选文字区域
     result_list = []
     h_scale = origin_h / h
     w_scale = origin_w / w
-    i = 0
     for box in region_list:
-        if i >= top_n:
-            break
+        # if i >= 20:
+        #     break
 
         p1 = box[0]
-        p2 = box[3]
+        p2 = box[1]
         # print(p1, p2, abs((p2[1]-p1[1])*(p2[1]-p1[1])), h*w)
 
         # 旋转的box忽略
         if p1[0] >= p2[0] or p1[1] >= p2[1]:
-            # print(p1[0], ">=", p2[0], p1[1], ">=", p2[1])
+            # print(box)
+            # print(p1[0], ">", p2[0], p1[1], ">", p2[1])
             continue
 
         # 太大的box忽略
@@ -146,16 +159,19 @@ def get_text_region(img_np, size, top_n=1):
             continue
 
         # 黑色点不够的忽略
-        if not count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]):
-            # print("black not enough")
-            continue
+        cnt_black = count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]) / (abs(p2[0] - p1[0])*abs(p2[1] - p1[1]))
+        # if cnt_black < abs(p2[0] - p1[0])*abs(p2[1] - p1[1])*0.1:
+        #     print("black not enough")
+        #     continue
+        # if not count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]):
+        #     # print("black not enough")
+        #     continue
 
         p1[1] = int(p1[1] * h_scale)
         p1[0] = int(p1[0] * w_scale)
         p2[1] = int(p2[1] * h_scale)
         p2[0] = int(p2[0] * w_scale)
-        result_list.append([p1, p2])
-        i += 1
+        result_list.append([p1, p2, cnt_black])
 
     # cv2.imshow("result", dilation)
     # cv2.waitKey(0)
@@ -163,14 +179,17 @@ def get_text_region(img_np, size, top_n=1):
     if not result_list:
         return [], None
 
+    result_list.sort(key=lambda x: x[2], reverse=True)
+    # for r in result_list:
+    #     print(r)
     # 裁剪
-    if top_n > 1:
-        result = random.sample(result_list, 1)[0]
-        height_start = result[0][1]
-        width_start = result[0][0]
-    else:
-        height_start = result_list[0][0][1]
-        width_start = result_list[0][0][0]
+    # if top_n > 1:
+    #     result = random.sample(result_list, 1)[0]
+    #     height_start = result[0][1]
+    #     width_start = result[0][0]
+    # else:
+    height_start = result_list[0][0][1]
+    width_start = result_list[0][0][0]
     height, width = size
     gray = gray[height_start:height_start+height, width_start:width_start+width]
 
@@ -334,13 +353,14 @@ def count_black(image_np):
 
         mask = cv2.inRange(image_np, lower_black, upper_black)
         cnt = np.sum(mask != 0)
+        return cnt
         # print("black count", cnt, image_np.shape[0]*image_np.shape[1])
-        if cnt >= image_np.shape[0]*image_np.shape[1]*0.3:
-            return True
-        else:
-            return False
+        # if cnt >= image_np.shape[0]*image_np.shape[1]*0.3:
+        #     return True
+        # else:
+        #     return False
     except:
-        return False
+        return 0
 
 
 def remove_black_border(img_np):

+ 14 - 9
isr/isr_interface.py

@@ -5,9 +5,11 @@ import time
 import sys
 import traceback
 # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+from format_convert.max_compute_config import max_compute
 import tensorflow as tf
 tf.compat.v1.disable_eager_execution()
-MAX_COMPUTE = False
+MAX_COMPUTE = max_compute
 
 if not MAX_COMPUTE:
     # tensorflow 内存设置
@@ -29,7 +31,6 @@ if not MAX_COMPUTE:
         gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
         sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
 
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert import _global
 import cv2
 import numpy as np
@@ -193,13 +194,17 @@ def _isr():
             globals().update({"global_isr_model": isr_model})
 
         # 检测+去除
-        result = isr(data, isr_yolo_model, isr_model)
-        result = result.get("image")
-        if judge_error_code(result):
-            return json.dumps({"image": result})
-        if isinstance(result, list) and result == [1]:
-            return json.dumps({"image": result})
-        img_replace = result
+        # result = isr(data, isr_yolo_model, isr_model)
+        # result = result.get("image")
+        # if judge_error_code(result):
+        #     return json.dumps({"image": result})
+        # if isinstance(result, list) and result == [1]:
+        #     return json.dumps({"image": result})
+        # img_replace = result
+
+        img_replace = isr(data, isr_yolo_model, isr_model).get("image")
+        if isinstance(img_replace, list):
+            return json.dumps({"image": img_replace})
 
         # numpy转为可序列化的string
         success, img_encode = cv2.imencode(".jpg", img_replace)

+ 7 - 28
ocr/ocr_interface.py

@@ -13,7 +13,7 @@ import requests
 import logging
 import numpy as np
 os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
-from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes
+from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes, bytes2np
 from flask import Flask, request
 from format_convert import _global
 
@@ -34,8 +34,6 @@ def _ocr():
             log("ocr no data!")
             return json.dumps({"text": str([-9]), "bbox": str([-9])})
         data = request.form.get("data")
-        # img_data = base64.b64decode(data)
-        # _md5 = get_md5_from_bytes(img_data)[0]
         _md5 = request.form.get("md5")
         _global.update({"md5": _md5})
         ocr_model = globals().get("global_ocr_model")
@@ -43,7 +41,6 @@ def _ocr():
             log("----------- init ocr_model ------------")
             ocr_model = OcrModels().get_model()
             globals().update({"global_ocr_model": ocr_model})
-        # text = picture2text(img_data, ocr_model)
         text = ocr(data, ocr_model)
         return json.dumps(text)
     except TimeoutError:
@@ -65,44 +62,24 @@ def ocr(data, ocr_model):
         return {"text": str([-5]), "bbox": str([-5])}
 
 
-flag = 0
 def picture2text(img_data, ocr_model):
     log("into ocr_interface picture2text")
     try:
-        start_time = time.time()
         # 二进制数据流转np.ndarray [np.uint8: 8位像素]
-        img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
-        # 将bgr转为rbg
-        try:
-            np_images = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-        except cv2.error as e:
-            if "src.empty()" in str(e):
-                log("ocr_interface picture2text image is empty!")
-                return {"text": str([]), "bbox": str([])}
-        # resize
-        # cv2.imshow("before resize", np_images)
-        # print("np_images.shape", np_images.shape)
-
-        # best_h, best_w = get_best_predict_size(np_images)
-        # np_images = cv2.resize(np_images, (best_w, best_h), interpolation=cv2.INTER_AREA)
-
-        # cv2.imshow("after resize", np_images)
-        # print("np_images.shape", np_images.shape)
-        # cv2.waitKey(0)
+        img = bytes2np(img_data)
 
         # 预测
-        results = ocr_model.ocr(np_images, det=True, rec=True, cls=False)
+        results = ocr_model.ocr(img, det=True, rec=True, cls=False)
 
         # 循环每张图片识别结果
         text_list = []
         bbox_list = []
         for line in results:
-            # print("ocr_interface line", line)
             text_list.append(line[-1][0])
             bbox_list.append(line[0])
 
         # 查看bbox
-        # img = np.zeros((np_images.shape[1], np_images.shape[0]), np.uint8)
+        # img = np.zeros((img.shape[1], img.shape[0]), np.uint8)
         # img.fill(255)
         # for box in bbox_list:
         #     print(box)
@@ -203,7 +180,7 @@ if __name__ == '__main__':
     # test_ocr_model(False)
     #
     # log("OCR running")
-    file_path = "C:/Users/Administrator/Desktop/test_image/error19.png"
+    file_path = "C:/Users/Administrator/Desktop/test_image/error3.png"
 
     with open(file_path, "rb") as f:
         file_bytes = f.read()
@@ -214,6 +191,8 @@ if __name__ == '__main__':
 
     text_list = eval(result.get("text"))
     box_list = eval(result.get("bbox"))
+    from format_convert.utils import ocr_cant_read
+    print(ocr_cant_read(text_list, box_list))
     print(text_list)
     #
     # new_list = []

+ 12 - 14
otr/otr_interface.py

@@ -2,12 +2,15 @@ import base64
 import json
 import multiprocessing as mp
 import os
+import sys
 import traceback
 # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
 # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+from format_convert.max_compute_config import max_compute
 import tensorflow as tf
 
-MAX_COMPUTE = False
+MAX_COMPUTE = max_compute
 
 if not MAX_COMPUTE:
     # tensorflow 内存设置
@@ -28,11 +31,9 @@ if not MAX_COMPUTE:
         os.environ['CUDA_CACHE_DISABLE'] = str(0)
         gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)
         sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
-import sys
-sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+
 import time
 import logging
-# from table_line import *
 import cv2
 import numpy as np
 from flask import Flask, request
@@ -43,6 +44,7 @@ from otr.table_line import get_points, get_split_line, get_points_row, \
     fix_inner, merge_line, fix_corner, delete_outline, table_net, table_line
 from format_convert import _global
 
+
 # 接口配置
 app = Flask(__name__)
 
@@ -66,14 +68,14 @@ def _otr():
 
         data = request.form.get("data")
         is_from_pdf = request.form.get("is_from_pdf")
-        img_data = base64.b64decode(data)
-        # _md5 = get_md5_from_bytes(img_data)[0]
         _md5 = request.form.get("md5")
         _global.update({"md5": _md5})
-        if is_from_pdf:
-            list_lines = line_detect(img_data, otr_model, prob=0.2)
-        else:
-            list_lines = line_detect(img_data, otr_model, prob=0.5)
+
+        list_lines = otr(data, otr_model, is_from_pdf)
+        # if is_from_pdf:
+        #     list_lines = line_detect(img_data, otr_model, prob=0.2)
+        # else:
+        #     list_lines = line_detect(img_data, otr_model, prob=0.5)
         return json.dumps(list_lines)
     except TimeoutError:
         return json.dumps({"list_line": str([-5])})
@@ -98,8 +100,6 @@ def otr(data, otr_model, is_from_pdf):
         raise TimeoutError
 
 
-flag = 0
-# model_path = "models/table-line.h5"
 def table_detect2(img_data, otr_model):
     log("into otr_interface table_detect")
     start_time = time.time()
@@ -402,8 +402,6 @@ def test_otr_model(from_remote=True):
     print(r)
 
 
-# otr_model = table_net((None, None, 3), 2)
-# otr_model.load_weights(model_path)
 if __name__ == '__main__':
     if len(sys.argv) == 2:
         port = int(sys.argv[1])