소스 검색

1.主接口增加超时时间参数
2.配置文件增加多GPU设置
3.pdf表格连接规则优化
4.pdf删除重复出现页眉页脚
5.pdf提取表格线优化
6.docx判断是否是网页格式
7.ocr增加只识别参数

fangjiasheng 1 년 전
부모
커밋
d435211e3c

+ 12 - 6
format_convert/convert.py

@@ -46,7 +46,7 @@ MAX_COMPUTE = max_compute
 if get_platform() == "Windows":
     globals().update({"time_out": 1000})
 else:
-    globals().update({"time_out": 6000})
+    globals().update({"time_out": 300})
 
 
 @memory_decorator
@@ -408,11 +408,17 @@ def _convert():
         _md5 = get_md5_from_bytes(stream)
         _md5 = _md5[0]
         _global.update({"md5": _md5})
+
         # 指定页码范围
         _page_no = data.get('page_no')
         # if _type not in ['pdf']:
         #     _page_no = None
 
+        # 指定timeout
+        _timeout = data.get('timeout')
+        if _timeout is not None:
+            globals().update({"time_out": _timeout})
+
         # 最终结果截取的最大字节数
         max_bytes = data.get("max_bytes")
 
@@ -569,6 +575,8 @@ def convert(data):
         _type = data.get("type")
         _md5 = get_md5_from_bytes(stream)
         _md5 = _md5[0]
+        _page_no = data.get('page_no')
+        max_bytes = data.get("max_bytes")
         _global.update({"md5": _md5})
 
         if get_platform() == "Windows":
@@ -576,7 +584,7 @@ def convert(data):
             # origin_unique_temp_file_process = unique_temp_file_process.__wrapped__
             # text, swf_images = origin_unique_temp_file_process(stream, _type)
             try:
-                text, swf_images = unique_temp_file_process(stream, _type, _md5)
+                text, swf_images = unique_temp_file_process(stream, _type, _md5, _page_no, time_out=globals().get('time_out'))
             except TimeoutError:
                 log("convert time out! 300 sec")
                 text = [-5]
@@ -584,7 +592,7 @@ def convert(data):
         else:
             # Linux 通过装饰器设置整个转换超时时间
             try:
-                text, swf_images = unique_temp_file_process(stream, _type, _md5)
+                text, swf_images = unique_temp_file_process(stream, _type, _md5, _page_no, time_out=globals().get('time_out'))
             except TimeoutError:
                 log("convert time out! 300 sec")
                 text = [-5]
@@ -624,7 +632,7 @@ def convert(data):
             classification = [str(classification[0])]
 
         # 判断长度,过长截取
-        text = cut_str(text, only_text)
+        text = cut_str(text, only_text, max_bytes)
         only_text = cut_str(only_text, only_text)
 
         if len(only_text) == 0:
@@ -873,8 +881,6 @@ if __name__ == '__main__':
     # log("my ip"+str(ip))
     # ip = "http://" + ip
     ip_port_dict = get_ip_port()
-    ip = "http://127.0.0.1"
-    processes = ip_port_dict.get(ip).get("convert_processes")
 
     set_flask_global()
 

+ 330 - 137
format_convert/convert_docx.py

@@ -1,128 +1,286 @@
-import inspect
 import os
 import sys
 sys.path.append(os.path.dirname(__file__) + "/../")
 from format_convert.convert_tree import _Document, _Sentence, _Page, _Image, _Table
-import logging
 import re
 import traceback
 import xml
 import zipfile
 import docx
-from format_convert.convert_image import picture2text
+from bs4 import BeautifulSoup
 from format_convert.utils import judge_error_code, add_div, get_logger, log, memory_decorator, get_garble_code
 from format_convert.wrapt_timeout_decorator import timeout
+from format_convert.convert_image import ImageConvert
 
 
 def docx2text():
     return
 
 
+def read_rel_image(document_xml_rels):
+    if not document_xml_rels:
+        return {}
+
+    # 获取映射文件里的关系 Id-Target
+    image_rel_dict = {}
+    for rel in document_xml_rels:
+        if 'Relationship' in str(rel):
+            _id = rel.get("Id")
+            _target = rel.get("Target")
+            _type = rel.get("Type")
+            if 'image' in _type:
+                image_rel_dict[_id] = _target
+    return image_rel_dict
+
+
+def read_no_start(numbering_xml):
+    """
+    读取编号组的起始值
+
+    :return:
+    """
+    if not numbering_xml:
+        return {}
+
+    # 获取虚拟-真实id映射关系
+    w_num_list = numbering_xml.getElementsByTagName("w:num")
+    abstract_real_id_dict = {}
+    for w_num in w_num_list:
+        w_num_id = w_num.getAttribute("w:numId")
+        w_abstract_num_id = w_num.getElementsByTagName('w:abstractNumId')[0].getAttribute("w:val")
+        abstract_real_id_dict[w_abstract_num_id] = w_num_id
+
+    # 获取虚拟id的开始编号
+    w_abstract_num_list = numbering_xml.getElementsByTagName("w:abstractNum")
+    abstract_id_level_dict = {}
+    for w_abstract_num in w_abstract_num_list:
+        w_abstract_num_id = w_abstract_num.getAttribute("w:abstractNumId")
+        w_lvl_list = w_abstract_num.getElementsByTagName("w:lvl")
+        level_start_dict = {}
+        for w_lvl in w_lvl_list:
+            w_ilvl_value = w_lvl.getAttribute('w:ilvl')
+            if w_lvl.getElementsByTagName("w:start"):
+                w_ilvl_start_num = w_lvl.getElementsByTagName("w:start")[0].getAttribute("w:val")
+                level_start_dict[int(w_ilvl_value)] = int(w_ilvl_start_num)
+        abstract_id_level_dict[w_abstract_num_id] = level_start_dict
+
+    # 映射回真实id
+    real_id_level_start_dict = {}
+    for abstract_id in abstract_real_id_dict.keys():
+        real_id = abstract_real_id_dict.get(abstract_id)
+        level_start_dict = abstract_id_level_dict.get(abstract_id)
+        if level_start_dict:
+            real_id_level_start_dict[int(real_id)] = level_start_dict
+
+    return real_id_level_start_dict
+
+
+def read_p_text(unique_type_dir, p_node, _last_node_level, _num_pr_dict, numbering_xml, document_xml_rels,
+                is_sdt=False):
+    """
+    读取w:p下的文本,包括编号
+
+    :param unique_type_dir:
+    :param p_node:
+    :param _last_node_level:
+    :param _num_pr_dict:
+    :param numbering_xml:
+    :param document_xml_rels:
+    :param is_sdt:
+    :return:
+    """
+    _text_list = []
+    _order_list = []
+
+    # 文本的编号(如果有编号的话)
+    text_no = ''
+
+    # 获取编号组的起始值
+    id_level_start_dict = read_no_start(numbering_xml)
+    # print('_num_pr_dict', _num_pr_dict)
+
+    # 提取编号 组-层级-序号
+    num_pr = p_node.getElementsByTagName("w:numPr")
+    if num_pr:
+        num_pr = num_pr[0]
+        if num_pr.getElementsByTagName("w:numId"):
+            group_id = int(num_pr.getElementsByTagName("w:numId")[0].getAttribute("w:val"))
+            if group_id >= 1:
+                node_level = num_pr.getElementsByTagName("w:ilvl")
+                if node_level:
+                    node_level = int(node_level[0].getAttribute("w:val"))
+                    # print('group_id', group_id, 'node_level', node_level, 'last_node_level', _last_node_level)
+                    if group_id in _num_pr_dict.keys():
+                        if node_level == 0 and node_level not in _num_pr_dict[group_id].keys():
+                            _num_pr_dict[group_id][node_level] = 1
+                        if _last_node_level != 0 and node_level < _last_node_level:
+                            # print('重置', 'group_id', group_id, 'last_node_level', last_node_level)
+                            # 需循环重置node_level到last_node_level之间的level
+                            for l in range(node_level+1, _last_node_level+1):
+                                _num_pr_dict[group_id][l] = 0
+                            if _num_pr_dict[group_id].get(node_level):
+                                _num_pr_dict[group_id][node_level] += 1
+                            else:
+                                pass
+                                # print('group_id, node_level', group_id, node_level)
+                        elif node_level in _num_pr_dict[group_id].keys():
+                            _num_pr_dict[group_id][node_level] += 1
+                        else:
+                            _num_pr_dict[group_id][node_level] = 1
+                    else:
+                        _num_pr_dict[group_id] = {node_level: 1}
+                    # print(num_pr_dict[group_id])
+                    for level in range(node_level+1):
+                        # 当前level下有多少个node
+                        if level not in _num_pr_dict[group_id]:
+                            if level not in id_level_start_dict[group_id]:
+                                continue
+                            else:
+                                level_node_cnt = id_level_start_dict[group_id][level]
+                        else:
+                            level_node_cnt = _num_pr_dict[group_id][level]
+
+                        if id_level_start_dict.get(group_id) and id_level_start_dict.get(group_id).get(level) and _num_pr_dict.get(group_id).get(level):
+                            start_no = id_level_start_dict.get(group_id).get(level)
+                            level_node_cnt += start_no - 1
+                        # print('level_node_cnt', level_node_cnt)
+                        text_no += str(level_node_cnt) + '.'
+                        # print('text_no', text_no)
+                    _last_node_level = node_level
+
+    # text = p_node.getElementsByTagName("w:t")
+    # picture = p_node.getElementsByTagName("wp:docPr")
+    # if text:
+    #     _order_list.append("w:t")
+    #     temp_text = ""
+    #     if is_sdt and len(text) == 2:
+    #         if len(text[0].childNodes) > 0 and len(text[1].childNodes) > 0:
+    #             temp_text += text[0].childNodes[0].nodeValue + '.'*20 + text[1].childNodes[0].nodeValue
+    #     else:
+    #         for t in text:
+    #             if len(t.childNodes) > 0:
+    #                 temp_text += t.childNodes[0].nodeValue
+    #             else:
+    #                 continue
+    #     if text_no:
+    #         temp_text = text_no + ' ' + temp_text
+    #     _text_list.append(temp_text)
+    # # 只有序号
+    # elif len(text_no) >= 2:
+    #     _text_list.append(text_no[:-1])
+    #
+    # if picture:
+    #     _order_list.append("wp:docPr")
+    #
+    # for line1 in p_node.childNodes:
+    #     if "w:r" in str(line1):
+    #         picture1 = line1.getElementsByTagName("w:pict")
+    #         if picture1:
+    #             _order_list.append("wp:docPr")
+
+    p_node_text = ''
+    has_html = False
+    # 编号先加上
+    if text_no:
+        p_node_text += text_no
+    text = p_node.getElementsByTagName("w:t")
+    # 目录页单特殊生成
+    if is_sdt and len(text) == 2:
+        p_node_text += text[0].childNodes[0].nodeValue + '.'*20 + text[1].childNodes[0].nodeValue
+    # 正常页面
+    else:
+        image_rel_dict = read_rel_image(document_xml_rels)
+        p_node_all = p_node.getElementsByTagName("*")
+        for node in p_node_all:
+            # 文本
+            if "w:t" in str(node).split(' '):
+                if node.childNodes:
+                    p_node_text += node.childNodes[0].nodeValue
+
+            # 图片,提前识别,不做成Image对象放入Page了
+            elif "a:blip" in str(node).split(' '):
+                _id = node.getAttribute("r:embed")
+                image_path = image_rel_dict.get(_id)
+                if image_path:
+                    image_path = unique_type_dir + 'word/' + image_path
+                    image_convert = ImageConvert(image_path, '')
+                    image_html = image_convert.get_html()[0]
+                    if isinstance(image_html, int):
+                        image_html = ''
+                    p_node_text += image_html
+                    has_html = True
+
+    # 只有编号
+    if len(p_node_text) > 0 and p_node_text == text_no:
+        p_node_text = p_node_text[:-1]
+
+    _text_list.append(p_node_text)
+    if has_html:
+        _order_list.append('w:t html')
+    else:
+        _order_list.append('w:t')
+    return _text_list, _order_list, _num_pr_dict, _last_node_level
+
+
 @timeout(50, timeout_exception=TimeoutError)
-def read_xml_order(path, save_path):
+def read_xml_order(unique_type_dir, document_xml, numbering_xml, document_xml_rels):
     log("into read_xml_order")
     try:
-        try:
-            f = zipfile.ZipFile(path)
-            for file in f.namelist():
-                if "word/document.xml" == str(file):
-                    f.extract(file, save_path)
-            f.close()
-        except Exception as e:
-            log("docx format error!")
-            return [-3]
-
-        try:
-            collection = xml_analyze(save_path + "word/document.xml")
-        except TimeoutError:
-            log("xml_analyze timeout")
-            return [-4]
-
-        body = collection.getElementsByTagName("w:body")[0]
+        body = document_xml.getElementsByTagName("w:body")[0]
         order_list = []
         text_list = []
         # 编号组记录
         num_pr_dict = {}
         last_node_level = 0
         for line in body.childNodes:
