123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- import base64
- import json
- import os
- import sys
- import time
- import traceback
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from format_convert.max_compute_config import max_compute
- import tensorflow as tf
- MAX_COMPUTE = max_compute
- if not MAX_COMPUTE:
- # tensorflow 内存设置
- try:
- gpus = tf.config.list_physical_devices('GPU')
- if len(gpus) > 0:
- tf.config.experimental.set_virtual_device_configuration(
- gpus[0],
- [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=512)])
- except:
- traceback.print_exc()
- # pass
- # gpus = tf.config.list_physical_devices('GPU')
- # for gpu in gpus: # 如果使用多块GPU时
- # tf.config.experimental.set_memory_growth(gpu, True)
- os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
- os.environ['CUDA_CACHE_DISABLE'] = str(0)
- gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
- sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
- from format_convert import _global
- from format_convert.utils import log, bytes2np, request_post, get_md5_from_bytes
- from flask import Flask, request
- from atc.predict import ModelRelationExtraction
- def atc(data, model):
- log("into atc_interface atc")
- try:
- # print("data", data)
- _, classification = model.evaluate(data)
- return {"classification": classification}
- except:
- traceback.print_exc()
- return {"classification": ""}
- # 接口配置
- app = Flask(__name__)
- @app.route('/atc', methods=['POST'])
- def _atc():
- _global._init()
- _global.update({"port": globals().get("port")})
- start_time = time.time()
- log("into atc_interface _atc")
- try:
- if not request.form:
- log("atc no data!")
- return json.dumps({"classification": str([-9])})
- data = request.form.get("data")
- log("atc_interface get data time" + str(time.time()-start_time))
- _md5 = request.form.get("md5")
- _global.update({"md5": _md5})
- atc_model = globals().get("global_atc_model")
- if atc_model is None:
- print("=========== init atc model ===========")
- atc_model = AtcModels().get_model()
- globals().update({"global_atc_model": atc_model})
- classification = atc(data, atc_model).get("classification")
- return json.dumps({"classification": classification})
- except:
- traceback.print_exc()
- return json.dumps({"classification": ""})
- finally:
- log("atc interface finish time " + str(time.time()-start_time))
- class AtcModels:
- def __init__(self):
- self.model = ModelRelationExtraction()
- def get_model(self):
- return self.model
- def test_atc_model(from_remote=False):
- atc_model = AtcModels().get_model()
- text = '''招标文件项目编号:SDGP370302202102000110项目名称:淄川经济开发区中心小学校园智能化采购项目采购人:山东淄川经
- 济开发区管理委员会采购代理机构:淄博正益招标有限公司发出日期:2021年8月目录第一章投标邀请7一、项目基本情况7二、申请人的资格要
- 求8三、获取招标文件8四、提交投标文件截止时间、开标时间和地点8五、公告期限9六、其他补充事宜9第二章投标人须知11一、总则161.采
- 购人、采购代理机构及投标人162.资金来源183.投标费用184.适用法律18二、招标文件185.招标文件构成186.招标文件的澄清与修改207.投
- 标截止时间的顺延20三、投标文件的编制208.编制要求209.投标范围及投标文件中标准和计量单位的使用2110.投标文件构成2211.投标报价241
- 2.电子版投标文件2513.投标保证金2614.投标有效期2615.投标文件的签署及规定26四、投标文件的递交2616.投标文件的递交2617.递交
- 投标文件的截止时间2718.投标文件的接收、修改与撤回27五、开标及评标2719.开标2720.资格审查2821.组建评标委员会2922.投标文件符
- 合性审查与澄清3023.投标偏离3224.投标无效3225.比较和评价3326.废标3527.保密要求36六、确定中标3628.中标候选人的确定原则及标
- 准3629.确定中标候选人和中标人3630.采购任务取消3631.中标通知书3632.签订合同3633.履约保证金3734.政府采购融资担保3735.预付
- 款3736.廉洁自律规定3737.人员回避3738.质疑与接收3739.项目其他相关费用3940.合同公示3941.验收4042.履约验收公示4043.招标文
- 件解释权40第三章货物需求41一、项目概述41
- '''
- _md5 = get_md5_from_bytes(text.encode('utf-8'))[0]
- _global._init()
- _global.update({"port": 18060, "md5": _md5})
- if from_remote:
- file_json = {"data": text, "md5": _md5}
- # _url = "http://192.168.2.102:17000/ocr"
- _url = "http://127.0.0.1:18060/atc"
- print(json.loads(request_post(_url, file_json)))
- else:
- result = atc(text, atc_model)
- # print(result)
- if type(result.get("angle")) == list:
- print(result)
- else:
- classification = result.get("classification")
- print("classification", classification)
- if __name__ == "__main__":
- test_atc_model()
|