convert_need_interface.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. # encoding=utf8
  2. import base64
  3. import json
  4. import multiprocessing
  5. import os
  6. import pickle
  7. import random
  8. import sys
  9. import time
  10. import uuid
  11. import cv2
  12. import torch
  13. from werkzeug.exceptions import NotFound
  14. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  15. from botr.yolov8.yolo_interface import yolo
  16. from botr.yolov8.model import Predictor
  17. from atc.atc_interface import AtcModels, atc
  18. from idc.idc_interface import IdcModels, idc
  19. from isr.isr_interface import IsrModels, isr
  20. import traceback
  21. import requests
  22. from format_convert import _global
  23. from format_convert.utils import get_platform, get_sequential_data, judge_error_code, request_post, get_ip_port, \
  24. get_intranet_ip, get_logger, log, get_args_from_config, get_using_ip, np2bytes, set_flask_global
  25. from ocr.ocr_interface import ocr, OcrModels
  26. from otr.otr_interface import otr, OtrModels
  27. from format_convert.libreoffice_interface import office_convert
  28. import numpy as np
  29. from format_convert.max_compute_config import max_compute
  30. MAX_COMPUTE = max_compute
  31. if get_platform() == "Windows":
  32. FROM_REMOTE = False
  33. only_test_ocr = False
  34. if only_test_ocr:
  35. ip_port_flag = {}
  36. ip_port_dict = get_ip_port()
  37. for _k in ip_port_dict.keys():
  38. ip_port_flag.update({_k: {"ocr": 0,
  39. "otr": 0,
  40. "convert": 0,
  41. "office": 0
  42. }})
  43. _global.update({"ip_port_flag": ip_port_flag})
  44. ip_port_dict["http://127.0.0.1"]["ocr"] = ["17000"]
  45. ip_port_dict["http://127.0.0.1"]["otr"] = ["18000"]
  46. _global.update({"ip_port": ip_port_dict})
  47. else:
  48. FROM_REMOTE = True
  49. if MAX_COMPUTE:
  50. FROM_REMOTE = False
  51. lock = multiprocessing.RLock()
  52. # 连接redis数据库
  53. # redis_db = redis.StrictRedis(host='192.168.2.103', port='6379',
  54. # db=1, password='bidi123456', health_check_interval=300)
  55. redis_db = None
  56. def from_office_interface(src_path, dest_path, target_format, retry_times=1, from_remote=FROM_REMOTE):
  57. try:
  58. # Win10跳出超时装饰器
  59. # if get_platform() == "Windows":
  60. # # origin_office_convert = office_convert.__wrapped__
  61. # # file_path = origin_office_convert(src_path, dest_path, target_format, retry_times)
  62. # file_path = office_convert(src_path, dest_path, target_format, retry_times)
  63. # else:
  64. # # 将装饰器包装为一个类,否则多进程Pickle会报错 it's not the same object as xxx 问题,
  65. # # timeout_decorator_obj = my_timeout_decorator.TimeoutClass(office_convert, 180, TimeoutError)
  66. # # file_path = timeout_decorator_obj.run(src_path, dest_path, target_format, retry_times)
  67. #
  68. # file_path = office_convert(src_path, dest_path, target_format, retry_times)
  69. if from_remote:
  70. # 重试
  71. retry_times_1 = 1
  72. retry_times_2 = 2
  73. while retry_times_1 and retry_times_2:
  74. # _ip = ip_pool("soffice", _random=True)
  75. # _port = port_pool("soffice", _random=True)
  76. # _ip = interface_ip_list[0]
  77. # _port = "16002"
  78. # _ip, _port = interface_pool("soffice")
  79. # ip_port = from_schedule_interface("office")
  80. ip_port = interface_pool_gunicorn("office")
  81. if judge_error_code(ip_port):
  82. return ip_port
  83. _url = ip_port + "/soffice"
  84. with open(src_path, "rb") as f:
  85. file_bytes = f.read()
  86. base64_stream = base64.b64encode(file_bytes)
  87. start_time = time.time()
  88. log('office _url ' + str(_url))
  89. r = json.loads(request_post(_url, {"src_path": src_path,
  90. "dest_path": dest_path,
  91. "file": base64_stream,
  92. "target_format": target_format,
  93. "retry_times": retry_times}, time_out=25))
  94. log("get interface return")
  95. log("office use time " + str(time.time()-start_time))
  96. if type(r) == list:
  97. # 接口连不上换个端口重试
  98. if retry_times_1 <= 1:
  99. return r
  100. else:
  101. retry_times_1 -= 1
  102. log("retry post office_interface... left times " + str(retry_times_1))
  103. continue
  104. file_str = r.get("data")
  105. if judge_error_code(file_str):
  106. if retry_times_2 <= 1:
  107. return file_str
  108. else:
  109. retry_times_2 -= 1
  110. continue
  111. file_bytes = eval(file_str)
  112. uid1 = src_path.split(os.sep)[-1].split(".")[0]
  113. file_path = dest_path + uid1 + "." + target_format
  114. if not os.path.exists(os.path.dirname(file_path)):
  115. os.makedirs(os.path.dirname(file_path), mode=0o777)
  116. with open(file_path, "wb") as f:
  117. f.write(file_bytes)
  118. break
  119. else:
  120. file_path = office_convert(src_path, dest_path, target_format, retry_times)
  121. if judge_error_code(file_path):
  122. return file_path
  123. return file_path
  124. except TimeoutError:
  125. log("from_office_interface timeout error!")
  126. return [-5]
  127. except:
  128. log("from_office_interface error!")
  129. print("from_office_interface", traceback.print_exc())
  130. return [-1]
  131. def from_ocr_interface(image_stream, is_table=0, only_rec=0, from_remote=FROM_REMOTE):
  132. log("into from_ocr_interface")
  133. try:
  134. base64_stream = base64.b64encode(image_stream)
  135. # 调用接口
  136. try:
  137. if from_remote:
  138. retry_times_1 = 3
  139. # 重试
  140. while retry_times_1:
  141. ip_port = interface_pool_gunicorn("ocr")
  142. if judge_error_code(ip_port):
  143. return ip_port
  144. _url = ip_port + "/ocr"
  145. r = json.loads(request_post(_url, {"data": base64_stream,
  146. "md5": _global.get("md5"),
  147. "only_rec": only_rec
  148. },
  149. time_out=60))
  150. log("get ocr interface return")
  151. if type(r) == list:
  152. # 接口连不上换个端口重试
  153. if retry_times_1 <= 1:
  154. if is_table:
  155. return r, r
  156. else:
  157. return r
  158. else:
  159. retry_times_1 -= 1
  160. log("retry post ocr_interface... left times " + str(retry_times_1))
  161. continue
  162. if judge_error_code(r):
  163. return r
  164. break
  165. else:
  166. if globals().get("global_ocr_model") is None:
  167. print("=========== init ocr model ===========")
  168. globals().update({"global_ocr_model": OcrModels().get_model()})
  169. r = ocr(data=base64_stream, ocr_model=globals().get("global_ocr_model"), only_rec=only_rec)
  170. except TimeoutError:
  171. if is_table:
  172. return [-5], [-5]
  173. else:
  174. return [-5]
  175. except requests.exceptions.ConnectionError as e:
  176. if is_table:
  177. return [-2], [-2]
  178. else:
  179. return [-2]
  180. _dict = r
  181. text_list = eval(_dict.get("text"))
  182. bbox_list = eval(_dict.get("bbox"))
  183. if text_list is None:
  184. text_list = []
  185. if bbox_list is None:
  186. bbox_list = []
  187. if is_table:
  188. return text_list, bbox_list
  189. else:
  190. if text_list and bbox_list:
  191. text = get_sequential_data(text_list, bbox_list, html=True)
  192. if judge_error_code(text):
  193. return text
  194. else:
  195. text = ""
  196. return text
  197. except Exception as e:
  198. log("from_ocr_interface error!")
  199. log(str(traceback.print_exc()))
  200. traceback.print_exc()
  201. # print("from_ocr_interface", e, global_type)
  202. if is_table:
  203. return [-1], [-1]
  204. else:
  205. return [-1]
  206. def from_gpu_interface_redis(_dict, model_type, predictor_type):
  207. log("into from_gpu_interface")
  208. start_time = time.time()
  209. try:
  210. # 调用接口
  211. _uuid = uuid.uuid1().hex
  212. _dict.update({"predictor_type": predictor_type, "model_type": model_type,
  213. "uuid": _uuid})
  214. _time = time.time()
  215. log("pickle.dumps(_dict)" + str(_dict))
  216. redis_db.rpush("producer_"+model_type, pickle.dumps(_dict))
  217. log("producer_" + model_type + " len " + str(redis_db.llen("producer_" + model_type)))
  218. log("to producer_" + model_type + " time " + str(time.time()-_time))
  219. _time = time.time()
  220. time_out = 300
  221. while True:
  222. time.sleep(0.2)
  223. if time.time() - _time > time_out:
  224. raise Exception
  225. if redis_db.hexists("consumer_"+model_type, _uuid):
  226. time1 = time.time()
  227. result = redis_db.hget("consumer_"+model_type, _uuid)
  228. log("from consumer_"+model_type + " time " + str(time.time()-time1))
  229. break
  230. result = pickle.loads(result)
  231. log("from_gpu_interface finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
  232. return result
  233. except Exception as e:
  234. log("from_gpu_interface error!")
  235. log("from_gpu_interface failed " + str(time.time()-start_time))
  236. traceback.print_exc()
  237. return [-2]
  238. def from_otr_interface(image_stream, is_from_pdf=False, from_remote=FROM_REMOTE):
  239. log("into from_otr_interface")
  240. try:
  241. base64_stream = base64.b64encode(image_stream)
  242. # 调用接口
  243. try:
  244. if from_remote:
  245. log("from remote")
  246. retry_times_1 = 3
  247. # 重试
  248. while retry_times_1:
  249. # _ip = ip_pool("otr", _random=True)
  250. # _port = port_pool("otr", _random=True)
  251. # if _ip == interface_ip_list[1]:
  252. # _port = otr_port_list[0]
  253. ip_port = interface_pool_gunicorn("otr")
  254. # ip_port = from_schedule_interface("otr")
  255. if judge_error_code(ip_port):
  256. return ip_port
  257. _url = ip_port + "/otr"
  258. r = json.loads(request_post(_url, {"data": base64_stream,
  259. "is_from_pdf": is_from_pdf,
  260. "md5": _global.get("md5")}, time_out=60))
  261. log("get interface return")
  262. if type(r) == list:
  263. # 接口连不上换个端口重试
  264. if retry_times_1 <= 1:
  265. return r
  266. else:
  267. retry_times_1 -= 1
  268. log("retry post otr_interface... left times " + str(retry_times_1))
  269. continue
  270. if judge_error_code(r):
  271. return r
  272. break
  273. else:
  274. log("from local")
  275. log("otr_model " + str(globals().get("global_otr_model")))
  276. if globals().get("global_otr_model") is None:
  277. print("=========== init otr model ===========")
  278. globals().update({"global_otr_model": OtrModels().get_model()})
  279. log("init finish")
  280. r = otr(data=base64_stream, otr_model=globals().get("global_otr_model"), is_from_pdf=is_from_pdf)
  281. # r = otr(data=base64_stream, otr_model=None, is_from_pdf=is_from_pdf)
  282. except TimeoutError:
  283. return [-5]
  284. except requests.exceptions.ConnectionError as e:
  285. log("from_otr_interface")
  286. print("from_otr_interface", traceback.print_exc())
  287. return [-2]
  288. # 处理结果
  289. _dict = r
  290. list_line = eval(_dict.get("list_line"))
  291. return list_line
  292. except Exception as e:
  293. log("from_otr_interface error!")
  294. print("from_otr_interface", traceback.print_exc())
  295. return [-1]
  296. def from_isr_interface(image_stream, from_remote=FROM_REMOTE):
  297. log("into from_isr_interface")
  298. # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  299. start_time = time.time()
  300. try:
  301. base64_stream = base64.b64encode(image_stream)
  302. # 调用接口
  303. try:
  304. if from_remote:
  305. retry_times_1 = 3
  306. # 重试
  307. while retry_times_1:
  308. ip_port = interface_pool_gunicorn("isr")
  309. if judge_error_code(ip_port):
  310. return ip_port
  311. _url = ip_port + "/isr"
  312. r = json.loads(request_post(_url, {"data": base64_stream,
  313. "md5": _global.get("md5")},
  314. time_out=60))
  315. log("get interface return")
  316. if type(r) == list:
  317. # 接口连不上换个端口重试
  318. if retry_times_1 <= 1:
  319. return r
  320. else:
  321. retry_times_1 -= 1
  322. log("retry post isr_interface... left times " + str(retry_times_1))
  323. continue
  324. if judge_error_code(r):
  325. return r
  326. break
  327. else:
  328. if globals().get("global_isr_model") is None:
  329. print("=========== init isr model ===========")
  330. isr_yolo_model, isr_model = IsrModels().get_model()
  331. globals().update({"global_isr_yolo_model": isr_yolo_model})
  332. globals().update({"global_isr_model": isr_model})
  333. r = isr(data=base64_stream,
  334. isr_yolo_model=globals().get("global_isr_yolo_model"),
  335. isr_model=globals().get("global_isr_model"))
  336. except TimeoutError:
  337. return [-5]
  338. except requests.exceptions.ConnectionError as e:
  339. return [-2]
  340. _dict = r
  341. if from_remote:
  342. image_string = _dict.get("image")
  343. if judge_error_code(image_string):
  344. return image_string
  345. # [1]代表检测不到印章,直接返回
  346. if isinstance(image_string, list) and image_string == [1]:
  347. return image_string
  348. image_base64 = image_string.encode("utf-8")
  349. image_bytes = base64.b64decode(image_base64)
  350. buffer = np.frombuffer(image_bytes, dtype=np.uint8)
  351. image_np = cv2.imdecode(buffer, 1)
  352. else:
  353. image_np = _dict.get("image")
  354. log("from_isr_interface cost time " + str(time.time()-start_time))
  355. return image_np
  356. except Exception as e:
  357. log("from_isr_interface error!")
  358. traceback.print_exc()
  359. return [-11]
  360. finally:
  361. # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  362. pass
  363. def from_idc_interface(image_stream, from_remote=FROM_REMOTE):
  364. log("into from_idc_interface")
  365. start_time = time.time()
  366. try:
  367. base64_stream = base64.b64encode(image_stream)
  368. # 调用接口
  369. try:
  370. if from_remote:
  371. retry_times_1 = 3
  372. # 重试
  373. while retry_times_1:
  374. ip_port = interface_pool_gunicorn("idc")
  375. if judge_error_code(ip_port):
  376. return ip_port
  377. _url = ip_port + "/idc"
  378. r = json.loads(request_post(_url, {"data": base64_stream,
  379. "md5": _global.get("md5")},
  380. time_out=60))
  381. log("get interface return")
  382. if type(r) == list:
  383. # 接口连不上换个端口重试
  384. if retry_times_1 <= 1:
  385. return r
  386. else:
  387. retry_times_1 -= 1
  388. log("retry post idc_interface... left times " + str(retry_times_1))
  389. continue
  390. if judge_error_code(r):
  391. return r
  392. break
  393. else:
  394. if globals().get("global_idc_model") is None:
  395. print("=========== init idc model ===========")
  396. idc_model = IdcModels().get_model()
  397. globals().update({"global_idc_model": idc_model})
  398. r = idc(data=base64_stream,
  399. model=globals().get("global_idc_model"))
  400. except TimeoutError:
  401. return [-5]
  402. except requests.exceptions.ConnectionError as e:
  403. return [-2]
  404. _dict = r
  405. angle = _dict.get("angle")
  406. log("from_idc_interface cost time " + str(time.time()-start_time))
  407. return angle
  408. except Exception as e:
  409. log("from_idc_interface error!")
  410. traceback.print_exc()
  411. return [-11]
  412. def from_atc_interface(text, from_remote=FROM_REMOTE):
  413. log("into from_atc_interface")
  414. start_time = time.time()
  415. try:
  416. # 调用接口
  417. try:
  418. if from_remote:
  419. retry_times_1 = 3
  420. # 重试
  421. while retry_times_1:
  422. ip_port = interface_pool_gunicorn("atc")
  423. if judge_error_code(ip_port):
  424. return ip_port
  425. _url = ip_port + "/atc"
  426. r = json.loads(request_post(_url, {"data": text,
  427. "md5": _global.get("md5")},
  428. time_out=60))
  429. log("get interface return")
  430. if type(r) == list:
  431. # 接口连不上换个端口重试
  432. if retry_times_1 <= 1:
  433. return r
  434. else:
  435. retry_times_1 -= 1
  436. log("retry post atc_interface... left times " + str(retry_times_1))
  437. continue
  438. if judge_error_code(r):
  439. return r
  440. break
  441. else:
  442. if globals().get("global_atc_model") is None:
  443. print("=========== init atc model ===========")
  444. atc_model = AtcModels().get_model()
  445. globals().update({"global_atc_model": atc_model})
  446. r = atc(data=text,
  447. model=globals().get("global_atc_model"))
  448. except TimeoutError:
  449. return [-5]
  450. except requests.exceptions.ConnectionError as e:
  451. return [-2]
  452. _dict = r
  453. classification = _dict.get("classification")
  454. log("from_atc_interface cost time " + str(time.time()-start_time))
  455. return classification
  456. except Exception as e:
  457. log("from_atc_interface error!")
  458. traceback.print_exc()
  459. return [-11]
  460. def from_yolo_interface(image_stream, from_remote=FROM_REMOTE):
  461. log("into from_yolo_interface")
  462. start_time = time.time()
  463. try:
  464. base64_stream = base64.b64encode(image_stream)
  465. # 调用接口
  466. try:
  467. if from_remote:
  468. retry_times_1 = 3
  469. # 重试
  470. while retry_times_1:
  471. ip_port = interface_pool_gunicorn("yolo")
  472. if judge_error_code(ip_port):
  473. return ip_port
  474. _url = ip_port + "/yolo"
  475. log('yolo _url ' + _url)
  476. r = json.loads(request_post(_url, {"data": base64_stream,
  477. "md5": _global.get("md5")},
  478. time_out=60))
  479. log("get interface return")
  480. if type(r) == list:
  481. # 接口连不上换个端口重试
  482. if retry_times_1 <= 1:
  483. return r
  484. else:
  485. retry_times_1 -= 1
  486. log("retry post yolo_interface... left times " + str(retry_times_1))
  487. continue
  488. if judge_error_code(r):
  489. return r
  490. break
  491. else:
  492. if globals().get("global_yolo_predictor") is None:
  493. print("=========== init yolo model ===========")
  494. ROOT = os.path.abspath(os.path.dirname(__file__)) + '/../'
  495. model_path = ROOT + 'botr/yolov8/weights.pt'
  496. image_size = 640
  497. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  498. yolo_predictor = Predictor(image_size, device, model_path)
  499. globals().update({"global_yolo_predictor": yolo_predictor})
  500. r = yolo(data=base64_stream,
  501. predictor=globals().get("global_yolo_predictor"))
  502. except TimeoutError:
  503. return [-5]
  504. except requests.exceptions.ConnectionError as e:
  505. return [-2]
  506. _dict = r
  507. b_table_list = _dict.get("b_table_list")
  508. log("from_yolo_interface cost time " + str(time.time()-start_time))
  509. return b_table_list
  510. except Exception as e:
  511. log("from_yolo_interface error!")
  512. traceback.print_exc()
  513. return [-11]
  514. def interface_pool_gunicorn(interface_type):
  515. # if get_platform() == 'Windows':
  516. # set_flask_global()
  517. ip_port_flag_dict = _global.get("ip_port_flag")
  518. ip_port_dict = _global.get("ip_port")
  519. try:
  520. if ip_port_dict is None or ip_port_flag_dict is None:
  521. print('_global', _global.get_dict())
  522. raise NotFound
  523. # 负载均衡, 选取有该接口的ip
  524. min_cnt = 10000.
  525. interface_cnt = 0
  526. _ip = None
  527. port_list = []
  528. for key in ip_port_flag_dict.keys():
  529. temp_port_list = get_args_from_config(ip_port_dict, key, interface_type)
  530. # print('temp_port_list', temp_port_list)
  531. if not temp_port_list:
  532. continue
  533. # 该ip下的该接口总数量(可能有多gpu接口)
  534. _port_list, _port_num_list, _ = temp_port_list[0]
  535. # print('_port_num_list', _port_num_list)
  536. total_port_num = sum(_port_num_list)
  537. if total_port_num == 0:
  538. continue
  539. interface_cnt = ip_port_flag_dict.get(key).get(interface_type)
  540. if interface_cnt is not None and interface_cnt / total_port_num < min_cnt:
  541. _ip = key
  542. min_cnt = interface_cnt / len(temp_port_list[0])
  543. # 选定ip,设置gpu的接口候选比例
  544. gpu_port_list = []
  545. for k in range(len(_port_list)):
  546. gpu_port_list += [_port_list[k]] * _port_num_list[k]
  547. port_list = gpu_port_list
  548. # port_list = temp_port_list[0]
  549. # 选取端口
  550. if interface_type == "office":
  551. if len(port_list) == 0:
  552. raise ConnectionError
  553. port_list = [str(port_list[k] + k) for k in range(len(port_list))]
  554. # 刚开始随机,后续求余
  555. if min_cnt == 0:
  556. _port = port_list[random.randint(0, len(port_list)-1)]
  557. ip_port_flag_dict[_ip][interface_type] = int(_port[-2:])
  558. else:
  559. _port = port_list[interface_cnt % len(port_list)]
  560. else:
  561. # 使用gunicorn则随机选
  562. _port = random.choice(port_list)
  563. # 更新flag
  564. if ip_port_flag_dict.get(_ip).get(interface_type) >= 10000:
  565. ip_port_flag_dict[_ip][interface_type] = 0
  566. else:
  567. ip_port_flag_dict[_ip][interface_type] += 1
  568. _global.update({"ip_port_flag": ip_port_flag_dict})
  569. ip_port = _ip + ":" + str(_port)
  570. log(interface_type)
  571. log(ip_port)
  572. return ip_port
  573. except NotFound:
  574. log("ip_port or ip_port_dict is None! checkout config")
  575. return [-2]
  576. except ConnectionError:
  577. log('no office interface running!')
  578. return [-15]
  579. except:
  580. traceback.print_exc()
  581. return [-1]
  582. if __name__ == "__main__":
  583. _global._init()
  584. set_flask_global()
  585. _img = cv2.imread(r"C:/Users/Administrator/Desktop/test_b_table/error11.png")
  586. _img_bytes = np2bytes(_img)
  587. b_list = from_yolo_interface(_img_bytes, from_remote=True)
  588. for l in b_list:
  589. for b in l:
  590. cv2.rectangle(_img, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (0, 0, 255), 2)
  591. cv2.namedWindow('img', cv2.WINDOW_NORMAL)
  592. cv2.imshow('img', _img)
  593. cv2.waitKey(0)