-            # print(str(line))
+            # 普通文本
             if "w:p" in str(line):
-                # 文本的编号(如果有编号的话)
-                text_no = ''
-                # 提取编号 组-层级-序号
-                num_pr = line.getElementsByTagName("w:numPr")
-                if num_pr:
-                    num_pr = num_pr[0]
-                    group_id = int(num_pr.getElementsByTagName("w:numId")[0].getAttribute("w:val"))
-                    if group_id >= 1:
-                        node_level = num_pr.getElementsByTagName("w:ilvl")
-                        if node_level:
-                            node_level = int(node_level[0].getAttribute("w:val"))
-                            # print('node_level', node_level, 'last_node_level', last_node_level)
-                            if group_id in num_pr_dict.keys():
-                                if last_node_level != 0 and node_level < last_node_level:
-                                    # print('重置', 'group_id', group_id, 'last_node_level', last_node_level)
-                                    # 需循环重置node_level到last_node_level之间的level
-                                    for l in range(node_level+1, last_node_level+1):
-                                        num_pr_dict[group_id][l] = 0
-                                    num_pr_dict[group_id][node_level] += 1
-                                elif node_level in num_pr_dict[group_id].keys():
-                                    num_pr_dict[group_id][node_level] += 1
-                                else:
-                                    num_pr_dict[group_id][node_level] = 1
-                            else:
-                                num_pr_dict[group_id] = {node_level: 1}
-                            # print(num_pr_dict[group_id])
-                            for level in range(node_level+1):
-                                # 当前level下有多少个node
-                                if level not in num_pr_dict[group_id]:
-                                    continue
-                                level_node_cnt = num_pr_dict[group_id][level]
-                                # print('level_node_cnt', level_node_cnt)
-                                text_no += str(level_node_cnt) + '.'
-                            last_node_level = node_level
-                            # print('read_xml_order text_no', text_no)
-
-                text = line.getElementsByTagName("w:t")
-                picture = line.getElementsByTagName("wp:docPr")
-                if text:
-                    order_list.append("w:t")
-                    temp_text = ""
-                    for t in text:
-                        if len(t.childNodes) > 0:
-                            temp_text += t.childNodes[0].nodeValue
-                        else:
-                            continue
-                    if text_no:
-                        temp_text = text_no + ' ' + temp_text
-                    text_list.append(temp_text)
-                if picture:
-                    order_list.append("wp:docPr")
-
-                for line1 in line.childNodes:
-                    if "w:r" in str(line1):
-                        # print("read_xml_order", "w:r")
-                        picture1 = line1.getElementsByTagName("w:pict")
-                        if picture1:
-                            order_list.append("wp:docPr")
-
-            if "w:tbl" in str(line):
+                t_list, o_list, num_pr_dict, last_node_level = read_p_text(unique_type_dir,
+                                                                           line,
+                                                                           last_node_level,
+                                                                           num_pr_dict,
+                                                                           numbering_xml,
+                                                                           document_xml_rels)
+                text_list += t_list
+                order_list += o_list
+
+            # 目录索引
+            elif "w:sdt" in str(line):
+                sdt = line
+                for sdt_child in sdt.childNodes:
+                    if "w:sdtContent" in str(sdt_child):
+                        sdt_content = sdt_child
+                        for sdt_content_child in sdt_content.childNodes:
+                            if 'w:p' in str(sdt_content_child):
+                                t_list, o_list, num_pr_dict, last_node_level = read_p_text(unique_type_dir,
+                                                                                           sdt_content_child,
+                                                                                           last_node_level,
+                                                                                           num_pr_dict,
+                                                                                           numbering_xml,
+                                                                                           document_xml_rels,
+                                                                                           is_sdt=True)
+                                text_list += t_list
+                                order_list += o_list
+
+            elif "w:tbl" in str(line):
                 order_list.append("w:tbl")
         # read_xml_table(path, save_path)
         return [order_list, text_list]
     except Exception as e:
         log("read_xml_order error!")
-        print("read_xml_order", traceback.print_exc())
-        # log_traceback("read_xml_order")
+        traceback.print_exc()
         return [-1]
 
 
 @timeout(50, timeout_exception=TimeoutError)
-def read_xml_table(path, save_path):
+def read_xml_table(unique_type_dir, document_xml, numbering_xml, document_xml_rels):
     def recursion_read_table(table):
         table_text = '<table border="1">'
         tr_index = 0
         tr_text_list = []
+        last_node_level = 0
+        num_pr_dict = {}
+
         # 直接子节点用child表示,所有子节点用all表示
         for table_child in table.childNodes:
             if 'w:tr' in str(table_child):
@@ -164,10 +322,18 @@ def read_xml_table(path, save_path):
                                 tc_text += recursion_read_table(tc_child)
                             if 'w:p' in str(tc_child).split(' '):
                                 tc_p_all_nodes = tc_child.getElementsByTagName("*")
-                                for tc_p_all in tc_p_all_nodes:
-                                    if 'w:t' in str(tc_p_all).split(' '):
-                                        # w:t必须加childNodes[0]才能读文本
-                                        tc_text += tc_p_all.childNodes[0].nodeValue
+                                _t_list, _, num_pr_dict, last_node_level = read_p_text(unique_type_dir,
+                                                                                       tc_child,
+                                                                                       last_node_level,
+                                                                                       num_pr_dict,
+                                                                                       numbering_xml,
+                                                                                       document_xml_rels)
+                                # print('_t_list', _t_list)
+                                tc_text += ''.join(_t_list)
+                                # for tc_p_all in tc_p_all_nodes:
+                                #     if 'w:t' in str(tc_p_all).split(' '):
+                                #         # w:t必须加childNodes[0]才能读文本
+                                #         tc_text += tc_p_all.childNodes[0].nodeValue
                         # 结束该tc
                         table_text = table_text + tc_text + "</td>"
                         tc_index += 1
@@ -182,26 +348,7 @@ def read_xml_table(path, save_path):
 
     log("into read_xml_table")
     try:
-        try:
-            f = zipfile.ZipFile(path)
-            for file in f.namelist():
-                if "word/document.xml" == str(file):
-                    f.extract(file, save_path)
-            f.close()
-        except Exception as e:
-            # print("docx format error!", e)
-            log("docx format error!")
-            return [-3]
-
-        log("xml_analyze%s"%(save_path))
-        try:
-            collection = xml_analyze(save_path + "word/document.xml")
-        except TimeoutError:
-            log("xml_analyze timeout")
-            return [-4]
-
-        log("xml_analyze done")
-        body = collection.getElementsByTagName("w:body")[0]
+        body = document_xml.getElementsByTagName("w:body")[0]
         table_text_list = []
         body_nodes = body.childNodes
         for node in body_nodes:
@@ -218,27 +365,19 @@ def read_xml_table(path, save_path):
 
 
 @timeout(25, timeout_exception=TimeoutError)
-def xml_analyze(path):
+def parse_xml(path):
     # 解析xml
     DOMTree = xml.dom.minidom.parse(path)
     collection = DOMTree.documentElement
     return collection
 
 
-def read_docx_table(document):
-    table_text_list = []
-    for table in document.tables:
-        table_text = "<table>"
-        # print("==================")
-        for row in table.rows:
-            table_text += "<tr>"
-            for cell in row.cells:
-                table_text += "<td>" + re.sub("\s","",str(cell.text)) + "</td>"
-            table_text += "</tr>"
-        table_text += "</table>"
-        # print(table_text)
-        table_text_list.append(table_text)
-    return table_text_list
+@timeout(25, timeout_exception=TimeoutError)
+def parse_xml2(path):
+    # 解析xml
+    tree = xml.etree.ElementTree.parse(path)
+    root = tree.getroot()
+    return root
 
 
 class DocxConvert:
@@ -247,6 +386,39 @@ class DocxConvert:
         self.path = path
         self.unique_type_dir = unique_type_dir
 
+        # 解压docx
+        try:
+            f = zipfile.ZipFile(path)
+            for file in f.namelist():
+                if "word/" in str(file):
+                    f.extract(file, self.unique_type_dir)
+
+            f.close()
+        except Exception as e:
+            log("docx format error!")
+            self._doc.error_code = [-3]
+
+        # 读取内容
+        try:
+            self.document_xml = parse_xml(self.unique_type_dir + "word/document.xml")
+
+            if os.path.exists(self.unique_type_dir + "word/numbering.xml"):
+                self.numbering_xml = parse_xml(self.unique_type_dir + "word/numbering.xml")
+            else:
+                self.numbering_xml = []
+
+            if os.path.exists(self.unique_type_dir + "word/_rels/document.xml.rels"):
+                self.document_xml_rels = parse_xml2(self.unique_type_dir + "word/_rels/document.xml.rels")
+            else:
+                self.document_xml_rels = []
+        except FileNotFoundError:
+            # 找不到解压文件,就用html格式读
+            log('FileNotFoundError')
+            self._doc.error_code = None
+        except TimeoutError:
+            log("parse_xml timeout")
+            self._doc.error_code = [-4]
+
     @memory_decorator
     def init_package(self):
         # 各个包初始化
@@ -259,6 +431,26 @@ class DocxConvert:
             self._doc.error_code = [-3]
 
     def convert(self):
+        self._page = _Page(None, 0)
+
+        # 先判断特殊doc文件,可能是html文本
+        is_html_doc = False
+        try:
+            with open(self.path, 'r') as f:
+                html_str = f.read()
+            if re.search('<div|<html|<body|<head|<tr|<br|<table|<td', html_str):
+                soup = BeautifulSoup(html_str, 'lxml')
+                text = soup.text
+                is_html_doc = True
+        except:
+            pass
+
+        if is_html_doc:
+            _sen = _Sentence(text, (0, 0, 0, 0))
+            self._page.add_child(_sen)
+            self._doc.add_child(self._page)
+            return
+
         self.init_package()
         if self._doc.error_code is not None:
             return
@@ -269,8 +461,6 @@ class DocxConvert:
             return
         order_list, text_list = order_and_text_list
 
-        self._page = _Page(None, 0)
-
         # 乱码返回文件格式错误
         match1 = re.findall(get_garble_code(), ''.join(text_list))
         if len(match1) > 10:
@@ -298,12 +488,21 @@ class DocxConvert:
         doc_pr_cnt = 0
         for tag in order_list:
             bbox = (0, order_y, 0, 0)
+            if tag == "w:t html":
+                if len(text_list) > 0:
+                    _para = text_list.pop(0)
+                    _sen = _Sentence(_para, bbox)
+                    _sen.combine = False
+                    _sen.is_html = True
+                    self._page.add_child(_sen)
+
             if tag == "w:t":
                 if len(text_list) > 0:
                     _para = text_list.pop(0)
                     _sen = _Sentence(_para, bbox)
-                    _sen.combine=False
+                    _sen.combine = False
                     self._page.add_child(_sen)
+
             if tag == "wp:docPr":
                 if len(image_list) > 0:
                     temp_image_path = self.unique_type_dir + "docpr" + str(doc_pr_cnt) + ".png"
@@ -327,18 +526,10 @@ class DocxConvert:
             self._doc.error_code = self._page.error_code
         self._doc.add_child(self._page)
 
-    def get_paragraphs(self):
-        # 遍历段落
-        paragraph_list = []
-        for paragraph in self.docx.paragraphs:
-            if paragraph.text != "":
-                paragraph_list.append(paragraph.text)
-        return paragraph_list
-
     @memory_decorator
     def get_tables(self):
         # 遍历表
-        table_list = read_xml_table(self.path, self.unique_type_dir)
+        table_list = read_xml_table(self.unique_type_dir, self.document_xml, self.numbering_xml, self.document_xml_rels)
         return table_list
 
     def get_images(self):
@@ -367,13 +558,15 @@ class DocxConvert:
     @memory_decorator
     def get_orders(self):
         # 解析document.xml,获取文字顺序
-        order_and_text_list = read_xml_order(self.path, self.unique_type_dir)
+        order_and_text_list = read_xml_order(self.unique_type_dir, self.document_xml, self.numbering_xml, self.document_xml_rels)
         return order_and_text_list
 
     def get_doc_object(self):
         return self._doc
 
     def get_html(self):
+        if self._doc.error_code is not None:
+            return self._doc.error_code
         try:
             self.convert()
         except:

+ 12 - 7
format_convert/convert_image.py

@@ -169,7 +169,7 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
 
         # 调用ocr模型接口
         image_bytes = np2bytes(_image_np)
