atc_interface.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import base64
  2. import json
  3. import os
  4. import sys
  5. import time
  6. import traceback
  7. from glob import glob
  8. import numpy as np
  9. import cv2
  10. import tensorflow as tf
  11. from PIL import Image
  12. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  13. MAX_COMPUTE = False
  14. if not MAX_COMPUTE:
  15. # tensorflow 内存设置
  16. try:
  17. gpus = tf.config.list_physical_devices('GPU')
  18. if len(gpus) > 0:
  19. tf.config.experimental.set_virtual_device_configuration(
  20. gpus[0],
  21. [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=512)])
  22. except:
  23. traceback.print_exc()
  24. # pass
  25. # gpus = tf.config.list_physical_devices('GPU')
  26. # for gpu in gpus: # 如果使用多块GPU时
  27. # tf.config.experimental.set_memory_growth(gpu, True)
  28. os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
  29. os.environ['CUDA_CACHE_DISABLE'] = str(0)
  30. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
  31. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  32. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  33. from format_convert import _global
  34. from format_convert.utils import log, bytes2np, request_post, get_md5_from_bytes
  35. from flask import Flask, request
  36. from atc.predict import ModelRelationExtraction
  37. def atc(data, model):
  38. log("into atc_interface atc")
  39. try:
  40. # print("data", data)
  41. _, classification = model.evaluate(data)
  42. return {"classification": classification}
  43. except TimeoutError:
  44. return {"classification": [-5]}
  45. except:
  46. traceback.print_exc()
  47. return {"classification": [-1]}
  48. # 接口配置
  49. app = Flask(__name__)
  50. @app.route('/atc', methods=['POST'])
  51. def _atc():
  52. _global._init()
  53. _global.update({"port": globals().get("port")})
  54. start_time = time.time()
  55. log("into atc_interface _atc")
  56. try:
  57. if not request.form:
  58. log("atc no data!")
  59. return json.dumps({"classification": str([-9])})
  60. data = request.form.get("data")
  61. log("atc_interface get data time" + str(time.time()-start_time))
  62. _md5 = request.form.get("md5")
  63. _global.update({"md5": _md5})
  64. atc_model = globals().get("global_atc_model")
  65. if atc_model is None:
  66. print("=========== init atc model ===========")
  67. atc_model = AtcModels().get_model()
  68. globals().update({"global_atc_model": atc_model})
  69. classification = atc(data, atc_model).get("classification")
  70. return json.dumps({"classification": classification})
  71. except TimeoutError:
  72. return json.dumps({"classification": str([-5])})
  73. except:
  74. traceback.print_exc()
  75. return json.dumps({"classification": str([-1])})
  76. finally:
  77. log("atc interface finish time " + str(time.time()-start_time))
  78. class AtcModels:
  79. def __init__(self):
  80. self.model = ModelRelationExtraction()
  81. def get_model(self):
  82. return self.model
  83. def test_atc_model(from_remote=False):
  84. atc_model = AtcModels().get_model()
  85. text = '''招标文件项目编号:SDGP370302202102000110项目名称:淄川经济开发区中心小学校园智能化采购项目采购人:山东淄川经
  86. 济开发区管理委员会采购代理机构:淄博正益招标有限公司发出日期:2021年8月目录第一章投标邀请7一、项目基本情况7二、申请人的资格要
  87. 求8三、获取招标文件8四、提交投标文件截止时间、开标时间和地点8五、公告期限9六、其他补充事宜9第二章投标人须知11一、总则161.采
  88. 购人、采购代理机构及投标人162.资金来源183.投标费用184.适用法律18二、招标文件185.招标文件构成186.招标文件的澄清与修改207.投
  89. 标截止时间的顺延20三、投标文件的编制208.编制要求209.投标范围及投标文件中标准和计量单位的使用2110.投标文件构成2211.投标报价241
  90. 2.电子版投标文件2513.投标保证金2614.投标有效期2615.投标文件的签署及规定26四、投标文件的递交2616.投标文件的递交2617.递交
  91. 投标文件的截止时间2718.投标文件的接收、修改与撤回27五、开标及评标2719.开标2720.资格审查2821.组建评标委员会2922.投标文件符
  92. 合性审查与澄清3023.投标偏离3224.投标无效3225.比较和评价3326.废标3527.保密要求36六、确定中标3628.中标候选人的确定原则及标
  93. 准3629.确定中标候选人和中标人3630.采购任务取消3631.中标通知书3632.签订合同3633.履约保证金3734.政府采购融资担保3735.预付
  94. 款3736.廉洁自律规定3737.人员回避3738.质疑与接收3739.项目其他相关费用3940.合同公示3941.验收4042.履约验收公示4043.招标文
  95. 件解释权40第三章货物需求41一、项目概述41
  96. '''
  97. _md5 = get_md5_from_bytes(text.encode('utf-8'))[0]
  98. _global._init()
  99. _global.update({"port": 18060, "md5": _md5})
  100. if from_remote:
  101. file_json = {"data": text, "md5": _md5}
  102. # _url = "http://192.168.2.102:17000/ocr"
  103. _url = "http://127.0.0.1:18060/atc"
  104. print(json.loads(request_post(_url, file_json)))
  105. else:
  106. result = atc(text, atc_model)
  107. # print(result)
  108. if type(result.get("angle")) == list:
  109. print(result)
  110. else:
  111. classification = result.get("classification")
  112. print("classification", classification)
  113. if __name__ == "__main__":
  114. test_atc_model()