import json import logging import os import random import sys import time import traceback os.environ["CUDA_VISIBLE_DEVICES"] = "-1" sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") import tensorflow as tf from flask import Flask, request from chinese_order.model import lstm_phrase from chinese_order.inference_phrase import recognize from utils import request_post logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') tf.compat.v1.disable_eager_execution() sess = tf.compat.v1.Session(graph=tf.Graph()) package_dir = os.path.abspath(os.path.dirname(__file__)) vocabulary_len = 5792 sequence_len = 6 model_path = package_dir + "/models/phrase_f1_0.85.h5" # 接口配置 app = Flask(__name__) @app.route('/cho', methods=['POST']) def cho(): start_time = time.time() logging.info("into cho_interface cho") try: # 接收网络数据 if not request.form: logging.info("cho no data!") return json.dumps({"data": "", "success": 0}) data = request.form.get("data") logging.info("cho_interface get data time" + str(time.time()-start_time)) # 加载模型 cho_model = globals().get("global_cho_model") if cho_model is None: print("=========== init cho model ===========") cho_model = ChoModels().get_model() globals().update({"global_cho_model": cho_model}) # 数据转换 char_list = json.loads(data) # 预测 decode_list = recognize(char_list, cho_model, sess) return json.dumps({"data": decode_list, "success": 1}) except: traceback.print_exc() return json.dumps({"data": "", "success": 0}) finally: logging.info("cho interface finish time " + str(time.time()-start_time)) class ChoModels: def __init__(self): with sess.as_default(): with sess.graph.as_default(): self.model = lstm_phrase((sequence_len, vocabulary_len)) self.model.load_weights(model_path) def get_model(self): return self.model def test_cho_model(from_remote=True): text = "巧克力" text = [x for x in text] random.shuffle(text) if from_remote: data_json = {"data": json.dumps(text)} _url = "http://127.0.0.1:17000/cho" result = json.loads(request_post(_url, data_json)) if result.get("success"): decode_list = result.get("data") print("char_list", text) print("decode_list", decode_list) else: print("failed!") if __name__ == "__main__": app.run(host='127.0.0.1', port=17058, debug=False) # test_cho_model()