-        text_list, bbox_list = from_ocr_interface(image_bytes, is_table=True)
+        text_list, bbox_list = from_ocr_interface(image_bytes, is_table=1)
         if judge_error_code(text_list):
             return text_list, text_list
 
@@ -394,7 +394,7 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
 
             # n条以上分割线,有问题
             if len(split_index_list) == 0 or len(split_index_list) >= 2:
-                print('len(split_index_list)', len(split_index_list), split_index_list)
+                # print('len(split_index_list)', len(split_index_list), split_index_list)
                 continue
             else:
                 # 根据index拆开图片,重新ocr
@@ -411,13 +411,16 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
 
                     # ocr
                     split_image_bytes = np2bytes(split_image_np)
-                    text_list2, bbox_list2 = from_ocr_interface(split_image_bytes, is_table=True, only_rec=True)
-                    print('text_list2', text_list2)
-                    print('bbox_list2', split_bbox_list)
+                    text_list2, bbox_list2 = from_ocr_interface(split_image_bytes, is_table=1, only_rec=1)
+                    # print('text_list2', text_list2)
+                    # print('bbox_list2', split_bbox_list)
                     if judge_error_code(text_list2):
                         text2 = ''
                     else:
-                        text2 = text_list2[0]
+                        if text_list2:
+                            text2 = text_list2[0]
+                        else:
+                            text2 = ''
                     split_text_list.append(text2)
                 splited_textbox_list.append(textbox)
 
@@ -433,8 +436,10 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
     log("into image_preprocess")
     try:
         if image_np is None:
+            log("image_preprocess image_np is None")
             return []
         if image_np.shape[0] <= 20 or image_np.shape[1] <= 20:
+            log('image_np.shape[0] <= 20 or image_np.shape[1] <= 20')
             return []
 
         if not b_table_from_text:
