123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647 |
- # encoding=utf8
- import base64
- import json
- import multiprocessing
- import os
- import pickle
- import random
- import sys
- import time
- import uuid
- import cv2
- import torch
- from werkzeug.exceptions import NotFound
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from botr.yolov8.yolo_interface import yolo
- from botr.yolov8.model import Predictor
- from atc.atc_interface import AtcModels, atc
- from idc.idc_interface import IdcModels, idc
- from isr.isr_interface import IsrModels, isr
- import traceback
- import requests
- from format_convert import _global
- from format_convert.utils import get_platform, get_sequential_data, judge_error_code, request_post, get_ip_port, \
- get_intranet_ip, get_logger, log, get_args_from_config, get_using_ip, np2bytes, set_flask_global
- from ocr.ocr_interface import ocr, OcrModels
- from otr.otr_interface import otr, OtrModels
- from format_convert.libreoffice_interface import office_convert
- import numpy as np
- from format_convert.max_compute_config import max_compute
- MAX_COMPUTE = max_compute
- if get_platform() == "Windows":
- FROM_REMOTE = False
- only_test_ocr = False
- if only_test_ocr:
- ip_port_flag = {}
- ip_port_dict = get_ip_port()
- for _k in ip_port_dict.keys():
- ip_port_flag.update({_k: {"ocr": 0,
- "otr": 0,
- "convert": 0,
- "office": 0
- }})
- _global.update({"ip_port_flag": ip_port_flag})
- ip_port_dict["http://127.0.0.1"]["ocr"] = ["17000"]
- ip_port_dict["http://127.0.0.1"]["otr"] = ["18000"]
- _global.update({"ip_port": ip_port_dict})
- else:
- FROM_REMOTE = True
- if MAX_COMPUTE:
- FROM_REMOTE = False
- lock = multiprocessing.RLock()
- # 连接redis数据库
- # redis_db = redis.StrictRedis(host='192.168.2.103', port='6379',
- # db=1, password='bidi123456', health_check_interval=300)
- redis_db = None
- def from_office_interface(src_path, dest_path, target_format, retry_times=1, from_remote=FROM_REMOTE):
- try:
- # Win10跳出超时装饰器
- # if get_platform() == "Windows":
- # # origin_office_convert = office_convert.__wrapped__
- # # file_path = origin_office_convert(src_path, dest_path, target_format, retry_times)
- # file_path = office_convert(src_path, dest_path, target_format, retry_times)
- # else:
- # # 将装饰器包装为一个类,否则多进程Pickle会报错 it's not the same object as xxx 问题,
- # # timeout_decorator_obj = my_timeout_decorator.TimeoutClass(office_convert, 180, TimeoutError)
- # # file_path = timeout_decorator_obj.run(src_path, dest_path, target_format, retry_times)
- #
- # file_path = office_convert(src_path, dest_path, target_format, retry_times)
- if from_remote:
- # 重试
- retry_times_1 = 1
- retry_times_2 = 2
- while retry_times_1 and retry_times_2:
- # _ip = ip_pool("soffice", _random=True)
- # _port = port_pool("soffice", _random=True)
- # _ip = interface_ip_list[0]
- # _port = "16002"
- # _ip, _port = interface_pool("soffice")
- # ip_port = from_schedule_interface("office")
- ip_port = interface_pool_gunicorn("office")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/soffice"
- with open(src_path, "rb") as f:
- file_bytes = f.read()
- base64_stream = base64.b64encode(file_bytes)
- start_time = time.time()
- log('office _url ' + str(_url))
- r = json.loads(request_post(_url, {"src_path": src_path,
- "dest_path": dest_path,
- "file": base64_stream,
- "target_format": target_format,
- "retry_times": retry_times}, time_out=25))
- log("get interface return")
- log("office use time " + str(time.time()-start_time))
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post office_interface... left times " + str(retry_times_1))
- continue
- file_str = r.get("data")
- if judge_error_code(file_str):
- if retry_times_2 <= 1:
- return file_str
- else:
- retry_times_2 -= 1
- continue
- file_bytes = eval(file_str)
- uid1 = src_path.split(os.sep)[-1].split(".")[0]
- file_path = dest_path + uid1 + "." + target_format
- if not os.path.exists(os.path.dirname(file_path)):
- os.makedirs(os.path.dirname(file_path), mode=0o777)
- with open(file_path, "wb") as f:
- f.write(file_bytes)
- break
- else:
- file_path = office_convert(src_path, dest_path, target_format, retry_times)
- if judge_error_code(file_path):
- return file_path
- return file_path
- except TimeoutError:
- log("from_office_interface timeout error!")
- return [-5]
- except:
- log("from_office_interface error!")
- print("from_office_interface", traceback.print_exc())
- return [-1]
- def from_ocr_interface(image_stream, is_table=0, only_rec=0, from_remote=FROM_REMOTE):
- log("into from_ocr_interface")
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- ip_port = interface_pool_gunicorn("ocr")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/ocr"
- r = json.loads(request_post(_url, {"data": base64_stream,
- "md5": _global.get("md5"),
- "only_rec": only_rec
- },
- time_out=60))
- log("get ocr interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- if is_table:
- return r, r
- else:
- return r
- else:
- retry_times_1 -= 1
- log("retry post ocr_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_ocr_model") is None:
- print("=========== init ocr model ===========")
- globals().update({"global_ocr_model": OcrModels().get_model()})
- r = ocr(data=base64_stream, ocr_model=globals().get("global_ocr_model"), only_rec=only_rec)
- except TimeoutError:
- if is_table:
- return [-5], [-5]
- else:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- if is_table:
- return [-2], [-2]
- else:
- return [-2]
- _dict = r
- text_list = eval(_dict.get("text"))
- bbox_list = eval(_dict.get("bbox"))
- if text_list is None:
- text_list = []
- if bbox_list is None:
- bbox_list = []
- if is_table:
- return text_list, bbox_list
- else:
- if text_list and bbox_list:
- text = get_sequential_data(text_list, bbox_list, html=True)
- if judge_error_code(text):
- return text
- else:
- text = ""
- return text
- except Exception as e:
- log("from_ocr_interface error!")
- log(str(traceback.print_exc()))
- traceback.print_exc()
- # print("from_ocr_interface", e, global_type)
- if is_table:
- return [-1], [-1]
- else:
- return [-1]
- def from_gpu_interface_redis(_dict, model_type, predictor_type):
- log("into from_gpu_interface")
- start_time = time.time()
- try:
- # 调用接口
- _uuid = uuid.uuid1().hex
- _dict.update({"predictor_type": predictor_type, "model_type": model_type,
- "uuid": _uuid})
- _time = time.time()
- log("pickle.dumps(_dict)" + str(_dict))
- redis_db.rpush("producer_"+model_type, pickle.dumps(_dict))
- log("producer_" + model_type + " len " + str(redis_db.llen("producer_" + model_type)))
- log("to producer_" + model_type + " time " + str(time.time()-_time))
- _time = time.time()
- time_out = 300
- while True:
- time.sleep(0.2)
- if time.time() - _time > time_out:
- raise Exception
- if redis_db.hexists("consumer_"+model_type, _uuid):
- time1 = time.time()
- result = redis_db.hget("consumer_"+model_type, _uuid)
- log("from consumer_"+model_type + " time " + str(time.time()-time1))
- break
- result = pickle.loads(result)
- log("from_gpu_interface finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
- return result
- except Exception as e:
- log("from_gpu_interface error!")
- log("from_gpu_interface failed " + str(time.time()-start_time))
- traceback.print_exc()
- return [-2]
- def from_otr_interface(image_stream, is_from_pdf=False, from_remote=FROM_REMOTE):
- log("into from_otr_interface")
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- log("from remote")
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- # _ip = ip_pool("otr", _random=True)
- # _port = port_pool("otr", _random=True)
- # if _ip == interface_ip_list[1]:
- # _port = otr_port_list[0]
- ip_port = interface_pool_gunicorn("otr")
- # ip_port = from_schedule_interface("otr")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/otr"
- r = json.loads(request_post(_url, {"data": base64_stream,
- "is_from_pdf": is_from_pdf,
- "md5": _global.get("md5")}, time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post otr_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- log("from local")
- log("otr_model " + str(globals().get("global_otr_model")))
- if globals().get("global_otr_model") is None:
- print("=========== init otr model ===========")
- globals().update({"global_otr_model": OtrModels().get_model()})
- log("init finish")
- r = otr(data=base64_stream, otr_model=globals().get("global_otr_model"), is_from_pdf=is_from_pdf)
- # r = otr(data=base64_stream, otr_model=None, is_from_pdf=is_from_pdf)
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- log("from_otr_interface")
- print("from_otr_interface", traceback.print_exc())
- return [-2]
- # 处理结果
- _dict = r
- list_line = eval(_dict.get("list_line"))
- return list_line
- except Exception as e:
- log("from_otr_interface error!")
- print("from_otr_interface", traceback.print_exc())
- return [-1]
- def from_isr_interface(image_stream, from_remote=FROM_REMOTE):
- log("into from_isr_interface")
- # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
- start_time = time.time()
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- ip_port = interface_pool_gunicorn("isr")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/isr"
- r = json.loads(request_post(_url, {"data": base64_stream,
- "md5": _global.get("md5")},
- time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post isr_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_isr_model") is None:
- print("=========== init isr model ===========")
- isr_yolo_model, isr_model = IsrModels().get_model()
- globals().update({"global_isr_yolo_model": isr_yolo_model})
- globals().update({"global_isr_model": isr_model})
- r = isr(data=base64_stream,
- isr_yolo_model=globals().get("global_isr_yolo_model"),
- isr_model=globals().get("global_isr_model"))
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- return [-2]
- _dict = r
- if from_remote:
- image_string = _dict.get("image")
- if judge_error_code(image_string):
- return image_string
- # [1]代表检测不到印章,直接返回
- if isinstance(image_string, list) and image_string == [1]:
- return image_string
- image_base64 = image_string.encode("utf-8")
- image_bytes = base64.b64decode(image_base64)
- buffer = np.frombuffer(image_bytes, dtype=np.uint8)
- image_np = cv2.imdecode(buffer, 1)
- else:
- image_np = _dict.get("image")
- log("from_isr_interface cost time " + str(time.time()-start_time))
- return image_np
- except Exception as e:
- log("from_isr_interface error!")
- traceback.print_exc()
- return [-11]
- finally:
- # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- pass
- def from_idc_interface(image_stream, from_remote=FROM_REMOTE):
- log("into from_idc_interface")
- start_time = time.time()
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- ip_port = interface_pool_gunicorn("idc")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/idc"
- r = json.loads(request_post(_url, {"data": base64_stream,
- "md5": _global.get("md5")},
- time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post idc_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_idc_model") is None:
- print("=========== init idc model ===========")
- idc_model = IdcModels().get_model()
- globals().update({"global_idc_model": idc_model})
- r = idc(data=base64_stream,
- model=globals().get("global_idc_model"))
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- return [-2]
- _dict = r
- angle = _dict.get("angle")
- log("from_idc_interface cost time " + str(time.time()-start_time))
- return angle
- except Exception as e:
- log("from_idc_interface error!")
- traceback.print_exc()
- return [-11]
- def from_atc_interface(text, from_remote=FROM_REMOTE):
- log("into from_atc_interface")
- start_time = time.time()
- try:
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- ip_port = interface_pool_gunicorn("atc")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/atc"
- r = json.loads(request_post(_url, {"data": text,
- "md5": _global.get("md5")},
- time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post atc_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_atc_model") is None:
- print("=========== init atc model ===========")
- atc_model = AtcModels().get_model()
- globals().update({"global_atc_model": atc_model})
- r = atc(data=text,
- model=globals().get("global_atc_model"))
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- return [-2]
- _dict = r
- classification = _dict.get("classification")
- log("from_atc_interface cost time " + str(time.time()-start_time))
- return classification
- except Exception as e:
- log("from_atc_interface error!")
- traceback.print_exc()
- return [-11]
- def from_yolo_interface(image_stream, from_remote=FROM_REMOTE):
- log("into from_yolo_interface")
- start_time = time.time()
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- ip_port = interface_pool_gunicorn("yolo")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/yolo"
- log('yolo _url ' + _url)
- r = json.loads(request_post(_url, {"data": base64_stream,
- "md5": _global.get("md5")},
- time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post yolo_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_yolo_predictor") is None:
- print("=========== init yolo model ===========")
- ROOT = os.path.abspath(os.path.dirname(__file__)) + '/../'
- model_path = ROOT + 'botr/yolov8/weights.pt'
- image_size = 640
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- yolo_predictor = Predictor(image_size, device, model_path)
- globals().update({"global_yolo_predictor": yolo_predictor})
- r = yolo(data=base64_stream,
- predictor=globals().get("global_yolo_predictor"))
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- return [-2]
- _dict = r
- b_table_list = _dict.get("b_table_list")
- log("from_yolo_interface cost time " + str(time.time()-start_time))
- return b_table_list
- except Exception as e:
- log("from_yolo_interface error!")
- traceback.print_exc()
- return [-11]
- def interface_pool_gunicorn(interface_type):
- if get_platform() == 'Windows':
- set_flask_global()
- ip_port_flag_dict = _global.get("ip_port_flag")
- ip_port_dict = _global.get("ip_port")
- try:
- if ip_port_dict is None or ip_port_flag_dict is None:
- print('_global', _global.get_dict())
- raise NotFound
- # 负载均衡, 选取有该接口的ip
- min_cnt = 10000.
- interface_cnt = 0
- _ip = None
- port_list = []
- for key in ip_port_flag_dict.keys():
- temp_port_list = get_args_from_config(ip_port_dict, key, interface_type)
- # print('temp_port_list', temp_port_list)
- if not temp_port_list:
- continue
- # 该ip下的该接口总数量(可能有多gpu接口)
- _port_list, _port_num_list, _ = temp_port_list[0]
- # print('_port_num_list', _port_num_list)
- total_port_num = sum(_port_num_list)
- if total_port_num == 0:
- continue
- interface_cnt = ip_port_flag_dict.get(key).get(interface_type)
- if interface_cnt is not None and interface_cnt / total_port_num < min_cnt:
- _ip = key
- min_cnt = interface_cnt / len(temp_port_list[0])
- # 选定ip,设置gpu的接口候选比例
- gpu_port_list = []
- for k in range(len(_port_list)):
- gpu_port_list += [_port_list[k]] * _port_num_list[k]
- port_list = gpu_port_list
- # port_list = temp_port_list[0]
- # 选取端口
- if interface_type == "office":
- if len(port_list) == 0:
- raise ConnectionError
- port_list = [str(port_list[k] + k) for k in range(len(port_list))]
- # 刚开始随机,后续求余
- if min_cnt == 0:
- _port = port_list[random.randint(0, len(port_list)-1)]
- ip_port_flag_dict[_ip][interface_type] = int(_port[-2:])
- else:
- _port = port_list[interface_cnt % len(port_list)]
- else:
- # 使用gunicorn则随机选
- _port = random.choice(port_list)
- # 更新flag
- if ip_port_flag_dict.get(_ip).get(interface_type) >= 10000:
- ip_port_flag_dict[_ip][interface_type] = 0
- else:
- ip_port_flag_dict[_ip][interface_type] += 1
- _global.update({"ip_port_flag": ip_port_flag_dict})
- ip_port = _ip + ":" + str(_port)
- log(interface_type)
- log(ip_port)
- return ip_port
- except NotFound:
- log("ip_port or ip_port_dict is None! checkout config")
- return [-2]
- except ConnectionError:
- log('no office interface running!')
- return [-15]
- except:
- traceback.print_exc()
- return [-1]
- if __name__ == "__main__":
- _global._init()
- set_flask_global()
- _img = cv2.imread(r"C:/Users/Administrator/Desktop/test_b_table/error11.png")
- _img_bytes = np2bytes(_img)
- b_list = from_yolo_interface(_img_bytes, from_remote=True)
- for l in b_list:
- for b in l:
- cv2.rectangle(_img, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (0, 0, 255), 2)
- cv2.namedWindow('img', cv2.WINDOW_NORMAL)
- cv2.imshow('img', _img)
- cv2.waitKey(0)
|