import base64 import json import os import sys import time import traceback os.environ["CUDA_VISIBLE_DEVICES"] = "-1" sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from config.max_compute_config import MAX_COMPUTE import tensorflow as tf if not MAX_COMPUTE: # tensorflow 内存设置 try: gpus = tf.config.list_physical_devices('GPU') if len(gpus) > 0: tf.config.experimental.set_virtual_device_configuration( gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=512)]) except: traceback.print_exc() # pass # gpus = tf.config.list_physical_devices('GPU') # for gpu in gpus: # 如果使用多块GPU时 # tf.config.experimental.set_memory_growth(gpu, True) os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648) os.environ['CUDA_CACHE_DISABLE'] = str(0) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) from format_convert import _global from format_convert.utils import log, bytes2np, request_post, get_md5_from_bytes from flask import Flask, request from atc.predict import ModelRelationExtraction def atc(data, model): log("into atc_interface atc") try: # print("data", data) _, classification = model.evaluate(data) return {"classification": classification} except: traceback.print_exc() return {"classification": ""} # 接口配置 app = Flask(__name__) @app.route('/atc', methods=['POST']) def _atc(): _global._init() _global.update({"port": globals().get("port")}) start_time = time.time() log("into atc_interface _atc") try: if not request.form: log("atc no data!") return json.dumps({"classification": str([-9])}) data = request.form.get("data") log("atc_interface get data time" + str(time.time()-start_time)) _md5 = request.form.get("md5") _global.update({"md5": _md5}) atc_model = globals().get("global_atc_model") if atc_model is None: print("=========== init atc model ===========") atc_model = AtcModels().get_model() globals().update({"global_atc_model": atc_model}) classification = atc(data, atc_model).get("classification") return json.dumps({"classification": classification}) except: traceback.print_exc() return json.dumps({"classification": ""}) finally: log("atc interface finish time " + str(time.time()-start_time)) class AtcModels: def __init__(self): self.model = ModelRelationExtraction() def get_model(self): return self.model def test_atc_model(from_remote=False): atc_model = AtcModels().get_model() text = '''招标文件项目编号:SDGP370302202102000110项目名称:淄川经济开发区中心小学校园智能化采购项目采购人:山东淄川经 济开发区管理委员会采购代理机构:淄博正益招标有限公司发出日期:2021年8月目录第一章投标邀请7一、项目基本情况7二、申请人的资格要 求8三、获取招标文件8四、提交投标文件截止时间、开标时间和地点8五、公告期限9六、其他补充事宜9第二章投标人须知11一、总则161.采 购人、采购代理机构及投标人162.资金来源183.投标费用184.适用法律18二、招标文件185.招标文件构成186.招标文件的澄清与修改207.投 标截止时间的顺延20三、投标文件的编制208.编制要求209.投标范围及投标文件中标准和计量单位的使用2110.投标文件构成2211.投标报价241 2.电子版投标文件2513.投标保证金2614.投标有效期2615.投标文件的签署及规定26四、投标文件的递交2616.投标文件的递交2617.递交 投标文件的截止时间2718.投标文件的接收、修改与撤回27五、开标及评标2719.开标2720.资格审查2821.组建评标委员会2922.投标文件符 合性审查与澄清3023.投标偏离3224.投标无效3225.比较和评价3326.废标3527.保密要求36六、确定中标3628.中标候选人的确定原则及标 准3629.确定中标候选人和中标人3630.采购任务取消3631.中标通知书3632.签订合同3633.履约保证金3734.政府采购融资担保3735.预付 款3736.廉洁自律规定3737.人员回避3738.质疑与接收3739.项目其他相关费用3940.合同公示3941.验收4042.履约验收公示4043.招标文 件解释权40第三章货物需求41一、项目概述41 ''' _md5 = get_md5_from_bytes(text.encode('utf-8'))[0] _global._init() _global.update({"port": 18060, "md5": _md5}) if from_remote: file_json = {"data": text, "md5": _md5} # _url = "http://192.168.2.102:17000/ocr" _url = "http://127.0.0.1:18060/atc" print(json.loads(request_post(_url, file_json))) else: result = atc(text, atc_model) # print(result) if type(result.get("angle")) == list: print(result) else: classification = result.get("classification") print("classification", classification) if __name__ == "__main__": test_atc_model()