@@ -1083,7 +1088,7 @@ def image_process_old(image_np, image_path, is_from_pdf=False, is_from_docx=Fals
         # 调用ocr模型接口
         with open(image_resize_path, "rb") as f:
             image_bytes = f.read()
-        text_list, bbox_list = from_ocr_interface(image_bytes, is_table=True)
+        text_list, bbox_list = from_ocr_interface(image_bytes, is_table=1)
         if judge_error_code(text_list):
             return text_list
 

+ 41 - 500
format_convert/convert_need_interface.py

@@ -54,11 +54,6 @@ else:
 if MAX_COMPUTE:
     FROM_REMOTE = False
 
-# ip_port_dict = get_ip_port()
-# ip = 'http://127.0.0.1'
-# ocr_port_list = ip_port_dict.get(ip).get("ocr")
-# otr_port_list = ip_port_dict.get(ip).get("otr")
-
 lock = multiprocessing.RLock()
 
 # 连接redis数据库
@@ -67,38 +62,6 @@ lock = multiprocessing.RLock()
 redis_db = None
 
 
-def _interface(_dict, time_out=60, retry_times=3):
-    try:
-        # 重试
-        model_type = _dict.get("model_type")
-        while retry_times:
-            ip_port = interface_pool(model_type)
-            if judge_error_code(ip_port):
-                return ip_port
-            _url = ip_port + "/" + model_type
-            # base64_stream = base64.b64encode(pickle.dumps(_dict))
-            r = json.loads(request_post(_url, {"data": json.dumps(_dict),
-                                               "model_type": model_type}, time_out=time_out))
-            log("get _interface return")
-            if type(r) == list:
-                # 接口连不上换个端口重试
-                if retry_times <= 1:
-                    return r
-                else:
-                    retry_times -= 1
-                    log("retry post _interface... left times " + str(retry_times) + " " + model_type)
-                    continue
-            if judge_error_code(r):
-                return r
-            return r
-            break
-
-    except TimeoutError:
-        return [-5]
-    except requests.exceptions.ConnectionError as e:
-        return [-2]
-
-
 def from_office_interface(src_path, dest_path, target_format, retry_times=1, from_remote=FROM_REMOTE):
     try:
         # Win10跳出超时装饰器
@@ -133,6 +96,7 @@ def from_office_interface(src_path, dest_path, target_format, retry_times=1, fro
                     file_bytes = f.read()
                 base64_stream = base64.b64encode(file_bytes)
                 start_time = time.time()
+                log('office _url ' + str(_url))
                 r = json.loads(request_post(_url, {"src_path": src_path,
                                                    "dest_path": dest_path,
                                                    "file": base64_stream,
@@ -178,7 +142,7 @@ def from_office_interface(src_path, dest_path, target_format, retry_times=1, fro
         return [-1]
 
 
-def from_ocr_interface(image_stream, is_table=False, only_rec=False, from_remote=FROM_REMOTE):
+def from_ocr_interface(image_stream, is_table=0, only_rec=0, from_remote=FROM_REMOTE):
     log("into from_ocr_interface")
     try:
         base64_stream = base64.b64encode(image_stream)
@@ -189,13 +153,6 @@ def from_ocr_interface(image_stream, is_table=False, only_rec=False, from_remote
                 retry_times_1 = 3
                 # 重试
                 while retry_times_1:
-                    # _ip = ip_pool("ocr", _random=True)
-                    # _port = port_pool("ocr", _random=True)
-                    # if _ip == interface_ip_list[1]:
-                    #     _port = ocr_port_list[0]
-                    # _ip, _port = interface_pool("ocr")
-                    # ip_port = _ip + ":" + _port
-                    # ip_port = from_schedule_interface("ocr")
                     ip_port = interface_pool_gunicorn("ocr")
                     if judge_error_code(ip_port):
                         return ip_port
@@ -205,14 +162,14 @@ def from_ocr_interface(image_stream, is_table=False, only_rec=False, from_remote
                                                        "only_rec": only_rec
                                                        },
                                                 time_out=60))
-                    log("get interface return")
+                    log("get ocr interface return")
                     if type(r) == list:
                         # 接口连不上换个端口重试
                         if retry_times_1 <= 1:
-                            # if is_table:
-                            return r, r
-                            # else:
-                            #     return r
+                            if is_table:
+                                return r, r
+                            else:
+                                return r
                         else:
                             retry_times_1 -= 1
                             log("retry post ocr_interface... left times " + str(retry_times_1))
@@ -226,15 +183,15 @@ def from_ocr_interface(image_stream, is_table=False, only_rec=False, from_remote
                     globals().update({"global_ocr_model": OcrModels().get_model()})
                 r = ocr(data=base64_stream, ocr_model=globals().get("global_ocr_model"), only_rec=only_rec)
         except TimeoutError:
-            # if is_table:
-            return [-5], [-5]
-            # else:
-            #     return [-5]
+            if is_table:
+                return [-5], [-5]
+            else:
+                return [-5]
         except requests.exceptions.ConnectionError as e:
-            # if is_table:
-            return [-2], [-2]
-            # else:
-            #     return [-2]
+            if is_table:
+                return [-2], [-2]
+            else:
+                return [-2]
 
         _dict = r
         text_list = eval(_dict.get("text"))
@@ -256,6 +213,8 @@ def from_ocr_interface(image_stream, is_table=False, only_rec=False, from_remote
             return text
     except Exception as e:
         log("from_ocr_interface error!")
+        log(str(traceback.print_exc()))
+        traceback.print_exc()
         # print("from_ocr_interface", e, global_type)
         if is_table:
             return [-1], [-1]
@@ -263,26 +222,6 @@ def from_ocr_interface(image_stream, is_table=False, only_rec=False, from_remote
             return [-1]
 
 
-def from_gpu_interface_flask(_dict, model_type, predictor_type):
-    log("into from_gpu_interface")
-    start_time = time.time()
-    try:
-        # 调用接口
-        _dict.update({"predictor_type": predictor_type, "model_type": model_type})
-        if model_type == "ocr":
-            use_zlib = True
-        else:
-            use_zlib = False
-        result = _interface(_dict, time_out=30, retry_times=2, use_zlib=use_zlib)
-        log("from_gpu_interface finish size " + str(sys.getsizeof(_dict)) + " time " + str(time.time()-start_time))
-        return result
-    except Exception as e:
-        log("from_gpu_interface error!")
-        log("from_gpu_interface failed " + str(time.time()-start_time))
-        traceback.print_exc()
-        return [-2]
-
-
 def from_gpu_interface_redis(_dict, model_type, predictor_type):
     log("into from_gpu_interface")
     start_time = time.time()
@@ -319,114 +258,6 @@ def from_gpu_interface_redis(_dict, model_type, predictor_type):
         return [-2]
 
 
-# def from_gpu_flask_sm(_dict, model_type, predictor_type):
-#     log("into from_gpu_share_memory")
-#     start_time = time.time()
-#     shm = None
-#     try:
-#         # 放入共享内存
-#         _time = time.time()
-#         np_data = _dict.get("inputs")
-#         shm = to_share_memory(np_data)
-#         log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
-#
-#         # 调用接口
-#         _time = time.time()
-#         _dict.pop("inputs")
-#         _dict.update({"predictor_type": predictor_type, "model_type": model_type,
-#                       "sm_name": shm.name, "sm_shape": np_data.shape,
-#                       "sm_dtype": str(np_data.dtype)})
-#         result = _interface(_dict, time_out=30, retry_times=2)
-#         log("_interface cost " + str(time.time()-_time))
-#
-#         # 读取共享内存
-#         _time = time.time()
-#         sm_name = result.get("sm_name")
-#         sm_shape = result.get("sm_shape")
-#         sm_dtype = result.get("sm_dtype")
-#         sm_dtype = get_np_type(sm_dtype)
-#         if sm_name:
-#             outputs = from_share_memory(sm_name, sm_shape, sm_dtype)
-#         else:
-#             log("from_share_memory failed!")
-#             raise Exception
-#         log("data from share memory " + sm_name + " " + str(time.time()-_time))
-#
-#         log("from_gpu_interface finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
-#         return {"preds": outputs, "gpu_time": result.get("gpu_time")}
-#     except Exception as e:
-#         log("from_gpu_interface failed " + str(time.time()-start_time))
-#         traceback.print_exc()
-#         return [-2]
-#     finally:
-#         # del b  # Unnecessary; merely emphasizing the array is no longer used
-#         if shm:
-#             try:
-#                 shm.close()
-#                 shm.unlink()
-#             except FileNotFoundError:
-#                 log("share memory " + shm.name + " not exists!")
-#             except Exception:
-#                 traceback.print_exc()
-#
-#
-# def from_gpu_share_memory(_dict, model_type, predictor_type):
-#     log("into from_gpu_share_memory")
-#     start_time = time.time()
-#     try:
-#         _dict.update({"model_type": model_type, "predictor_type": predictor_type})
-#         outputs, gpu_time = share_memory_pool(_dict)
-#         log("from_gpu_share_memory finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
-#         return {"preds": outputs, "gpu_time": float(gpu_time)}
-#     except Exception as e:
-#         log("from_gpu_interface failed " + str(time.time()-start_time))
-#         traceback.print_exc()
-#         return [-2]
-
-
-def from_otr_interface2(image_stream):
-    log("into from_otr_interface")
-    try:
-        base64_stream = base64.b64encode(image_stream)
-
-        # 调用接口
-        try:
-            if globals().get("global_otr_model") is None:
-                globals().update({"global_otr_model": OtrModels().get_model()})
-                print("=========== init otr model ===========")
-            r = otr(data=base64_stream, otr_model=globals().get("global_otr_model"))
-        except TimeoutError:
-            return [-5], [-5], [-5], [-5], [-5]
-        except requests.exceptions.ConnectionError as e:
-            log("from_otr_interface")
-            print("from_otr_interface", traceback.print_exc())
-            return [-2], [-2], [-2], [-2], [-2]
-
-        # 处理结果
-        _dict = r
-        points = eval(_dict.get("points"))
-        split_lines = eval(_dict.get("split_lines"))
-        bboxes = eval(_dict.get("bboxes"))
-        outline_points = eval(_dict.get("outline_points"))
-        lines = eval(_dict.get("lines"))
-        # print("from_otr_interface len(bboxes)", len(bboxes))
-        if points is None:
-            points = []
-        if split_lines is None:
-            split_lines = []
-        if bboxes is None:
-            bboxes = []
-        if outline_points is None:
-            outline_points = []
-        if lines is None:
-            lines = []
-        return points, split_lines, bboxes, outline_points, lines
-    except Exception as e:
-        log("from_otr_interface error!")
-        print("from_otr_interface", traceback.print_exc())
-        return [-1], [-1], [-1], [-1], [-1]
-
-
 def from_otr_interface(image_stream, is_from_pdf=False, from_remote=FROM_REMOTE):
     log("into from_otr_interface")
     try:
@@ -723,81 +554,13 @@ def from_yolo_interface(image_stream, from_remote=FROM_REMOTE):
         return [-11]
 
 
-# def from_schedule_interface(interface_type):
-#     try:
-#         _ip = "http://" + get_intranet_ip()
-#         _port = ip_port_dict.get(_ip).get("schedule")[0]
-#         _url = _ip + ":" + _port + "/schedule"
-#         data = {"interface_type": interface_type}
-#         result = json.loads(request_post(_url, data, time_out=10)).get("data")
-#         if judge_error_code(result):
-#             return result
-#         _ip, _port = result
-#         log("from_schedule_interface " + _ip + " " + _port)
-#         return _ip + ":" + _port
-#     except requests.exceptions.ConnectionError as e:
-#         log("from_schedule_interface ConnectionError")
-#         return [-2]
-#     except:
-#         log("from_schedule_interface error!")
-#         traceback.print_exc()
-#         return [-1]
-
-
-def interface_pool(interface_type, use_gunicorn=True):
-    ip_port_flag = _global.get("ip_port_flag")
-    ip_port_dict = _global.get("ip_port")
-    try:
-        if use_gunicorn:
-            _ip = "http://127.0.0.1"
-            _port = ip_port_dict.get(_ip).get(interface_type)[0]
-            ip_port = _ip + ":" + str(_port)
-            log(ip_port)
-            return ip_port
-
-        # 负载均衡, 选取ip
-        interface_load_list = []
-        for _ip in ip_port_flag.keys():
-            if ip_port_dict.get(_ip).get(interface_type):
-                load_scale = ip_port_flag.get(_ip).get(interface_type) / len(ip_port_dict.get(_ip).get(interface_type))
-                interface_load_list.append([_ip, load_scale])
-
-        if not interface_load_list:
-            raise NotFound
-        interface_load_list.sort(key=lambda x: x[-1])
-        _ip = interface_load_list[0][0]
-
-        # 负载均衡, 选取port
-        ip_type_cnt = ip_port_flag.get(_ip).get(interface_type)
-        ip_type_total = len(ip_port_dict.get(_ip).get(interface_type))
-        if ip_type_cnt == 0:
-            ip_type_cnt = random.randint(0, ip_type_total-1)
-        port_index = ip_type_cnt % ip_type_total
-        _port = ip_port_dict.get(_ip).get(interface_type)[port_index]
-
-        # 更新flag
-        current_flag = ip_type_cnt
-        if current_flag >= 10000:
-            ip_port_flag[_ip][interface_type] = 0
-        else:
-            ip_port_flag[_ip][interface_type] = current_flag + 1
-        _global.update({"ip_port_flag": ip_port_flag})
-        # log(str(_global.get("ip_port_flag")))
-
-        ip_port = _ip + ":" + str(_port)
-        log(ip_port)
-        return ip_port
-    except NotFound:
-        log("cannot read ip from config! checkout config")
-        return [-2]
-    except:
-        traceback.print_exc()
-        return [-1]
-
-
 def interface_pool_gunicorn(interface_type):
+    # if get_platform() == 'Windows':
+    #     set_flask_global()
+
     ip_port_flag_dict = _global.get("ip_port_flag")
     ip_port_dict = _global.get("ip_port")
+
     try:
         if ip_port_dict is None or ip_port_flag_dict is None:
             print('_global', _global.get_dict())
@@ -810,19 +573,34 @@ def interface_pool_gunicorn(interface_type):
         port_list = []
         for key in ip_port_flag_dict.keys():
             temp_port_list = get_args_from_config(ip_port_dict, key, interface_type)
+            # print('temp_port_list', temp_port_list)
             if not temp_port_list:
                 continue
+
+            # 该ip下的该接口总数量(可能有多gpu接口)
+            _port_list, _port_num_list, _ = temp_port_list[0]
+            # print('_port_num_list', _port_num_list)
+            total_port_num = sum(_port_num_list)
+            if total_port_num == 0:
+                continue
+
             interface_cnt = ip_port_flag_dict.get(key).get(interface_type)
-            if interface_cnt is not None and interface_cnt / len(temp_port_list[0]) < min_cnt:
+            if interface_cnt is not None and interface_cnt / total_port_num < min_cnt:
                 _ip = key
                 min_cnt = interface_cnt / len(temp_port_list[0])
-                port_list = temp_port_list[0]
+
+                # 选定ip,设置gpu的接口候选比例
+                gpu_port_list = []
+                for k in range(len(_port_list)):
+                    gpu_port_list += [_port_list[k]] * _port_num_list[k]
+                port_list = gpu_port_list
+                # port_list = temp_port_list[0]
 
         # 选取端口
         if interface_type == "office":
             if len(port_list) == 0:
                 raise ConnectionError
-
+            port_list = [str(port_list[k] + k) for k in range(len(port_list))]
             # 刚开始随机,后续求余
             if min_cnt == 0:
                 _port = port_list[random.randint(0, len(port_list)-1)]
@@ -830,8 +608,8 @@ def interface_pool_gunicorn(interface_type):
             else:
                 _port = port_list[interface_cnt % len(port_list)]
         else:
-            # 使用gunicorn则直接选第一个
-            _port = port_list[0]
+            # 使用gunicorn则随机选
+            _port = random.choice(port_list)
 
         # 更新flag
         if ip_port_flag_dict.get(_ip).get(interface_type) >= 10000:
@@ -855,243 +633,6 @@ def interface_pool_gunicorn(interface_type):
         return [-1]
 
 
-def interface_pool_gunicorn_old(interface_type):
-    ip_flag_list = _global.get("ip_flag")
-    ip_port_flag_dict = _global.get("ip_port_flag")
-    ip_port_dict = _global.get("ip_port")
-    try:
-        if ip_flag_list is None or ip_port_dict is None or ip_port_flag_dict is None:
-            raise NotFound
-
-        if interface_type == "office":
-            # _ip = "http://127.0.0.1"
-            _ip = get_using_ip()
-            # 选取端口
-            port_list = ip_port_dict.get(_ip).get("MASTER").get(interface_type)
-            ip_type_cnt = ip_port_flag_dict.get(_ip).get(interface_type)
-            if ip_type_cnt == 0:
-                _port = port_list[random.randint(0, len(port_list)-1)]
-            else:
-                _port = port_list[ip_type_cnt % len(port_list)]
-            # 更新flag
-            if ip_port_flag_dict.get(_ip).get(interface_type) >= 10000:
-                ip_port_flag_dict[_ip][interface_type] = 0
-            else:
-                ip_port_flag_dict[_ip][interface_type] += 1
-            _global.update({"ip_port_flag": ip_port_flag_dict})
-
-        else:
-            # 负载均衡, 选取ip
-            ip_flag_list.sort(key=lambda x: x[1])
-            if ip_flag_list[-1][1] == 0:
-                ip_index = random.randint(0, len(ip_flag_list)-1)
-            else:
-                ip_index = 0
-            _ip = ip_flag_list[ip_index][0]
-            if "master" in _ip:
-                port_index = 1
-            else:
-                port_index = 0
-            _ip = _ip.split("_")[0]
-            # 选取端口, 使用gunicorn则直接选第一个
-            # _port = ip_port_dict.get(_ip).get("MASTER").get(interface_type)[0]
-            log("_ip " + _ip)
-            log("interface_type " + interface_type)
-            port_list = get_args_from_config(ip_port_dict, _ip, interface_type)
-            log("port_list" + str(port_list))
-            if port_index >= len(port_list):
-                port_index = 0
-            _port = port_list[port_index][0]
-
-            # # 选取端口, 使用gunicorn则直接选第一个
-            # _ip = _ip.split("_")[0]
-            # port_list = get_args_from_config(ip_port_dict, _ip, interface_type)
-            # if
-            # print(port_list)
-            # _port = port_list[0][0]
-
-            # 更新flag
-            if ip_flag_list[ip_index][1] >= 10000:
-                ip_flag_list[ip_index][1] = 0
-            else:
-                ip_flag_list[ip_index][1] += + 1
-            _global.update({"ip_flag": ip_flag_list})
-
-        ip_port = _ip + ":" + str(_port)
-        log(ip_port)
-        return ip_port
-    except NotFound:
-        log("ip_flag or ip_port_dict is None! checkout config")
-        return [-2]
-    except:
-        traceback.print_exc()
-        return [-1]
-
-
-# def share_memory_pool(args_dict):
-#     np_data = args_dict.get("inputs")
-#     _type = args_dict.get("model_type")
-#     args_dict.update({"sm_shape": np_data.shape, "sm_dtype": str(np_data.dtype)})
-#
-#     if _type == 'ocr':
-#         port_list = ocr_port_list
-#     elif _type == 'otr':
-#         port_list = otr_port_list
-#     else:
-#         log("type error! only support ocr otr")
-#         raise Exception
-#
-#     # 循环判断是否有空的share memory
-#     empty_sm_list = None
-#     sm_list_name = ""
-#     while empty_sm_list is None:
-#         for p in port_list:
-#             sm_list_name = "sml_"+_type+"_"+str(p)
-#             sm_list = get_share_memory_list(sm_list_name)
-#             if sm_list[0] == "0":
-#                 lock.acquire(timeout=0.1)
-#                 if sm_list[0] == "0":
-#                     sm_list[0] = "1"
-#                     sm_list[-1] = "0"
-#                     empty_sm_list = sm_list
-#                     break
-#                 else:
-#                     continue
-#                 lock.release()
-#
-#     log(str(os.getppid()) + " empty_sm_list " + sm_list_name)
-#
-#     # numpy放入共享内存
-#     _time = time.time()
-#     release_share_memory(get_share_memory("psm_" + str(os.getpid())))
-#     shm = to_share_memory(np_data)
-#     log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
-#
-#     # 参数放入共享内存列表
-#     empty_sm_list[1] = args_dict.get("md5")
-#     empty_sm_list[2] = args_dict.get("model_type")
-#     empty_sm_list[3] = args_dict.get("predictor_type")
-#     empty_sm_list[4] = args_dict.get("args")
-#     empty_sm_list[5] = str(shm.name)
-#     empty_sm_list[6] = str(args_dict.get("sm_shape"))
-#     empty_sm_list[7] = args_dict.get("sm_dtype")
-#     empty_sm_list[-1] = "1"
-#     # log("empty_sm_list[7] " + empty_sm_list[7])
-#     close_share_memory_list(empty_sm_list)
-#
-#     # 循环判断是否完成
-#     finish_sm_list = get_share_memory_list(sm_list_name)
-#     while True:
-#         if finish_sm_list[-1] == "0":
-#             break
-#
-#     # 读取共享内存
-#     _time = time.time()
-#     sm_name = finish_sm_list[5]
-#     sm_shape = finish_sm_list[6]
-#     sm_shape = eval(sm_shape)
-#     sm_dtype = finish_sm_list[7]
-#     gpu_time = finish_sm_list[8]
-#     sm_dtype = get_np_type(sm_dtype)
-#     outputs = from_share_memory(sm_name, sm_shape, sm_dtype)
-#     log(args_dict.get("model_type") + " " + args_dict.get("predictor_type") + " outputs " + str(outputs.shape))
-#     log("data from share memory " + sm_name + " " + str(time.time()-_time))
-#
-#     # 释放
-#     release_share_memory(get_share_memory(sm_name))
-#
-#     # 重置share memory list
-#     finish_sm_list[-1] = "0"
-#     finish_sm_list[0] = "0"
-#
-#     close_share_memory_list(finish_sm_list)
-#     return outputs, gpu_time
-
-
-# def interface_pool(interface_type):
-#     try:
-#         ip_port_dict = _global.get("ip_port")
-#         ip_list = list(ip_port_dict.keys())
-#         _ip = random.choice(ip_list)
-#         if interface_type != 'office':
-#             _port = ip_port_dict.get(_ip).get(interface_type)[0]
-#         else:
-#             _port = random.choice(ip_port_dict.get(_ip).get(interface_type))
-#         log(_ip + ":" + _port)
-#         return _ip + ":" + _port
-#     except Exception as e:
-#         traceback.print_exc()
-#         return [-1]
-
-
-# def ip_pool(interface_type, _random=False):
-#     ip_flag_name = interface_type + '_ip_flag'
-#     ip_flag = globals().get(ip_flag_name)
-#     if ip_flag is None:
-#         if _random:
-#             _r = random.randint(0, len(interface_ip_list)-1)
-#             ip_flag = _r
-#             globals().update({ip_flag_name: ip_flag})
-#             ip_index = _r
-#         else:
-#             ip_flag = 0
-#             globals().update({ip_flag_name: ip_flag})
-#             ip_index = 0
-#     else:
-#         ip_index = ip_flag % len(interface_ip_list)
-#     ip_flag += 1
-#
-#     if ip_flag >= 10000:
-#         ip_flag = 0
-#     globals().update({ip_flag_name: ip_flag})
-#
-#     log("ip_pool " + interface_type + " " + str(ip_flag) + " " + str(interface_ip_list[ip_index]))
-#     return interface_ip_list[ip_index]
-#
-#
-# def port_pool(interface_type, _random=False):
-#     port_flag_name = interface_type + '_port_flag'
-#
-#     port_flag = globals().get(port_flag_name)
-#     if port_flag is None:
-#         if _random:
-#             if interface_type == "ocr":
-#                 _r = random.randint(0, len(ocr_port_list)-1)
-#             elif interface_type == "otr":
-#                 _r = random.randint(0, len(otr_port_list)-1)
-#             else:
-#                 _r = random.randint(0, len(soffice_port_list)-1)
-#             port_flag = _r
-#             globals().update({port_flag_name: port_flag})
-#             port_index = _r
-#         else:
-#             port_flag = 0
-#             globals().update({port_flag_name: port_flag})
-#             port_index = 0
-#     else:
-#         if interface_type == "ocr":
-#             port_index = port_flag % len(ocr_port_list)
-#         elif interface_type == "otr":
-#             port_index = port_flag % len(otr_port_list)
-#         else:
-#             port_index = port_flag % len(soffice_port_list)
-#     port_flag += 1
-#
-#     if port_flag >= 10000:
-#         port_flag = 0
-#     globals().update({port_flag_name: port_flag})
-#
-#     if interface_type == "ocr":
-#         log("port_pool " + interface_type + " " + str(port_flag) + " " + ocr_port_list[port_index])
-#         return ocr_port_list[port_index]
-#     elif interface_type == "otr":
-#         log("port_pool " + interface_type + " " + str(port_flag) + " " + otr_port_list[port_index])
-#         return otr_port_list[port_index]
-#     else:
-#         log("port_pool " + interface_type + " " + str(port_flag) + " " + soffice_port_list[port_index])
-#         return soffice_port_list[port_index]
-
-
 if __name__ == "__main__":
     _global._init()
     set_flask_global()

파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
+ 345 - 437
format_convert/convert_pdf.py


+ 59 - 43
format_convert/convert_test.py

@@ -5,13 +5,10 @@ import random
 import sys
 import time
 from glob import glob
-from multiprocessing import Process
-
-from bs4 import BeautifulSoup
-
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert.utils import get_platform, request_post, get_md5_from_bytes
 from format_convert.convert import to_html
+import multiprocessing as mp
 
 
 def test_one(p, page_no_range=None, from_remote=False):
@@ -22,13 +19,13 @@ def test_one(p, page_no_range=None, from_remote=False):
 
     _md5 = get_md5_from_bytes(file_bytes)
 
-    data = {"file": file_base64, "type": p.split(".")[-1], "filemd5": 100, 'page_no': page_no_range}
+    data = {"file": file_base64, "type": p.split(".")[-1], "filemd5": _md5, 'page_no': page_no_range}
     if from_remote:
-        # _url = 'http://121.46.18.113:15010/convert'
+        _url = 'http://121.46.18.113:15010/convert'
         # _url = 'http://192.168.2.103:15010/convert'
         # _url = 'http://192.168.2.102:15011/convert'
         # _url = 'http://172.16.160.65:15010/convert'
-        _url = 'http://127.0.0.1:15010/convert'
+        # _url = 'http://127.0.0.1:15010/convert'
         result = json.loads(request_post(_url, data, time_out=10000))
         text_str = ""
         for t in result.get("result_html"):
@@ -39,6 +36,7 @@ def test_one(p, page_no_range=None, from_remote=False):
         print("only support remote!")
 
     print(_md5)
+    print('第', page_no_range.split(',')[0], '页到第', page_no_range.split(',')[-1], '页')
     print("result_text", result.get("result_text")[0][:20])
     print("is_success", result.get("is_success"))
     print(time.time()-start_time)
@@ -57,47 +55,65 @@ def test_duplicate(path_list, process_no=None):
             test_one(p, from_remote=True)
 
 
+def test_maxcompute(p, page_no_range=None):
+    from format_convert import convert
+    start_time = time.time()
+    with open(p, "rb") as f:
+        file_bytes = f.read()
+    file_base64 = base64.b64encode(file_bytes)
+    _md5 = get_md5_from_bytes(file_bytes)
+
+    data = {"file": file_base64, "type": p.split(".")[-1], "filemd5": _md5, 'page_no': page_no_range}
+    result = convert.convert(data)
+    text_str = ""
+    for t in result.get("result_html"):
+        text_str += t
+    to_html(os.path.dirname(os.path.abspath(__file__)) + "/../result.html",
+            text_str)
+
+    print(_md5)
+    print('第', page_no_range.split(',')[0], '页到第', page_no_range.split(',')[-1], '页')
+    print("result_text", result.get("result_text")[0][:20])
+    print("is_success", result.get("is_success"))
+    print(time.time()-start_time)
+
+
 if __name__ == '__main__':
     if get_platform() == "Windows":
         # file_path = "C:/Users/Administrator/Desktop/2.png"
-        # file_path = "C:/Users/Administrator/Desktop/test_xls/merge_cell.xlsx"
-        # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_Interface/20210609202634853485.xlsx"
+        file_path = "C:/Users/Administrator/Desktop/test_xls/error4.xls"
+        # file_path = "C:/Users/Administrator/Desktop/test_doc/error5.doc"
+        # file_path = "D:/BIDI_DOC/比地_文档/1677829036789.pdf"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_ODPS/1624325845476.pdf"
-        # file_path = "C:/Users/Administrator/Downloads/20210508190133924ba.pdf"
-        # file_path = "C:/Users/Administrator/Desktop/test_doc/error8.doc"
-        # file_path = "C:/Users/Administrator/Desktop/test_image/error10.png"
+        # file_path = "C:/Users/Administrator/Downloads/1688432101601.xlsx"
+        # file_path = "C:/Users/Administrator/Desktop/test_doc/error14.docx"
+        # file_path = "C:/Users/Administrator/Desktop/test_image/error36.png"
         # file_path = "C:/Users/Administrator/Desktop/test_b_table/error1.png"
-        file_path = "C:/Users/Administrator/Desktop/test_pdf/error1.pdf"
+        # file_path = "C:/Users/Administrator/Desktop/test_pdf/表格连接error/error7.pdf"
         # file_path = "C:/save_b_table/0-0895e32470613dd7be1139eefd1342c4.png"
     else:
         file_path = "1660296734009.pdf"
-    test_one(file_path, page_no_range='13,14', from_remote=True)
-
-    # paths = glob("C:/Users/Administrator/Desktop/test_image/*")
-    # for file_path in paths:
-    #     test_one(file_path, from_remote=True)
-
-    # if get_platform() == "Windows":
-    #     # file_path_list = ["D:/BIDI_DOC/比地_文档/2022/Test_Interface/1623328459080.doc",
-    #     #                   "D:/BIDI_DOC/比地_文档/2022/Test_Interface/94961e1987d1090e.xls",
-    #     #                   "D:/BIDI_DOC/比地_文档/2022/Test_Interface/11111111.rar"]
-    #     # file_path_list = ["D:/BIDI_DOC/比地_文档/2022/Test_Interface/1623328459080.doc",
-    #     #                   "D:/BIDI_DOC/比地_文档/2022/Test_Interface/94961e1987d1090e.xls"]
-    #     # file_path_list = ["D:/BIDI_DOC/比地_文档/2022/Test_Interface/1623423836610.pdf"]
-    #     file_path_list = ["C:/Users/Administrator/Desktop/error16.jpg"]
-    # else:
-    #     file_path_list = ["1623423836610.pdf"]
-    # start_time = time.time()
-    # p_list = []
-    # for j in range(3):
-    #     p = Process(target=test_duplicate, args=(file_path_list, j, ))
-    #     p.start()
-    #     p_list.append(p)
-    # for p in p_list:
-    #     p.join()
-    # print("finish", time.time() - start_time)
-
-    # with open(file_path, 'r') as f:
-    #     t = f.read()
-    # soup = BeautifulSoup(t, 'lxml')
-    # print(soup.text)
+
+    test_one(file_path, page_no_range='1,-1', from_remote=True)
+
+    file_path = "C:/Users/Administrator/Downloads/"
+    # file_path = r"C:\Users\Administrator\Desktop\test_pdf\直接读表格线error/"
+    # file_path = r"C:\Users\Administrator\Desktop\test_pdf\表格连接error/"
+    test_pdf_list = [['6df7f2bd5e8cac99a15a6c012e0d82a8.pdf', '34,52'],
+                     ['ca6a86753400d6dd6a1b324c5678b7fb.pdf', '18,69'],
+                     ['a8380bf795c71caf8185fb11395df138.pdf', '27,38'],
+                     ['7fd2ce6b08d086c98158b6f2fa0293b0.pdf', '32,48'],
+                     ['dd1adb4dc2014c7abcf403ef15a01eb5.pdf', '2,12'],
+                     ['error50.pdf', '1,-1'],
+                     ['error59.pdf', '1,-1'],
+                     ['error51.pdf', '1,-1'],
+                     ['error7.pdf', '39,57'],
+                     ]
+    index = 1
+    # test_one(file_path+test_pdf_list[index][0], page_no_range=test_pdf_list[index][1], from_remote=True)
+
+
+    # 测试maxcompute模式
+    # _process = mp.Process(target=test_maxcompute, args=(file_path, '1,-1',))
+    # _process.start()
+    # _process.join()

+ 1 - 1
format_convert/convert_zip.py

@@ -7,7 +7,7 @@ sys.path.append(os.path.dirname(__file__) + "/../")
 from format_convert.convert_tree import _Document, _Page, _Sentence
 import logging
 import traceback
-import my_zipfile as zipfile
+import format_convert.my_zipfile as zipfile
 from format_convert import get_memory_info
 from format_convert.utils import get_platform, rename_inner_files, judge_error_code, judge_format, get_logger, log, \
     memory_decorator

+ 117 - 0
format_convert/interface_new.yml

@@ -0,0 +1,117 @@
+{
+  "MASTER": {
+    "ip": "http://0.0.0.0",
+
+    "path": {
+      "python": "/data/anaconda3/envs/convert3/bin/python",
+      "gunicorn": "/data/anaconda3/envs/convert3/bin/gunicorn",
+      "project": "/data/fangjiasheng/format_conversion_maxcompute/"
+    },
+
+    "convert": {
+      "port": [15010],
+      "port_num": [3],
+      "gpu": [-1]
+    },
+
+    "ocr": {
+      "port": [17000, 17001],
+      "port_num": [3, 1],
+      "gpu": [0, 1]
+    },
+
+    "otr": {
+      "port": [ 18000, 18001 ],
+      "port_num": [ 0, 2 ],
+      "gpu": [ 0, 1 ]
+    },
+
+    "idc": {
+      "port": [ 18020 ],
+      "port_num": [ 1 ],
+      "gpu": [ -1 ]
+    },
+
+    "isr": {
+      "port": [ 18040 ],
+      "port_num": [ 3 ],
+      "gpu": [ 0 ]
+    },
+
+    "atc": {
+      "port": [ 18060 ],
+      "port_num": [ 2 ],
+      "gpu": [ -1 ]
+    },
+
+    "yolo": {
+      "port": [ 18080 ],
+      "port_num": [ 2 ],
+      "gpu": [ 0 ]
+    },
+
+    "office": {
+      "port": [ 16000 ],
+      "port_num": [ 25 ],
+      "gpu": []
+    }
+  },
+
+  "SLAVE": {
+    "ip": "",
+
+    "path": {
+      "python": "/data/anaconda3/envs/convert3/bin/python",
+      "gunicorn": "/data/anaconda3/envs/convert3/bin/gunicorn",
+      "project": "/data/fangjiasheng/format_conversion_maxcompute/"
+    },
+
+    "convert": {
+      "port": [ 15010 ],
+      "port_num": [ 3 ],
+      "gpu": [ -1 ]
+    },
+
+    "ocr": {
+      "port": [ 17000, 17001 ],
+      "port_num": [ 3, 1 ],
+      "gpu": [ 0, 1 ]
+    },
+
+    "otr": {
+      "port": [ 18000, 18001 ],
+      "port_num": [ 0, 2 ],
+      "gpu": [ 0, 1 ]
+    },
+
+    "idc": {
+      "port": [ 18020 ],
+      "port_num": [ 1 ],
+      "gpu": [ -1 ]
+    },
+
+    "isr": {
+      "port": [ 18040 ],
+      "port_num": [ 3 ],
+      "gpu": [ 0 ]
+    },
+
+    "atc": {
+      "port": [ 18060 ],
+      "port_num": [ 2 ],
+      "gpu": [ -1 ]
+    },
+
+    "yolo": {
+      "port": [ 18080 ],
+      "port_num": [ 2 ],
+      "gpu": [ 0 ]
+    },
+
+    "office": {
+      "port": [ 16000 ],
+      "port_num": [ 25 ],
+      "gpu": []
+    }
+  }
+}

+ 0 - 4
format_convert/kill_all.py

@@ -9,12 +9,8 @@ import time
 ip_port_dict = get_ip_port()
 ip = get_using_ip()
 
-if ip == 'http://127.0.0.1':
-    ip = 'http://0.0.0.0'
-
 python_path = get_args_from_config(ip_port_dict, ip, "python_path")[0]
 project_path = get_args_from_config(ip_port_dict, ip, "project_path")[0]
-gunicorn_path = get_args_from_config(ip_port_dict, ip, "gunicorn_path")[0]
 
 
 def kill():

+ 204 - 135
format_convert/monitor_process_config.py

@@ -8,78 +8,131 @@ import psutil
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert.utils import get_ip_port, get_intranet_ip, get_args_from_config, get_all_ip, get_using_ip
 
-
+# 解析配置文件
 ip_port_dict = get_ip_port()
 ip = get_using_ip()
 print("local ip:", ip)
 
-if ip == 'http://127.0.0.1':
-    ip = 'http://0.0.0.0'
-
-# 获取各个参数
-convert_port_list = get_args_from_config(ip_port_dict, ip, "convert", "MASTER")
-if convert_port_list:
-    convert_port_list = convert_port_list[0]
-ocr_port_list = get_args_from_config(ip_port_dict, ip, "ocr")
-otr_port_list = get_args_from_config(ip_port_dict, ip, "otr")
-idc_port_list = get_args_from_config(ip_port_dict, ip, "idc")
-isr_port_list = get_args_from_config(ip_port_dict, ip, "isr")
-atc_port_list = get_args_from_config(ip_port_dict, ip, "atc")
-yolo_port_list = get_args_from_config(ip_port_dict, ip, "yolo")
-soffice_port_list = get_args_from_config(ip_port_dict, ip, "office", "MASTER")
-if soffice_port_list:
-    soffice_port_list = soffice_port_list[0]
-python_path_list = get_args_from_config(ip_port_dict, ip, "python_path")
-project_path_list = get_args_from_config(ip_port_dict, ip, "project_path")
-gunicorn_path_list = get_args_from_config(ip_port_dict, ip, "gunicorn_path")
+# 自定义输出
 std_out = " >>/convert.out 2>&1 &"
 std_out_gpu = " >>/gpu.out 2>&1 &"
 std_out_schedule = " >>/schedule.out 2>&1 &"
 
-print("convert_port_list", convert_port_list)
-print("ocr_port_list", ocr_port_list)
-print("otr_port_list", otr_port_list)
-print("idc_port_list", idc_port_list)
-print("isr_port_list", isr_port_list)
-print("atc_port_list", atc_port_list)
-print("yolo_port_list", yolo_port_list)
-print("soffice_port_list", soffice_port_list)
-
-# 根据port生成gunicorn语句
-ocr_comm_list = []
-otr_comm_list = []
-isr_comm_list = []
-idc_comm_list = []
-atc_comm_list = []
-yolo_comm_list = []
-for i in range(len(ocr_port_list)):
-    ocr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(ocr_port_list[i]))
-                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
-                         + project_path_list[i] + "/ocr ocr_interface:app" + std_out_gpu)
-for i in range(len(otr_port_list)):
-    otr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(otr_port_list[i]))
-                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
-                         + project_path_list[i] + "/otr otr_interface:app" + std_out_gpu)
-for i in range(len(idc_port_list)):
-    idc_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(idc_port_list[i]))
-                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
-                         + project_path_list[i] + "/idc idc_interface:app" + std_out_gpu)
-for i in range(len(isr_port_list)):
-    isr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(isr_port_list[i]))
-                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
-                         + project_path_list[i] + "/isr isr_interface:app" + std_out_gpu)
-for i in range(len(atc_port_list)):
-    atc_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(atc_port_list[i]))
-                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
-                         + project_path_list[i] + "/atc atc_interface:app" + std_out_gpu)
-for i in range(len(yolo_port_list)):
-    yolo_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(yolo_port_list[i]))
-                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
-                         + project_path_list[i] + "/botr/yolov8 yolo_interface:app" + std_out_gpu)
-
-convert_comm = "nohup " + gunicorn_path_list[0] + " -w " + str(len(convert_port_list)) + " -t 300 -b 0.0.0.0:# --chdir " \
-               + project_path_list[0] + "/format_convert convert:app" + std_out
-soffice_comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
+# 获取接口各个参数,提前生成命令
+python_path = get_args_from_config(ip_port_dict, ip, "python_path")[0]
+project_path = get_args_from_config(ip_port_dict, ip, "project_path")[0]
+gunicorn_path = get_args_from_config(ip_port_dict, ip, "gunicorn_path")[0]
+interface_list = ['convert', 'ocr', 'otr', 'idc', 'isr', 'atc', 'yolo', 'office']
+comm_dict = {}
+interface_port_dict = {}
+for name in interface_list:
+    if get_args_from_config(ip_port_dict, ip, name, 'MASTER'):
+        port_list, num_list, gpu_list = get_args_from_config(ip_port_dict, ip, name, 'MASTER')[0]
+    else:
+        port_list, num_list, gpu_list = get_args_from_config(ip_port_dict, ip, name)[0]
+
+    interface_port_dict[name] = [port_list, num_list, gpu_list]
+
+    for i, port in enumerate(port_list):
+        port_num = num_list[i]
+        if int(port_num) == 0:
+            continue
+
+        # 设置gpu
+        if gpu_list:
+            gpu = gpu_list[i]
+        else:
+            gpu = -1
+        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
+        gpu_comm = 'export CUDA_VISIBLE_DEVICES=' + str(gpu) + ' && '
+
+        # 设置命令
+        if name == 'convert':
+            comm = "nohup " + gunicorn_path + " -w " + str(port_num) + " -t 300 --keep-alive 600 -b 0.0.0.0:" + str(port) + " --chdir " + project_path + "format_convert" + ' ' + name + ":app" + std_out
+        elif name == 'yolo':
+            comm = "nohup " + gunicorn_path + " -w " + str(port_num) + " -t 300 --keep-alive 600 -b 0.0.0.0:" + str(port) + " --chdir " + project_path + "/botr/yolov8" + ' ' + name + "_interface:app" + std_out_gpu
+        elif name == 'office':
+            comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
+            office_port_comm_list = []
+            for office_port in range(port, port + port_num):
+                office_port_comm_list = re.sub("#", str(office_port), comm)
+            comm_dict[name] = office_port_comm_list
+        else:
+            comm = "nohup " + gunicorn_path + " -w " + str(port_num) + " -t 300 --keep-alive 600 -b 0.0.0.0:" + str(port) + " --chdir " + project_path + "/" + name + ' ' + name + "_interface:app" + std_out_gpu
+
+        if name == 'office':
+            continue
+
+        if name in comm_dict.keys():
+            comm_dict[name] += [gpu_comm + comm]
+        else:
+            comm_dict[name] = [gpu_comm + comm]
+
+    # print(name, port_list, num_list, gpu_list)
+
+# convert_port_list = get_args_from_config(ip_port_dict, ip, "convert", "MASTER")
+# if convert_port_list:
+#     convert_port_list = convert_port_list[0]
+# ocr_port_list = get_args_from_config(ip_port_dict, ip, "ocr")
+# otr_port_list = get_args_from_config(ip_port_dict, ip, "otr")
+# idc_port_list = get_args_from_config(ip_port_dict, ip, "idc")
+# isr_port_list = get_args_from_config(ip_port_dict, ip, "isr")
+# atc_port_list = get_args_from_config(ip_port_dict, ip, "atc")
+# yolo_port_list = get_args_from_config(ip_port_dict, ip, "yolo")
+# soffice_port_list = get_args_from_config(ip_port_dict, ip, "office", "MASTER")
+# if soffice_port_list:
+#     soffice_port_list = soffice_port_list[0]
+# python_path_list = get_args_from_config(ip_port_dict, ip, "python_path")
+# project_path_list = get_args_from_config(ip_port_dict, ip, "project_path")
+# gunicorn_path_list = get_args_from_config(ip_port_dict, ip, "gunicorn_path")
+# std_out = " >>/convert.out 2>&1 &"
+# std_out_gpu = " >>/gpu.out 2>&1 &"
+# std_out_schedule = " >>/schedule.out 2>&1 &"
+#
+# print("convert_port_list", convert_port_list)
+# print("ocr_port_list", ocr_port_list)
+# print("otr_port_list", otr_port_list)
+# print("idc_port_list", idc_port_list)
+# print("isr_port_list", isr_port_list)
+# print("atc_port_list", atc_port_list)
+# print("yolo_port_list", yolo_port_list)
+# print("soffice_port_list", soffice_port_list)
+#
+# # 根据port生成gunicorn语句
+# ocr_comm_list = []
+# otr_comm_list = []
+# isr_comm_list = []
+# idc_comm_list = []
+# atc_comm_list = []
+# yolo_comm_list = []
+# for i in range(len(ocr_port_list)):
+#     ocr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(ocr_port_list[i]))
+#                          + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+#                          + project_path_list[i] + "/ocr ocr_interface:app" + std_out_gpu)
+# for i in range(len(otr_port_list)):
+#     otr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(otr_port_list[i]))
+#                          + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+#                          + project_path_list[i] + "/otr otr_interface:app" + std_out_gpu)
+# for i in range(len(idc_port_list)):
+#     idc_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(idc_port_list[i]))
+#                          + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+#                          + project_path_list[i] + "/idc idc_interface:app" + std_out_gpu)
+# for i in range(len(isr_port_list)):
+#     isr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(isr_port_list[i]))
+#                          + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+#                          + project_path_list[i] + "/isr isr_interface:app" + std_out_gpu)
+# for i in range(len(atc_port_list)):
+#     atc_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(atc_port_list[i]))
+#                          + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+#                          + project_path_list[i] + "/atc atc_interface:app" + std_out_gpu)
+# for i in range(len(yolo_port_list)):
+#     yolo_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(yolo_port_list[i]))
+#                          + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+#                          + project_path_list[i] + "/botr/yolov8 yolo_interface:app" + std_out_gpu)
+#
+# convert_comm = "nohup " + gunicorn_path_list[0] + " -w " + str(len(convert_port_list)) + " -t 300 -b 0.0.0.0:# --chdir " \
+#                + project_path_list[0] + "/format_convert convert:app" + std_out
+# soffice_comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
 
 
 def get_port():
