12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034 |
- # encoding=utf8
- import base64
- import inspect
- import json
- import logging
- import multiprocessing
- import os
- import pickle
- import random
- import sys
- import time
- import uuid
- import cv2
- import redis
- from werkzeug.exceptions import NotFound
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from atc.atc_interface import AtcModels, atc
- from idc.idc_interface import IdcModels, idc
- from isr.isr_interface import IsrModels, isr
- import traceback
- import requests
- from format_convert import _global
- from format_convert.utils import get_platform, get_sequential_data, judge_error_code, request_post, get_ip_port, \
- get_intranet_ip, get_logger, log, get_args_from_config, get_using_ip
- from ocr.ocr_interface import ocr, OcrModels
- from otr.otr_interface import otr, OtrModels
- from format_convert.libreoffice_interface import office_convert
- import numpy as np
- from format_convert.max_compute_config import max_compute
- MAX_COMPUTE = max_compute
- if get_platform() == "Windows":
- FROM_REMOTE = False
- only_test_ocr = False
- if only_test_ocr:
- ip_port_flag = {}
- ip_port_dict = get_ip_port()
- for _k in ip_port_dict.keys():
- ip_port_flag.update({_k: {"ocr": 0,
- "otr": 0,
- "convert": 0,
- "office": 0
- }})
- _global.update({"ip_port_flag": ip_port_flag})
- ip_port_dict["http://127.0.0.1"]["ocr"] = ["17000"]
- ip_port_dict["http://127.0.0.1"]["otr"] = ["18000"]
- _global.update({"ip_port": ip_port_dict})
- else:
- FROM_REMOTE = True
- if MAX_COMPUTE:
- FROM_REMOTE = False
- # ip_port_dict = get_ip_port()
- # ip = 'http://127.0.0.1'
- # ocr_port_list = ip_port_dict.get(ip).get("ocr")
- # otr_port_list = ip_port_dict.get(ip).get("otr")
- lock = multiprocessing.RLock()
- # 连接redis数据库
- # redis_db = redis.StrictRedis(host='192.168.2.103', port='6379',
- # db=1, password='bidi123456', health_check_interval=300)
- redis_db = None
- def _interface(_dict, time_out=60, retry_times=3):
- try:
- # 重试
- model_type = _dict.get("model_type")
- while retry_times:
- ip_port = interface_pool(model_type)
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/" + model_type
- # base64_stream = base64.b64encode(pickle.dumps(_dict))
- r = json.loads(request_post(_url, {"data": json.dumps(_dict),
- "model_type": model_type}, time_out=time_out))
- log("get _interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times <= 1:
- return r
- else:
- retry_times -= 1
- log("retry post _interface... left times " + str(retry_times) + " " + model_type)
- continue
- if judge_error_code(r):
- return r
- return r
- break
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- return [-2]
- def from_office_interface(src_path, dest_path, target_format, retry_times=1, from_remote=FROM_REMOTE):
- try:
- # Win10跳出超时装饰器
- # if get_platform() == "Windows":
- # # origin_office_convert = office_convert.__wrapped__
- # # file_path = origin_office_convert(src_path, dest_path, target_format, retry_times)
- # file_path = office_convert(src_path, dest_path, target_format, retry_times)
- # else:
- # # 将装饰器包装为一个类,否则多进程Pickle会报错 it's not the same object as xxx 问题,
- # # timeout_decorator_obj = my_timeout_decorator.TimeoutClass(office_convert, 180, TimeoutError)
- # # file_path = timeout_decorator_obj.run(src_path, dest_path, target_format, retry_times)
- #
- # file_path = office_convert(src_path, dest_path, target_format, retry_times)
- if from_remote:
- # 重试
- retry_times_1 = 1
- retry_times_2 = 2
- while retry_times_1 and retry_times_2:
- # _ip = ip_pool("soffice", _random=True)
- # _port = port_pool("soffice", _random=True)
- # _ip = interface_ip_list[0]
- # _port = "16002"
- # _ip, _port = interface_pool("soffice")
- # ip_port = from_schedule_interface("office")
- ip_port = interface_pool_gunicorn("office")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/soffice"
- with open(src_path, "rb") as f:
- file_bytes = f.read()
- base64_stream = base64.b64encode(file_bytes)
- start_time = time.time()
- r = json.loads(request_post(_url, {"src_path": src_path,
- "dest_path": dest_path,
- "file": base64_stream,
- "target_format": target_format,
- "retry_times": retry_times}, time_out=25))
- log("get interface return")
- log("office use time " + str(time.time()-start_time))
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post office_interface... left times " + str(retry_times_1))
- continue
- file_str = r.get("data")
- if judge_error_code(file_str):
- if retry_times_2 <= 1:
- return file_str
- else:
- retry_times_2 -= 1
- continue
- file_bytes = eval(file_str)
- uid1 = src_path.split(os.sep)[-1].split(".")[0]
- file_path = dest_path + uid1 + "." + target_format
- if not os.path.exists(os.path.dirname(file_path)):
- os.makedirs(os.path.dirname(file_path), mode=0o777)
- with open(file_path, "wb") as f:
- f.write(file_bytes)
- break
- else:
- file_path = office_convert(src_path, dest_path, target_format, retry_times)
- if judge_error_code(file_path):
- return file_path
- return file_path
- except TimeoutError:
- log("from_office_interface timeout error!")
- return [-5]
- except:
- log("from_office_interface error!")
- print("from_office_interface", traceback.print_exc())
- return [-1]
- def from_ocr_interface(image_stream, is_table=False, from_remote=FROM_REMOTE):
- log("into from_ocr_interface")
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- # _ip = ip_pool("ocr", _random=True)
- # _port = port_pool("ocr", _random=True)
- # if _ip == interface_ip_list[1]:
- # _port = ocr_port_list[0]
- # _ip, _port = interface_pool("ocr")
- # ip_port = _ip + ":" + _port
- # ip_port = from_schedule_interface("ocr")
- ip_port = interface_pool_gunicorn("ocr")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/ocr"
- r = json.loads(request_post(_url, {"data": base64_stream,
- "md5": _global.get("md5")},
- time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- if is_table:
- return r, r
- else:
- return r
- else:
- retry_times_1 -= 1
- log("retry post ocr_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_ocr_model") is None:
- print("=========== init ocr model ===========")
- globals().update({"global_ocr_model": OcrModels().get_model()})
- r = ocr(data=base64_stream, ocr_model=globals().get("global_ocr_model"))
- except TimeoutError:
- if is_table:
- return [-5], [-5]
- else:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- if is_table:
- return [-2], [-2]
- else:
- return [-2]
- _dict = r
- text_list = eval(_dict.get("text"))
- bbox_list = eval(_dict.get("bbox"))
- if text_list is None:
- text_list = []
- if bbox_list is None:
- bbox_list = []
- if is_table:
- return text_list, bbox_list
- else:
- if text_list and bbox_list:
- text = get_sequential_data(text_list, bbox_list, html=True)
- if judge_error_code(text):
- return text
- else:
- text = ""
- return text
- except Exception as e:
- log("from_ocr_interface error!")
- # print("from_ocr_interface", e, global_type)
- if is_table:
- return [-1], [-1]
- else:
- return [-1]
- def from_gpu_interface_flask(_dict, model_type, predictor_type):
- log("into from_gpu_interface")
- start_time = time.time()
- try:
- # 调用接口
- _dict.update({"predictor_type": predictor_type, "model_type": model_type})
- if model_type == "ocr":
- use_zlib = True
- else:
- use_zlib = False
- result = _interface(_dict, time_out=30, retry_times=2, use_zlib=use_zlib)
- log("from_gpu_interface finish size " + str(sys.getsizeof(_dict)) + " time " + str(time.time()-start_time))
- return result
- except Exception as e:
- log("from_gpu_interface error!")
- log("from_gpu_interface failed " + str(time.time()-start_time))
- traceback.print_exc()
- return [-2]
- def from_gpu_interface_redis(_dict, model_type, predictor_type):
- log("into from_gpu_interface")
- start_time = time.time()
- try:
- # 调用接口
- _uuid = uuid.uuid1().hex
- _dict.update({"predictor_type": predictor_type, "model_type": model_type,
- "uuid": _uuid})
- _time = time.time()
- log("pickle.dumps(_dict)" + str(_dict))
- redis_db.rpush("producer_"+model_type, pickle.dumps(_dict))
- log("producer_" + model_type + " len " + str(redis_db.llen("producer_" + model_type)))
- log("to producer_" + model_type + " time " + str(time.time()-_time))
- _time = time.time()
- time_out = 300
- while True:
- time.sleep(0.2)
- if time.time() - _time > time_out:
- raise Exception
- if redis_db.hexists("consumer_"+model_type, _uuid):
- time1 = time.time()
- result = redis_db.hget("consumer_"+model_type, _uuid)
- log("from consumer_"+model_type + " time " + str(time.time()-time1))
- break
- result = pickle.loads(result)
- log("from_gpu_interface finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
- return result
- except Exception as e:
- log("from_gpu_interface error!")
- log("from_gpu_interface failed " + str(time.time()-start_time))
- traceback.print_exc()
- return [-2]
- # def from_gpu_flask_sm(_dict, model_type, predictor_type):
- # log("into from_gpu_share_memory")
- # start_time = time.time()
- # shm = None
- # try:
- # # 放入共享内存
- # _time = time.time()
- # np_data = _dict.get("inputs")
- # shm = to_share_memory(np_data)
- # log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
- #
- # # 调用接口
- # _time = time.time()
- # _dict.pop("inputs")
- # _dict.update({"predictor_type": predictor_type, "model_type": model_type,
- # "sm_name": shm.name, "sm_shape": np_data.shape,
- # "sm_dtype": str(np_data.dtype)})
- # result = _interface(_dict, time_out=30, retry_times=2)
- # log("_interface cost " + str(time.time()-_time))
- #
- # # 读取共享内存
- # _time = time.time()
- # sm_name = result.get("sm_name")
- # sm_shape = result.get("sm_shape")
- # sm_dtype = result.get("sm_dtype")
- # sm_dtype = get_np_type(sm_dtype)
- # if sm_name:
- # outputs = from_share_memory(sm_name, sm_shape, sm_dtype)
- # else:
- # log("from_share_memory failed!")
- # raise Exception
- # log("data from share memory " + sm_name + " " + str(time.time()-_time))
- #
- # log("from_gpu_interface finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
- # return {"preds": outputs, "gpu_time": result.get("gpu_time")}
- # except Exception as e:
- # log("from_gpu_interface failed " + str(time.time()-start_time))
- # traceback.print_exc()
- # return [-2]
- # finally:
- # # del b # Unnecessary; merely emphasizing the array is no longer used
- # if shm:
- # try:
- # shm.close()
- # shm.unlink()
- # except FileNotFoundError:
- # log("share memory " + shm.name + " not exists!")
- # except Exception:
- # traceback.print_exc()
- #
- #
- # def from_gpu_share_memory(_dict, model_type, predictor_type):
- # log("into from_gpu_share_memory")
- # start_time = time.time()
- # try:
- # _dict.update({"model_type": model_type, "predictor_type": predictor_type})
- # outputs, gpu_time = share_memory_pool(_dict)
- # log("from_gpu_share_memory finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
- # return {"preds": outputs, "gpu_time": float(gpu_time)}
- # except Exception as e:
- # log("from_gpu_interface failed " + str(time.time()-start_time))
- # traceback.print_exc()
- # return [-2]
- def from_otr_interface2(image_stream):
- log("into from_otr_interface")
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if globals().get("global_otr_model") is None:
- globals().update({"global_otr_model": OtrModels().get_model()})
- print("=========== init otr model ===========")
- r = otr(data=base64_stream, otr_model=globals().get("global_otr_model"))
- except TimeoutError:
- return [-5], [-5], [-5], [-5], [-5]
- except requests.exceptions.ConnectionError as e:
- log("from_otr_interface")
- print("from_otr_interface", traceback.print_exc())
- return [-2], [-2], [-2], [-2], [-2]
- # 处理结果
- _dict = r
- points = eval(_dict.get("points"))
- split_lines = eval(_dict.get("split_lines"))
- bboxes = eval(_dict.get("bboxes"))
- outline_points = eval(_dict.get("outline_points"))
- lines = eval(_dict.get("lines"))
- # print("from_otr_interface len(bboxes)", len(bboxes))
- if points is None:
- points = []
- if split_lines is None:
- split_lines = []
- if bboxes is None:
- bboxes = []
- if outline_points is None:
- outline_points = []
- if lines is None:
- lines = []
- return points, split_lines, bboxes, outline_points, lines
- except Exception as e:
- log("from_otr_interface error!")
- print("from_otr_interface", traceback.print_exc())
- return [-1], [-1], [-1], [-1], [-1]
- def from_otr_interface(image_stream, is_from_pdf=False, from_remote=FROM_REMOTE):
- log("into from_otr_interface")
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- log("from remote")
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- # _ip = ip_pool("otr", _random=True)
- # _port = port_pool("otr", _random=True)
- # if _ip == interface_ip_list[1]:
- # _port = otr_port_list[0]
- ip_port = interface_pool_gunicorn("otr")
- # ip_port = from_schedule_interface("otr")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/otr"
- r = json.loads(request_post(_url, {"data": base64_stream,
- "is_from_pdf": is_from_pdf,
- "md5": _global.get("md5")}, time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post otr_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- log("from local")
- log("otr_model " + str(globals().get("global_otr_model")))
- if globals().get("global_otr_model") is None:
- print("=========== init otr model ===========")
- globals().update({"global_otr_model": OtrModels().get_model()})
- log("init finish")
- r = otr(data=base64_stream, otr_model=globals().get("global_otr_model"), is_from_pdf=is_from_pdf)
- # r = otr(data=base64_stream, otr_model=None, is_from_pdf=is_from_pdf)
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- log("from_otr_interface")
- print("from_otr_interface", traceback.print_exc())
- return [-2]
- # 处理结果
- _dict = r
- list_line = eval(_dict.get("list_line"))
- return list_line
- except Exception as e:
- log("from_otr_interface error!")
- print("from_otr_interface", traceback.print_exc())
- return [-1]
- def from_isr_interface(image_stream, from_remote=FROM_REMOTE):
- log("into from_isr_interface")
- # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
- start_time = time.time()
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- ip_port = interface_pool_gunicorn("isr")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/isr"
- r = json.loads(request_post(_url, {"data": base64_stream,
- "md5": _global.get("md5")},
- time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post isr_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_isr_model") is None:
- print("=========== init isr model ===========")
- isr_yolo_model, isr_model = IsrModels().get_model()
- globals().update({"global_isr_yolo_model": isr_yolo_model})
- globals().update({"global_isr_model": isr_model})
- r = isr(data=base64_stream,
- isr_yolo_model=globals().get("global_isr_yolo_model"),
- isr_model=globals().get("global_isr_model"))
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- return [-2]
- _dict = r
- if from_remote:
- image_string = _dict.get("image")
- if judge_error_code(image_string):
- return image_string
- # [1]代表检测不到印章,直接返回
- if isinstance(image_string, list) and image_string == [1]:
- return image_string
- image_base64 = image_string.encode("utf-8")
- image_bytes = base64.b64decode(image_base64)
- buffer = np.frombuffer(image_bytes, dtype=np.uint8)
- image_np = cv2.imdecode(buffer, 1)
- else:
- image_np = _dict.get("image")
- log("from_isr_interface cost time " + str(time.time()-start_time))
- return image_np
- except Exception as e:
- log("from_isr_interface error!")
- traceback.print_exc()
- return [-11]
- finally:
- # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- pass
- def from_idc_interface(image_stream, from_remote=FROM_REMOTE):
- log("into from_idc_interface")
- start_time = time.time()
- try:
- base64_stream = base64.b64encode(image_stream)
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- ip_port = interface_pool_gunicorn("idc")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/idc"
- r = json.loads(request_post(_url, {"data": base64_stream,
- "md5": _global.get("md5")},
- time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post idc_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_idc_model") is None:
- print("=========== init idc model ===========")
- idc_model = IdcModels().get_model()
- globals().update({"global_idc_model": idc_model})
- r = idc(data=base64_stream,
- model=globals().get("global_idc_model"))
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- return [-2]
- _dict = r
- angle = _dict.get("angle")
- log("from_idc_interface cost time " + str(time.time()-start_time))
- return angle
- except Exception as e:
- log("from_idc_interface error!")
- traceback.print_exc()
- return [-10]
- def from_atc_interface(text, from_remote=FROM_REMOTE):
- log("into from_atc_interface")
- start_time = time.time()
- try:
- # 调用接口
- try:
- if from_remote:
- retry_times_1 = 3
- # 重试
- while retry_times_1:
- ip_port = interface_pool_gunicorn("atc")
- if judge_error_code(ip_port):
- return ip_port
- _url = ip_port + "/atc"
- r = json.loads(request_post(_url, {"data": text,
- "md5": _global.get("md5")},
- time_out=60))
- log("get interface return")
- if type(r) == list:
- # 接口连不上换个端口重试
- if retry_times_1 <= 1:
- return r
- else:
- retry_times_1 -= 1
- log("retry post atc_interface... left times " + str(retry_times_1))
- continue
- if judge_error_code(r):
- return r
- break
- else:
- if globals().get("global_atc_model") is None:
- print("=========== init atc model ===========")
- atc_model = AtcModels().get_model()
- globals().update({"global_atc_model": atc_model})
- r = atc(data=text,
- model=globals().get("global_atc_model"))
- except TimeoutError:
- return [-5]
- except requests.exceptions.ConnectionError as e:
- return [-2]
- _dict = r
- classification = _dict.get("classification")
- log("from_atc_interface cost time " + str(time.time()-start_time))
- return classification
- except Exception as e:
- log("from_atc_interface error!")
- traceback.print_exc()
- return [-10]
- # def from_schedule_interface(interface_type):
- # try:
- # _ip = "http://" + get_intranet_ip()
- # _port = ip_port_dict.get(_ip).get("schedule")[0]
- # _url = _ip + ":" + _port + "/schedule"
- # data = {"interface_type": interface_type}
- # result = json.loads(request_post(_url, data, time_out=10)).get("data")
- # if judge_error_code(result):
- # return result
- # _ip, _port = result
- # log("from_schedule_interface " + _ip + " " + _port)
- # return _ip + ":" + _port
- # except requests.exceptions.ConnectionError as e:
- # log("from_schedule_interface ConnectionError")
- # return [-2]
- # except:
- # log("from_schedule_interface error!")
- # traceback.print_exc()
- # return [-1]
- def interface_pool(interface_type, use_gunicorn=True):
- ip_port_flag = _global.get("ip_port_flag")
- ip_port_dict = _global.get("ip_port")
- try:
- if use_gunicorn:
- _ip = "http://127.0.0.1"
- _port = ip_port_dict.get(_ip).get(interface_type)[0]
- ip_port = _ip + ":" + str(_port)
- log(ip_port)
- return ip_port
- # 负载均衡, 选取ip
- interface_load_list = []
- for _ip in ip_port_flag.keys():
- if ip_port_dict.get(_ip).get(interface_type):
- load_scale = ip_port_flag.get(_ip).get(interface_type) / len(ip_port_dict.get(_ip).get(interface_type))
- interface_load_list.append([_ip, load_scale])
- if not interface_load_list:
- raise NotFound
- interface_load_list.sort(key=lambda x: x[-1])
- _ip = interface_load_list[0][0]
- # 负载均衡, 选取port
- ip_type_cnt = ip_port_flag.get(_ip).get(interface_type)
- ip_type_total = len(ip_port_dict.get(_ip).get(interface_type))
- if ip_type_cnt == 0:
- ip_type_cnt = random.randint(0, ip_type_total-1)
- port_index = ip_type_cnt % ip_type_total
- _port = ip_port_dict.get(_ip).get(interface_type)[port_index]
- # 更新flag
- current_flag = ip_type_cnt
- if current_flag >= 10000:
- ip_port_flag[_ip][interface_type] = 0
- else:
- ip_port_flag[_ip][interface_type] = current_flag + 1
- _global.update({"ip_port_flag": ip_port_flag})
- # log(str(_global.get("ip_port_flag")))
- ip_port = _ip + ":" + str(_port)
- log(ip_port)
- return ip_port
- except NotFound:
- log("cannot read ip from config! checkout config")
- return [-2]
- except:
- traceback.print_exc()
- return [-1]
- def interface_pool_gunicorn(interface_type):
- ip_port_flag_dict = _global.get("ip_port_flag")
- ip_port_dict = _global.get("ip_port")
- try:
- if ip_port_dict is None or ip_port_flag_dict is None:
- raise NotFound
- # 负载均衡, 选取有该接口的ip
- min_cnt = 10000.
- interface_cnt = 0
- _ip = None
- port_list = []
- for key in ip_port_flag_dict.keys():
- temp_port_list = get_args_from_config(ip_port_dict, key, interface_type)
- if not temp_port_list:
- continue
- interface_cnt = ip_port_flag_dict.get(key).get(interface_type)
- if interface_cnt is not None and interface_cnt / len(temp_port_list[0]) < min_cnt:
- _ip = key
- min_cnt = interface_cnt / len(temp_port_list[0])
- port_list = temp_port_list[0]
- # 选取端口
- if interface_type == "office":
- # 刚开始随机,后续求余
- if min_cnt == 0:
- _port = port_list[random.randint(0, len(port_list)-1)]
- ip_port_flag_dict[_ip][interface_type] = int(_port[-2:])
- else:
- _port = port_list[interface_cnt % len(port_list)]
- else:
- # 使用gunicorn则直接选第一个
- _port = port_list[0]
- # 更新flag
- if ip_port_flag_dict.get(_ip).get(interface_type) >= 10000:
- ip_port_flag_dict[_ip][interface_type] = 0
- else:
- ip_port_flag_dict[_ip][interface_type] += 1
- _global.update({"ip_port_flag": ip_port_flag_dict})
- ip_port = _ip + ":" + str(_port)
- log(interface_type)
- log(ip_port)
- return ip_port
- except NotFound:
- log("ip_flag or ip_port_dict is None! checkout config")
- return [-2]
- except:
- traceback.print_exc()
- return [-1]
- def interface_pool_gunicorn_old(interface_type):
- ip_flag_list = _global.get("ip_flag")
- ip_port_flag_dict = _global.get("ip_port_flag")
- ip_port_dict = _global.get("ip_port")
- try:
- if ip_flag_list is None or ip_port_dict is None or ip_port_flag_dict is None:
- raise NotFound
- if interface_type == "office":
- # _ip = "http://127.0.0.1"
- _ip = get_using_ip()
- # 选取端口
- port_list = ip_port_dict.get(_ip).get("MASTER").get(interface_type)
- ip_type_cnt = ip_port_flag_dict.get(_ip).get(interface_type)
- if ip_type_cnt == 0:
- _port = port_list[random.randint(0, len(port_list)-1)]
- else:
- _port = port_list[ip_type_cnt % len(port_list)]
- # 更新flag
- if ip_port_flag_dict.get(_ip).get(interface_type) >= 10000:
- ip_port_flag_dict[_ip][interface_type] = 0
- else:
- ip_port_flag_dict[_ip][interface_type] += 1
- _global.update({"ip_port_flag": ip_port_flag_dict})
- else:
- # 负载均衡, 选取ip
- ip_flag_list.sort(key=lambda x: x[1])
- if ip_flag_list[-1][1] == 0:
- ip_index = random.randint(0, len(ip_flag_list)-1)
- else:
- ip_index = 0
- _ip = ip_flag_list[ip_index][0]
- if "master" in _ip:
- port_index = 1
- else:
- port_index = 0
- _ip = _ip.split("_")[0]
- # 选取端口, 使用gunicorn则直接选第一个
- # _port = ip_port_dict.get(_ip).get("MASTER").get(interface_type)[0]
- log("_ip " + _ip)
- log("interface_type " + interface_type)
- port_list = get_args_from_config(ip_port_dict, _ip, interface_type)
- log("port_list" + str(port_list))
- if port_index >= len(port_list):
- port_index = 0
- _port = port_list[port_index][0]
- # # 选取端口, 使用gunicorn则直接选第一个
- # _ip = _ip.split("_")[0]
- # port_list = get_args_from_config(ip_port_dict, _ip, interface_type)
- # if
- # print(port_list)
- # _port = port_list[0][0]
- # 更新flag
- if ip_flag_list[ip_index][1] >= 10000:
- ip_flag_list[ip_index][1] = 0
- else:
- ip_flag_list[ip_index][1] += + 1
- _global.update({"ip_flag": ip_flag_list})
- ip_port = _ip + ":" + str(_port)
- log(ip_port)
- return ip_port
- except NotFound:
- log("ip_flag or ip_port_dict is None! checkout config")
- return [-2]
- except:
- traceback.print_exc()
- return [-1]
- # def share_memory_pool(args_dict):
- # np_data = args_dict.get("inputs")
- # _type = args_dict.get("model_type")
- # args_dict.update({"sm_shape": np_data.shape, "sm_dtype": str(np_data.dtype)})
- #
- # if _type == 'ocr':
- # port_list = ocr_port_list
- # elif _type == 'otr':
- # port_list = otr_port_list
- # else:
- # log("type error! only support ocr otr")
- # raise Exception
- #
- # # 循环判断是否有空的share memory
- # empty_sm_list = None
- # sm_list_name = ""
- # while empty_sm_list is None:
- # for p in port_list:
- # sm_list_name = "sml_"+_type+"_"+str(p)
- # sm_list = get_share_memory_list(sm_list_name)
- # if sm_list[0] == "0":
- # lock.acquire(timeout=0.1)
- # if sm_list[0] == "0":
- # sm_list[0] = "1"
- # sm_list[-1] = "0"
- # empty_sm_list = sm_list
- # break
- # else:
- # continue
- # lock.release()
- #
- # log(str(os.getppid()) + " empty_sm_list " + sm_list_name)
- #
- # # numpy放入共享内存
- # _time = time.time()
- # release_share_memory(get_share_memory("psm_" + str(os.getpid())))
- # shm = to_share_memory(np_data)
- # log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
- #
- # # 参数放入共享内存列表
- # empty_sm_list[1] = args_dict.get("md5")
- # empty_sm_list[2] = args_dict.get("model_type")
- # empty_sm_list[3] = args_dict.get("predictor_type")
- # empty_sm_list[4] = args_dict.get("args")
- # empty_sm_list[5] = str(shm.name)
- # empty_sm_list[6] = str(args_dict.get("sm_shape"))
- # empty_sm_list[7] = args_dict.get("sm_dtype")
- # empty_sm_list[-1] = "1"
- # # log("empty_sm_list[7] " + empty_sm_list[7])
- # close_share_memory_list(empty_sm_list)
- #
- # # 循环判断是否完成
- # finish_sm_list = get_share_memory_list(sm_list_name)
- # while True:
- # if finish_sm_list[-1] == "0":
- # break
- #
- # # 读取共享内存
- # _time = time.time()
- # sm_name = finish_sm_list[5]
- # sm_shape = finish_sm_list[6]
- # sm_shape = eval(sm_shape)
- # sm_dtype = finish_sm_list[7]
- # gpu_time = finish_sm_list[8]
- # sm_dtype = get_np_type(sm_dtype)
- # outputs = from_share_memory(sm_name, sm_shape, sm_dtype)
- # log(args_dict.get("model_type") + " " + args_dict.get("predictor_type") + " outputs " + str(outputs.shape))
- # log("data from share memory " + sm_name + " " + str(time.time()-_time))
- #
- # # 释放
- # release_share_memory(get_share_memory(sm_name))
- #
- # # 重置share memory list
- # finish_sm_list[-1] = "0"
- # finish_sm_list[0] = "0"
- #
- # close_share_memory_list(finish_sm_list)
- # return outputs, gpu_time
- # def interface_pool(interface_type):
- # try:
- # ip_port_dict = _global.get("ip_port")
- # ip_list = list(ip_port_dict.keys())
- # _ip = random.choice(ip_list)
- # if interface_type != 'office':
- # _port = ip_port_dict.get(_ip).get(interface_type)[0]
- # else:
- # _port = random.choice(ip_port_dict.get(_ip).get(interface_type))
- # log(_ip + ":" + _port)
- # return _ip + ":" + _port
- # except Exception as e:
- # traceback.print_exc()
- # return [-1]
- # def ip_pool(interface_type, _random=False):
- # ip_flag_name = interface_type + '_ip_flag'
- # ip_flag = globals().get(ip_flag_name)
- # if ip_flag is None:
- # if _random:
- # _r = random.randint(0, len(interface_ip_list)-1)
- # ip_flag = _r
- # globals().update({ip_flag_name: ip_flag})
- # ip_index = _r
- # else:
- # ip_flag = 0
- # globals().update({ip_flag_name: ip_flag})
- # ip_index = 0
- # else:
- # ip_index = ip_flag % len(interface_ip_list)
- # ip_flag += 1
- #
- # if ip_flag >= 10000:
- # ip_flag = 0
- # globals().update({ip_flag_name: ip_flag})
- #
- # log("ip_pool " + interface_type + " " + str(ip_flag) + " " + str(interface_ip_list[ip_index]))
- # return interface_ip_list[ip_index]
- #
- #
- # def port_pool(interface_type, _random=False):
- # port_flag_name = interface_type + '_port_flag'
- #
- # port_flag = globals().get(port_flag_name)
- # if port_flag is None:
- # if _random:
- # if interface_type == "ocr":
- # _r = random.randint(0, len(ocr_port_list)-1)
- # elif interface_type == "otr":
- # _r = random.randint(0, len(otr_port_list)-1)
- # else:
- # _r = random.randint(0, len(soffice_port_list)-1)
- # port_flag = _r
- # globals().update({port_flag_name: port_flag})
- # port_index = _r
- # else:
- # port_flag = 0
- # globals().update({port_flag_name: port_flag})
- # port_index = 0
- # else:
- # if interface_type == "ocr":
- # port_index = port_flag % len(ocr_port_list)
- # elif interface_type == "otr":
- # port_index = port_flag % len(otr_port_list)
- # else:
- # port_index = port_flag % len(soffice_port_list)
- # port_flag += 1
- #
- # if port_flag >= 10000:
- # port_flag = 0
- # globals().update({port_flag_name: port_flag})
- #
- # if interface_type == "ocr":
- # log("port_pool " + interface_type + " " + str(port_flag) + " " + ocr_port_list[port_index])
- # return ocr_port_list[port_index]
- # elif interface_type == "otr":
- # log("port_pool " + interface_type + " " + str(port_flag) + " " + otr_port_list[port_index])
- # return otr_port_list[port_index]
- # else:
- # log("port_pool " + interface_type + " " + str(port_flag) + " " + soffice_port_list[port_index])
- # return soffice_port_list[port_index]
- if __name__ == "__main__":
- from format_convert.utils import set_flask_global
- _global._init()
- set_flask_global()
- for i in range(10):
- print("result", interface_pool_gunicorn("otr"))
|