123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- import json
- import logging
- import os
- import random
- import sys
- import time
- import traceback
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- import tensorflow as tf
- from flask import Flask, request
- from chinese_order.model import lstm_phrase
- from chinese_order.inference_phrase import recognize
- from utils import request_post
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- tf.compat.v1.disable_eager_execution()
- sess = tf.compat.v1.Session(graph=tf.Graph())
- package_dir = os.path.abspath(os.path.dirname(__file__))
- vocabulary_len = 5792
- sequence_len = 6
- model_path = package_dir + "/models/phrase_f1_0.85.h5"
- # 接口配置
- app = Flask(__name__)
- @app.route('/cho', methods=['POST'])
- def cho():
- start_time = time.time()
- logging.info("into cho_interface cho")
- try:
- # 接收网络数据
- if not request.form:
- logging.info("cho no data!")
- return json.dumps({"data": "", "success": 0})
- data = request.form.get("data")
- logging.info("cho_interface get data time" + str(time.time()-start_time))
- # 加载模型
- cho_model = globals().get("global_cho_model")
- if cho_model is None:
- print("=========== init cho model ===========")
- cho_model = ChoModels().get_model()
- globals().update({"global_cho_model": cho_model})
- # 数据转换
- char_list = json.loads(data)
- # 预测
- decode_list = recognize(char_list, cho_model, sess)
- return json.dumps({"data": decode_list, "success": 1})
- except:
- traceback.print_exc()
- return json.dumps({"data": "", "success": 0})
- finally:
- logging.info("cho interface finish time " + str(time.time()-start_time))
- class ChoModels:
- def __init__(self):
- with sess.as_default():
- with sess.graph.as_default():
- self.model = lstm_phrase((sequence_len, vocabulary_len))
- self.model.load_weights(model_path)
- def get_model(self):
- return self.model
- def test_cho_model(from_remote=True):
- text = "巧克力"
- text = [x for x in text]
- random.shuffle(text)
- if from_remote:
- data_json = {"data": json.dumps(text)}
- _url = "http://127.0.0.1:17000/cho"
- result = json.loads(request_post(_url, data_json))
- if result.get("success"):
- decode_list = result.get("data")
- print("char_list", text)
- print("decode_list", decode_list)
- else:
- print("failed!")
- if __name__ == "__main__":
- app.run(host='127.0.0.1', port=17058, debug=False)
- # test_cho_model()
|