@@ -92,29 +145,38 @@ def get_port():
     return current_port_list
 
 
-def restart(process_type, port, index=0):
-    if process_type == "convert":
-        _comm = re.sub("#", port, convert_comm)
-    elif process_type == "ocr":
-        _comm = re.sub("#", port, ocr_comm_list[index])
-    elif process_type == "otr":
-        _comm = re.sub("#", port, otr_comm_list[index])
-    elif process_type == "soffice":
-        _comm = re.sub("#", port, soffice_comm)
-    elif process_type == "idc":
-        _comm = re.sub("#", port, idc_comm_list[index])
-    elif process_type == "isr":
-        _comm = re.sub("#", port, isr_comm_list[index])
-    elif process_type == "atc":
-        _comm = re.sub("#", port, atc_comm_list[index])
-    elif process_type == "yolo":
-        _comm = re.sub("#", port, yolo_comm_list[index])
-    else:
-        _comm = "netstat -nltp"
-        print("no process_type", process_type)
-    # os.system("echo $(date +%F%n%T)")
-    print(datetime.datetime.now(), "restart comm", _comm)
-    os.system(_comm)
+def restart(interface_type, port, index=0):
+    # if process_type == "convert":
+    #     _comm = re.sub("#", port, convert_comm)
+    # elif process_type == "ocr":
+    #     _comm = re.sub("#", port, ocr_comm_list[index])
+    # elif process_type == "otr":
+    #     _comm = re.sub("#", port, otr_comm_list[index])
+    # elif process_type == "soffice":
+    #     _comm = re.sub("#", port, soffice_comm)
+    # elif process_type == "idc":
+    #     _comm = re.sub("#", port, idc_comm_list[index])
+    # elif process_type == "isr":
+    #     _comm = re.sub("#", port, isr_comm_list[index])
+    # elif process_type == "atc":
+    #     _comm = re.sub("#", port, atc_comm_list[index])
+    # elif process_type == "yolo":
+    #     _comm = re.sub("#", port, yolo_comm_list[index])
+    # else:
+    #     _comm = "netstat -nltp"
+    #     print("no process_type", process_type)
+    #
+
+    _comm_list = comm_dict.get(interface_type)
+
+    if not _comm_list:
+        print('monitor_process_config restart command error! check config!')
+        raise
+
+    for _comm in _comm_list:
+        if str(port) in _comm:
+            print(datetime.datetime.now(), "restart comm", _comm)
+            os.system(_comm)
 
 
 def kill_soffice(limit_sec=30):
