atc_interface.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import base64
  2. import json
  3. import os
  4. import sys
  5. import time
  6. import traceback
  7. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  9. from format_convert.max_compute_config import max_compute
  10. import tensorflow as tf
  11. MAX_COMPUTE = max_compute
  12. if not MAX_COMPUTE:
  13. # tensorflow 内存设置
  14. try:
  15. gpus = tf.config.list_physical_devices('GPU')
  16. if len(gpus) > 0:
  17. tf.config.experimental.set_virtual_device_configuration(
  18. gpus[0],
  19. [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=512)])
  20. except:
  21. traceback.print_exc()
  22. # pass
  23. # gpus = tf.config.list_physical_devices('GPU')
  24. # for gpu in gpus: # 如果使用多块GPU时
  25. # tf.config.experimental.set_memory_growth(gpu, True)
  26. os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
  27. os.environ['CUDA_CACHE_DISABLE'] = str(0)
  28. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
  29. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  30. from format_convert import _global
  31. from format_convert.utils import log, bytes2np, request_post, get_md5_from_bytes
  32. from flask import Flask, request
  33. from atc.predict import ModelRelationExtraction
  34. def atc(data, model):
  35. log("into atc_interface atc")
  36. try:
  37. # print("data", data)
  38. _, classification = model.evaluate(data)
  39. return {"classification": classification}
  40. except:
  41. traceback.print_exc()
  42. return {"classification": ""}
  43. # 接口配置
  44. app = Flask(__name__)
  45. @app.route('/atc', methods=['POST'])
  46. def _atc():
  47. _global._init()
  48. _global.update({"port": globals().get("port")})
  49. start_time = time.time()
  50. log("into atc_interface _atc")
  51. try:
  52. if not request.form:
  53. log("atc no data!")
  54. return json.dumps({"classification": str([-9])})
  55. data = request.form.get("data")
  56. log("atc_interface get data time" + str(time.time()-start_time))
  57. _md5 = request.form.get("md5")
  58. _global.update({"md5": _md5})
  59. atc_model = globals().get("global_atc_model")
  60. if atc_model is None:
  61. print("=========== init atc model ===========")
  62. atc_model = AtcModels().get_model()
  63. globals().update({"global_atc_model": atc_model})
  64. classification = atc(data, atc_model).get("classification")
  65. return json.dumps({"classification": classification})
  66. except:
  67. traceback.print_exc()
  68. return json.dumps({"classification": ""})
  69. finally:
  70. log("atc interface finish time " + str(time.time()-start_time))
  71. class AtcModels:
  72. def __init__(self):
  73. self.model = ModelRelationExtraction()
  74. def get_model(self):
  75. return self.model
  76. def test_atc_model(from_remote=False):
  77. atc_model = AtcModels().get_model()
  78. text = '''招标文件项目编号:SDGP370302202102000110项目名称:淄川经济开发区中心小学校园智能化采购项目采购人:山东淄川经
  79. 济开发区管理委员会采购代理机构:淄博正益招标有限公司发出日期:2021年8月目录第一章投标邀请7一、项目基本情况7二、申请人的资格要
  80. 求8三、获取招标文件8四、提交投标文件截止时间、开标时间和地点8五、公告期限9六、其他补充事宜9第二章投标人须知11一、总则161.采
  81. 购人、采购代理机构及投标人162.资金来源183.投标费用184.适用法律18二、招标文件185.招标文件构成186.招标文件的澄清与修改207.投
  82. 标截止时间的顺延20三、投标文件的编制208.编制要求209.投标范围及投标文件中标准和计量单位的使用2110.投标文件构成2211.投标报价241
  83. 2.电子版投标文件2513.投标保证金2614.投标有效期2615.投标文件的签署及规定26四、投标文件的递交2616.投标文件的递交2617.递交
  84. 投标文件的截止时间2718.投标文件的接收、修改与撤回27五、开标及评标2719.开标2720.资格审查2821.组建评标委员会2922.投标文件符
  85. 合性审查与澄清3023.投标偏离3224.投标无效3225.比较和评价3326.废标3527.保密要求36六、确定中标3628.中标候选人的确定原则及标
  86. 准3629.确定中标候选人和中标人3630.采购任务取消3631.中标通知书3632.签订合同3633.履约保证金3734.政府采购融资担保3735.预付
  87. 款3736.廉洁自律规定3737.人员回避3738.质疑与接收3739.项目其他相关费用3940.合同公示3941.验收4042.履约验收公示4043.招标文
  88. 件解释权40第三章货物需求41一、项目概述41
  89. '''
  90. _md5 = get_md5_from_bytes(text.encode('utf-8'))[0]
  91. _global._init()
  92. _global.update({"port": 18060, "md5": _md5})
  93. if from_remote:
  94. file_json = {"data": text, "md5": _md5}
  95. # _url = "http://192.168.2.102:17000/ocr"
  96. _url = "http://127.0.0.1:18060/atc"
  97. print(json.loads(request_post(_url, file_json)))
  98. else:
  99. result = atc(text, atc_model)
  100. # print(result)
  101. if type(result.get("angle")) == list:
  102. print(result)
  103. else:
  104. classification = result.get("classification")
  105. print("classification", classification)
  106. if __name__ == "__main__":
  107. test_atc_model()