cho_interface.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import json
  2. import logging
  3. import os
  4. import random
  5. import sys
  6. import time
  7. import traceback
  8. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  10. import tensorflow as tf
  11. from flask import Flask, request
  12. from chinese_order.model import lstm_phrase
  13. from chinese_order.inference_phrase import recognize
  14. from utils import request_post
  15. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  16. tf.compat.v1.disable_eager_execution()
  17. sess = tf.compat.v1.Session(graph=tf.Graph())
  18. package_dir = os.path.abspath(os.path.dirname(__file__))
  19. vocabulary_len = 5792
  20. sequence_len = 6
  21. model_path = package_dir + "/models/phrase_f1_0.85.h5"
  22. # 接口配置
  23. app = Flask(__name__)
  24. @app.route('/cho', methods=['POST'])
  25. def cho():
  26. start_time = time.time()
  27. logging.info("into cho_interface cho")
  28. try:
  29. # 接收网络数据
  30. if not request.form:
  31. logging.info("cho no data!")
  32. return json.dumps({"data": "", "success": 0})
  33. data = request.form.get("data")
  34. logging.info("cho_interface get data time" + str(time.time()-start_time))
  35. # 加载模型
  36. cho_model = globals().get("global_cho_model")
  37. if cho_model is None:
  38. print("=========== init cho model ===========")
  39. cho_model = ChoModels().get_model()
  40. globals().update({"global_cho_model": cho_model})
  41. # 数据转换
  42. char_list = json.loads(data)
  43. # 预测
  44. decode_list = recognize(char_list, cho_model, sess)
  45. return json.dumps({"data": decode_list, "success": 1})
  46. except:
  47. traceback.print_exc()
  48. return json.dumps({"data": "", "success": 0})
  49. finally:
  50. logging.info("cho interface finish time " + str(time.time()-start_time))
  51. class ChoModels:
  52. def __init__(self):
  53. with sess.as_default():
  54. with sess.graph.as_default():
  55. self.model = lstm_phrase((sequence_len, vocabulary_len))
  56. self.model.load_weights(model_path)
  57. def get_model(self):
  58. return self.model
  59. def test_cho_model(from_remote=True):
  60. text = "巧克力"
  61. text = [x for x in text]
  62. random.shuffle(text)
  63. if from_remote:
  64. data_json = {"data": json.dumps(text)}
  65. _url = "http://127.0.0.1:17000/cho"
  66. result = json.loads(request_post(_url, data_json))
  67. if result.get("success"):
  68. decode_list = result.get("data")
  69. print("char_list", text)
  70. print("decode_list", decode_list)
  71. else:
  72. print("failed!")
  73. if __name__ == "__main__":
  74. app.run(host='127.0.0.1', port=17058, debug=False)
  75. # test_cho_model()