@@ -180,53 +242,60 @@ def kill_nested_timeout_process():
 
 
 def monitor():
-    current_port_list = get_port()
-
-    if convert_port_list:
-        for p in convert_port_list[:1]:
-            if p not in current_port_list:
-                restart("convert", p)
-
-    if ocr_port_list:
-        for j in range(len(ocr_port_list)):
-            for p in ocr_port_list[j][:1]:
-                if p not in current_port_list:
-                    restart("ocr", p, index=j)
-
-    if otr_port_list:
-        for j in range(len(otr_port_list)):
-            for p in otr_port_list[j][:1]:
-                if p not in current_port_list:
-                    restart("otr", p, index=j)
-
-    if idc_port_list:
-        for j in range(len(idc_port_list)):
-            for p in idc_port_list[j][:1]:
-                if p not in current_port_list:
-                    restart("idc", p, index=j)
-
-    if isr_port_list:
-        for j in range(len(isr_port_list)):
-            for p in isr_port_list[j][:1]:
-                if p not in current_port_list:
-                    restart("isr", p, index=j)
-
-    if atc_port_list:
-        for j in range(len(atc_port_list)):
-            for p in atc_port_list[j][:1]:
-                if p not in current_port_list:
-                    restart("atc", p, index=j)
-
-    if yolo_port_list:
-        for j in range(len(yolo_port_list)):
-            for p in yolo_port_list[j][:1]:
-                if p not in current_port_list:
-                    restart("yolo", p, index=j)
-
-    if soffice_port_list:
-        for p in soffice_port_list:
-            if p not in current_port_list:
-                restart("soffice", p)
+    for _name in interface_list:
+        if interface_port_dict.get(_name):
+            _port_list, _num_list, _gpu_list = interface_port_dict.get(_name)
+            current_port_list = get_port()
+            for j, p in enumerate(_port_list):
+                if str(p) not in current_port_list:
+                    restart(_name, p)
+
+
+    # if convert_port_list:
+    #     for p in convert_port_list[:1]:
+    #         if p not in current_port_list:
+    #             restart("convert", p)
+    #
+    # if ocr_port_list:
+    #     for j in range(len(ocr_port_list)):
+    #         for p in ocr_port_list[j][:1]:
+    #             if p not in current_port_list:
+    #                 restart("ocr", p, index=j)
+    #
+    # if otr_port_list:
+    #     for j in range(len(otr_port_list)):
+    #         for p in otr_port_list[j][:1]:
+    #             if p not in current_port_list:
+    #                 restart("otr", p, index=j)
+    #
+    # if idc_port_list:
+    #     for j in range(len(idc_port_list)):
+    #         for p in idc_port_list[j][:1]:
+    #             if p not in current_port_list:
+    #                 restart("idc", p, index=j)
+    #
+    # if isr_port_list:
+    #     for j in range(len(isr_port_list)):
+    #         for p in isr_port_list[j][:1]:
+    #             if p not in current_port_list:
+    #                 restart("isr", p, index=j)
+    #
+    # if atc_port_list:
+    #     for j in range(len(atc_port_list)):
+    #         for p in atc_port_list[j][:1]:
+    #             if p not in current_port_list:
+    #                 restart("atc", p, index=j)
+    #
+    # if yolo_port_list:
+    #     for j in range(len(yolo_port_list)):
+    #         for p in yolo_port_list[j][:1]:
+    #             if p not in current_port_list:
+    #                 restart("yolo", p, index=j)
+    #
+    # if soffice_port_list:
+    #     for p in soffice_port_list:
+    #         if p not in current_port_list:
+    #             restart("soffice", p)
 
     kill_soffice()
 
@@ -239,7 +308,7 @@ def monitor():
 
 
 if __name__ == "__main__":
-    for i in range(6):
+    for i in range(3):
         # os.system("echo $(date +%F%n%T)")
         monitor()
         time.sleep(10)

+ 0 - 0
my_zipfile.py → format_convert/my_zipfile.py


+ 115 - 74
format_convert/utils.py

@@ -1453,12 +1453,15 @@ def my_subprocess_call(*popenargs, timeout=None):
 
 
 def parse_yaml():
-    yaml_path = os.path.dirname(os.path.abspath(__file__)) + "/interface.yml"
-    with open(yaml_path, "r", encoding='utf-8') as f:
-        cfg = f.read()
+    yaml_path = os.path.dirname(os.path.abspath(__file__)) + "/interface_new.yml"
+    # with open(yaml_path, "r", encoding='utf-8') as f:
+    #     cfg = f.read()
+    #
+    # params = yaml.load(cfg, Loader=yaml.SafeLoader)
 
-    params = yaml.load(cfg, Loader=yaml.SafeLoader)
-    return params
+    with open(yaml_path, "r", encoding='utf-8') as f:
+        _dict = json.load(f)
+    return _dict
 
 
 def get_ip_port(node_type=None, interface_type=None):
@@ -1477,60 +1480,46 @@ def get_ip_port(node_type=None, interface_type=None):
     # 循环 master slave
     for type1 in node_type_list:
         node_type = type1.upper()
-        ip_list = params.get(node_type).get("ip")
+        ip = params.get(node_type).get("ip")
+        if not ip:
+            continue
+
+        if ip_port_dict.get(ip):
+            ip_port_dict.get(ip).update({node_type: {}})
+        else:
+            ip_port_dict.update({ip: {node_type: {}}})
+
+        # 有IP时,循环多个参数
+        for type2 in interface_type_list:
+            python_path = None
+            project_path = None
+            gunicorn_path = None
+            port_list = []
+            interface_type = type2
+
+            if not params.get(node_type).get(interface_type):
+                continue
 
-        # 循环多个IP
-        for j in range(len(ip_list)):
-            _ip = ip_list[j]
-            if ip_port_dict.get(_ip):
-                ip_port_dict.get(_ip).update({node_type: {}})
+            if interface_type == "path":
+                python_path = params.get(node_type).get(interface_type).get("python")
+                project_path = params.get(node_type).get(interface_type).get("project")
+                gunicorn_path = params.get(node_type).get(interface_type).get("gunicorn")
             else:
-                ip_port_dict.update({_ip: {node_type: {}}})
-
-            # 有IP时,循环多个参数
-            for type2 in interface_type_list:
-                python_path = None
-                project_path = None
-                gunicorn_path = None
-                processes = 0
-                port_list = []
-                interface_type = type2.upper()
-                # if interface_type in ["convert".upper()]:
-                #     _port = params.get(node_type).get(interface_type).get("port")
-                #     if _port is None:
-                #         port_list = []
-                #     else:
-                #         if interface_type == "convert".upper():
-                #             processes = params.get(node_type).get(interface_type).get("processes")[j]
-                #         port_list = [str(_port[j])]*int(processes)
-                #         # port_list = [str(_port)]
-                if interface_type == "path".upper():
-                    python_path = params.get(node_type).get(interface_type).get("python")[j]
-                    project_path = params.get(node_type).get(interface_type).get("project")[j]
-                    gunicorn_path = params.get(node_type).get(interface_type).get("gunicorn")[j]
+                port = params.get(node_type).get(interface_type).get("port")
+                port_num = params.get(node_type).get(interface_type).get("port_num")
+                gpu_no = params.get(node_type).get(interface_type).get("gpu")
+                if port is None or port_num is None:
+                    port_list = []
                 else:
-                    port_start = params.get(node_type).get(interface_type).get("port_start")
-                    port_no = params.get(node_type).get(interface_type).get("port_no")
-                    if port_start is None or port_no is None:
-                        port_list = []
-                    else:
-                        if interface_type in ["office".upper()]:
-                            port_list = [str(x) for x in range(port_start[j], port_start[j] + port_no[j], 1)]
-                        else:
-                            port_list = [str(port_start[j])] * port_no[j]
-                # if ip_list:
-                #     for i in range(len(ip_list)):
-
-                # 参数放入dict
-                if port_list:
-                    ip_port_dict.get(_ip).get(node_type).update({interface_type.lower(): port_list})
-                if processes:
-                    ip_port_dict.get(_ip).get(node_type).update({interface_type.lower() + "_processes": processes})
-                if project_path and python_path and gunicorn_path:
-                    ip_port_dict.get(_ip).get(node_type).update({"project_path": project_path,
-                                                                 "python_path": python_path,
-                                                                 "gunicorn_path": gunicorn_path})
-                # print("ip_port_dict", ip_port_dict)
+                    port_list = [port, port_num, gpu_no]
+
+            # 参数放入dict
+            if port_list:
+                ip_port_dict.get(ip).get(node_type).update({interface_type: port_list})
+            if project_path and python_path and gunicorn_path:
+                ip_port_dict.get(ip).get(node_type).update({"project_path": project_path,
+                                                            "python_path": python_path,
+                                                            "gunicorn_path": gunicorn_path})
     return ip_port_dict
 
 
@@ -1616,7 +1605,7 @@ def get_intranet_ip():
 
 def get_all_ip():
     if get_platform() == "Windows":
-        ips = ['127.0.0.1']
+        ips = ['0.0.0.0']
     else:
         ips = [ip.split('/')[0] for ip in os.popen("ip addr | grep 'inet '|awk '{print $2}'").readlines()]
     for i in range(len(ips)):
@@ -1627,11 +1616,14 @@ def get_all_ip():
 def get_using_ip():
     ip_port_dict = get_ip_port()
     ips = get_all_ip()
-    ip = "http://127.0.0.1"
     for key in ip_port_dict.keys():
         if key in ips:
             ip = key
             break
+
+    # ip = "http://127.0.0.1"
+    if ip == 'http://127.0.0.1':
+        ip = 'http://0.0.0.0'
     return ip
 
 
@@ -1700,26 +1692,16 @@ def set_flask_global():
     ip_port_flag = {}
     # ip_flag = []
     ip_port_dict = get_ip_port()
+    print(ip_port_dict)
     for _k in ip_port_dict.keys():
+        print(_k)
         ip_port_flag.update({_k: {}})
         for interface in ["ocr", "otr", "convert", "idc", "isr", "atc", 'yolo', "office"]:
-            if ip_port_dict.get(_k).get("MASTER"):
-                if ip_port_dict.get(_k).get("MASTER").get(interface):
+            if ip_port_dict.get(_k).get("MASTER") and ip_port_dict.get(_k).get("MASTER").get(interface):
                     ip_port_flag[_k][interface] = 0
             else:
-                if ip_port_dict.get(_k).get("SLAVE").get(interface):
+                if ip_port_dict.get(_k).get("SLAVE") and ip_port_dict.get(_k).get("SLAVE").get(interface):
                     ip_port_flag[_k][interface] = 0
-        # ip_port_flag.update({_k: {"ocr": 0,
-        #                           "otr": 0,
-        #                           "convert": 0,
-        #                           "idc": 0,
-        #                           "isr": 0,
-        #                           "office": 0
-        #                           }})
-        # if ip_port_dict.get(_k).get("MASTER"):
-        #     ip_flag.append([_k+"_master", 0])
-        # if ip_port_dict.get(_k).get("SLAVE"):
-        #     ip_flag.append([_k+"_slave", 0])
     _global.update({"ip_port_flag": ip_port_flag})
     _global.update({"ip_port": ip_port_dict})
     # _global.update({"ip_flag": ip_flag})
@@ -2026,7 +2008,7 @@ def file_lock(file_name):
 
 def get_garble_code():
     reg_str = '[ÿÝØÐÙÚÛÜÒÓÔÕÖÊÄẨòóôäåüúîïìþ¡¢£¤§èéêëȟš' + \
-              'Ϸᱦ¼ŒÞ¾Çœø‡Æ�ϐ㏫⮰ڝⶹӇⰚڣༀងϦȠ⚓Ⴭᐬ⩔ⅮⰚࡦࣽ' + \
+              'Ϸᱦ¼ŒÞ¾Çœø‡Æ�ϐ㏫⮰ڝⶹӇⰚڣༀងϦȠ⚓Ⴭᐬ⩔ⅮⰚࡦࣽ' + \
               '䕆㶃䌛㻰䙹䔮㔭䶰䰬䉰䶰䘔䉥喌䶥䶰䛳䉙䄠' + \
               ''.join(['\\x0' + str(x) for x in range(1, 10)]) + \
               ''.join(['\\x' + str(x) for x in range(10, 20)]) + \
@@ -2045,6 +2027,63 @@ def line_is_cross(A, B, C, D):
         return False
 
 
+def line_iou(line1, line2, axis=0):
+    inter = min(line1[1][axis], line2[1][axis]) - max(line1[0][axis], line2[0][axis])
+    # union = max(line1[1][axis], line2[1][axis]) - min(line1[0][axis], line2[0][axis])
+    union = min(abs(line1[0][axis]-line1[1][axis]), abs(line2[0][axis]-line2[1][axis]))
+    if union in [0, 0.]:
+        iou = 0.
+    else:
+        iou = inter / union
+    return iou
+
+
+def bbox_iou(bbox1, bbox2):
+    x1_min, y1_min, x1_max, y1_max = bbox1
+    x2_min, y2_min, x2_max, y2_max = bbox2
+
+    # 计算矩形框1的宽度、高度和面积
+    width1 = x1_max - x1_min
+    height1 = y1_max - y1_min
+    area1 = width1 * height1
+
+    # 计算矩形框2的宽度、高度和面积
+    width2 = x2_max - x2_min
+    height2 = y2_max - y2_min
+    area2 = width2 * height2
+
+    # 计算相交矩形框的左上角和右下角坐标
+    x_intersection_min = max(x1_min, x2_min)
+    y_intersection_min = max(y1_min, y2_min)
+    x_intersection_max = min(x1_max, x2_max)
+    y_intersection_max = min(y1_max, y2_max)
+
+    # 计算相交矩形框的宽度和高度
+    intersection_width = max(0, x_intersection_max - x_intersection_min)
+    intersection_height = max(0, y_intersection_max - y_intersection_min)
+
+    # 计算相交矩形框的面积
+    intersection_area = intersection_width * intersection_height
+
+    # 判断包含关系并调整相交面积
+    if (x1_min <= x2_min) and (y1_min <= y2_min) and (x1_max >= x2_max) and (y1_max >= y2_max):
+        union_area = area2
+    elif (x2_min <= x1_min) and (y2_min <= y1_min) and (x2_max >= x1_max) and (y2_max >= y1_max):
+        union_area = area1
+    else:
+        # 计算并集矩形框的面积
+        # union_area = area1 + area2 - intersection_area
+        union_area = min(area1, area2)
+
+    # 计算IoU
+    if int(union_area) == 0:
+        iou = 0
+    else:
+        iou = intersection_area / union_area
+
+    return iou
+
+
 if __name__ == "__main__":
     # strs = r"D:\Project\temp\04384fcc9e8911ecbd2844f971944973\043876ca9e8911eca5e144f971944973_rar\1624114035529.jpeg"
     # print(slash_replace(strs))
@@ -2075,9 +2114,11 @@ if __name__ == "__main__":
 
     print(get_ip_port())
     # set_flask_global()
-    # print(get_all_ip())
+    print(get_all_ip())
     print(get_args_from_config(get_ip_port(), get_all_ip()[0], "idc"))
     print(get_args_from_config(get_ip_port(), get_all_ip()[0], "atc"))
+    print(get_args_from_config(get_ip_port(), get_all_ip()[0], "ocr"))
+    print(get_args_from_config(get_ip_port(), get_all_ip()[0], 'convert', 'MASTER'))
     # print(get_args_from_config(get_ip_port(), "http://127.0.0.1", "gunicorn_path"))
     # print(get_intranet_ip())
     # _path = "C:/Users/Administrator/Downloads/3.png"

+ 7 - 4
ocr/ocr_interface.py

@@ -37,7 +37,9 @@ def _ocr():
         _md5 = request.form.get("md5")
         only_rec = request.form.get("only_rec")
         if only_rec is None:
-            only_rec = False
+            only_rec = 0
+        else:
+            only_rec = int(only_rec)
         _global.update({"md5": _md5})
         ocr_model = globals().get("global_ocr_model")
         if ocr_model is None:
@@ -55,7 +57,7 @@ def _ocr():
         log("ocr interface finish time " + str(time.time()-start_time))
 
 
-def ocr(data, ocr_model, only_rec=False):
+def ocr(data, ocr_model, only_rec=0):
     log("into ocr_interface ocr")
     try:
         img_data = base64.b64decode(data)
@@ -65,7 +67,7 @@ def ocr(data, ocr_model, only_rec=False):
         return {"text": str([-5]), "bbox": str([-5])}
 
 
-def picture2text(img_data, ocr_model, only_rec=False):
+def picture2text(img_data, ocr_model, only_rec=0):
     log("into ocr_interface picture2text")
     try:
         # 二进制数据流转np.ndarray [np.uint8: 8位像素]
@@ -121,6 +123,7 @@ class OcrModels:
     def __init__(self):
         from ocr.paddleocr import PaddleOCR
         try:
+            log('----------- init ocr model ---------------')
             self.ocr_model = PaddleOCR(use_angle_cls=True, lang="ch")
         except:
             print(traceback.print_exc())
@@ -131,7 +134,7 @@ class OcrModels:
 
 
 def test_ocr_model(from_remote=True):
-    file_path = "C:/Users/Administrator/Desktop/2.png"
+    file_path = "error8.png"
     with open(file_path, "rb") as f:
         file_bytes = f.read()
     file_base64 = base64.b64encode(file_bytes)

+ 8 - 3
ocr/paddleocr.py

@@ -31,6 +31,7 @@ from tqdm import tqdm
 os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
 from ocr.tools.infer import predict_system
 from ocr.ppocr.utils.logging import get_logger
+from format_convert.max_compute_config import max_compute
 
 logger = get_logger()
 from ocr.ppocr.utils.utility import check_and_read_gif, get_image_file_list
@@ -187,8 +188,13 @@ def parse_args(mMain=True, add_help=True):
         parser.add_argument("--use_angle_cls", type=str2bool, default=False)
         return parser.parse_args()
     else:
+        if max_compute:
+            use_gpu = False
+        else:
+            use_gpu = True
+
         return argparse.Namespace(
-            use_gpu=True,
+            use_gpu=use_gpu,
             ir_optim=True,
             use_tensorrt=False,
             gpu_mem=8000,
@@ -258,7 +264,7 @@ class PaddleOCR(predict_system.TextSystem):
         if postprocess_params.cls_model_dir is None:
             postprocess_params.cls_model_dir = os.path.join(
                 BASE_DIR, '{}/cls'.format(VERSION))
-        print(postprocess_params)
+        logger.info(postprocess_params)
 
         # download model
         maybe_download(postprocess_params.det_model_dir, model_urls['det'])
@@ -287,7 +293,6 @@ class PaddleOCR(predict_system.TextSystem):
             det: use text detection or not, if false, only rec will be exec. default is True
             rec: use text recognition or not, if false, only det will be exec. default is True
         """
-        print("det, rec, cls", det, rec, cls)
         assert isinstance(img, (np.ndarray, list, str))
         if isinstance(img, list) and det == True:
             logger.error('When input a list of images, det must be false')

+ 24 - 9
ocr/tools/infer/predict_rec.py

@@ -78,18 +78,32 @@ class TextRecognizer(object):
             utility.create_predictor(args, 'rec', logger)
 
     def resize_norm_img(self, img, max_wh_ratio):
+        h, w = img.shape[:2]
         imgC, imgH, imgW = self.rec_image_shape
         assert imgC == img.shape[2]
-        if self.character_type == "ch":
-            imgW = int((32 * max_wh_ratio))
-        h, w = img.shape[:2]
-        ratio = w / float(h)
-        if math.ceil(imgH * ratio) > imgW:
-            resized_w = imgW
+
+        if max_wh_ratio < 0.1:
+            if h > imgW:
+                resized_image = cv2.resize(img, (w, imgW))
+            else:
+                resized_image = img
+
         else:
-            resized_w = int(math.ceil(imgH * ratio))
-        # print("predict_rec.py resize_norm_img resize shape", (resized_w, imgH))
-        resized_image = cv2.resize(img, (resized_w, imgH))
+            if self.character_type == "ch":
+                imgW = int((32 * max_wh_ratio))
+
+            ratio = w / float(h)
+            if math.ceil(imgH * ratio) > imgW:
+                resized_w = imgW
+            else:
+                resized_w = int(math.ceil(imgH * ratio))
+
+            try:
+                resized_image = cv2.resize(img, (resized_w, imgH))
+            except:
+                log("predict_rec.py resize_norm_img resize shape " + str((resized_w, imgH, imgW, h, w, ratio, max_wh_ratio)) + ' ' + str(self.rec_image_shape))
+                raise
+
         resized_image = resized_image.astype('float32')
         resized_image = resized_image.transpose((2, 0, 1)) / 255
         resized_image -= 0.5
@@ -194,6 +208,7 @@ class TextRecognizer(object):
                 max_wh_ratio = max(max_wh_ratio, wh_ratio)
             for ino in range(beg_img_no, end_img_no):
                 if self.rec_algorithm != "SRN":
+
                     norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                     max_wh_ratio)
                     norm_img = norm_img[np.newaxis, :]

+ 29 - 2
otr/table_line_new.py

@@ -173,15 +173,21 @@ def table_line_pdf(line_list, page_w, page_h, is_test=0):
     img_new = np.full([int(page_h+1), int(page_w+1), 3], 255, dtype=np.uint8)
     img_show = copy.deepcopy(img_new)
 
+    show(line_list, title="table_line_pdf start", mode=2, is_test=is_test)
+
     # 分成横竖线
     start_time = time.time()
     row_line_list = []
     col_line_list = []
     for line in line_list:
+        # 可能有斜线
         if line[0] == line[2]:
             col_line_list.append(line)
         elif line[1] == line[3]:
             row_line_list.append(line)
+        else:
+            if is_test:
+                print(line)
     log("pdf divide rows and cols " + str(time.time() - start_time))
     show(row_line_list + col_line_list, title="divide", mode=2, is_test=is_test)
 
@@ -189,7 +195,13 @@ def table_line_pdf(line_list, page_w, page_h, is_test=0):
     if not row_line_list or not col_line_list:
         return []
 
+    # 合并线
+    row_line_list = merge_line(row_line_list, axis=0)
+    col_line_list = merge_line(col_line_list, axis=1)
+    show(row_line_list + col_line_list, title="merge", mode=2, is_test=is_test)
+
     # 计算交点
+    print('img_new.shape', img_new.shape)
     cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
     if not cross_points:
         return []
@@ -248,6 +260,13 @@ def table_line_pdf(line_list, page_w, page_h, is_test=0):
         sub_col_line_list = area_col_line_list[i]
         sub_point_list = area_point_list[i]
 
+        # 验证轮廓的4个交点
+        sub_row_line_list, sub_col_line_list = fix_4_points(sub_point_list, sub_row_line_list, sub_col_line_list)
+
+        # 把四个边线在加一次
+        sub_point_list = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
+        sub_row_line_list, sub_col_line_list = add_outline(sub_point_list, sub_row_line_list, sub_col_line_list)
+
         # 修复内部缺线
         start_time = time.time()
         sub_row_line_list, sub_col_line_list = fix_inner(sub_row_line_list, sub_col_line_list, sub_point_list)
@@ -258,6 +277,11 @@ def table_line_pdf(line_list, page_w, page_h, is_test=0):
         start_time = time.time()
         cross_points = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
         show(cross_points, title="get_points3", img=img_show, mode=4, is_test=is_test)
+        area_point_list[i] = cross_points
+
+        # 合并线
+        area_row_line_list[i] = merge_line(sub_row_line_list, axis=0)
+        area_col_line_list[i] = merge_line(sub_col_line_list, axis=1)
 
     row_line_list = [y for x in area_row_line_list for y in x]
     col_line_list = [y for x in area_col_line_list for y in x]
@@ -628,18 +652,21 @@ def merge_line(lines, axis, threshold=5):
     return result_lines
 
 
-def get_points(row_lines, col_lines, image_size):
+def get_points(row_lines, col_lines, image_size, threshold=5):
     # 创建空图
     row_img = np.zeros(image_size, np.uint8)
     col_img = np.zeros(image_size, np.uint8)
 
     # 画线
-    threshold = 5
+    # threshold = 5
     for row in row_lines:
         cv2.line(row_img, (int(row[0] - threshold), int(row[1])), (int(row[2] + threshold), int(row[3])), (255, 255, 255), 1)
     for col in col_lines:
         cv2.line(col_img, (int(col[0]), int(col[1] - threshold)), (int(col[2]), int(col[3] + threshold)), (255, 255, 255), 1)
 
+    # cv2.imshow('get_points', row_img+col_img)
+    # cv2.waitKey(0)
+
     # 求出交点
     point_img = np.bitwise_and(row_img, col_img)
     # cv2.imwrite("get_points.jpg", row_img+col_img)

이 변경점에서 너무 많은 파일들이 변경되어 몇몇 파일들은 표시되지 않았습니다.