test_ocr_interface.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. import base64
  2. import copy
  3. import ctypes
  4. import gc
  5. import hashlib
  6. import inspect
  7. import multiprocessing
  8. import os
  9. import random
  10. import traceback
  11. from glob import glob, iglob
  12. import threading
  13. import time
  14. import urllib
  15. import psutil
  16. import requests
  17. import json
  18. import sys
  19. from multiprocessing import Process, Pool
  20. __dir__ = os.path.dirname(os.path.abspath(__file__))
  21. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  22. from format_convert.convert import convert
  23. from ocr.ocr_interface import ocr, OcrModels
  24. from otr.otr_interface import otr, OtrModels
  25. from format_convert.judge_platform import get_platform
  26. class myThread(threading.Thread):
  27. def __init__(self, threadName):
  28. threading.Thread.__init__(self)
  29. self.threadName = threadName
  30. def run(self):
  31. while True:
  32. start_time = time.time()
  33. test_convert()
  34. print(self.threadName, "finish!", time.time()-start_time)
  35. class myThread_appendix(threading.Thread):
  36. def __init__(self, threadName, _list):
  37. threading.Thread.__init__(self)
  38. self.threadName = threadName
  39. self._list = _list
  40. def run(self):
  41. start_time = time.time()
  42. test_appendix_downloaded(self._list)
  43. print(self.threadName, "finish!", time.time()-start_time)
  44. def test_ocr():
  45. with open("test_files/开标记录表3_page_0.png", "rb") as f:
  46. base64_data = base64.b64encode(f.read())
  47. # print(base64_data)
  48. url = local_url + ":15011" + '/ocr'
  49. # url = 'http://127.0.0.1:15013/ocr'
  50. r = requests.post(url, data=base64_data, timeout=2000)
  51. # print("test:", r.content.decode("utf-8"))
  52. def test_otr():
  53. with open("test_files/开标记录表3_page_0.png", "rb") as f:
  54. base64_data = base64.b64encode(f.read())
  55. # print(base64_data)
  56. url = local_url + ":15017" + '/otr'
  57. # url = 'http://127.0.0.1:15013/ocr'
  58. r = requests.post(url, data=base64_data, timeout=2000)
  59. # print("test:", r.content.decode("utf-8"))
  60. def test_convert():
  61. # path = "开标记录表3.pdf"
  62. # path = "test_files/开标记录表3_page_0.png"
  63. # path = "test_files/1.docx"
  64. # path = '光明食品(集团)有限公司2017年度经审计的合并及母公司财务报表.pdf'
  65. # path = '光明.pdf'
  66. # path = 'D:/BIDI_DOC/比地_文档/Oracle11g学生成绩管理系统.docx'
  67. # path = "C:\\Users\\Administrator\\Desktop\\1600825332753119.doc"
  68. # path = "temp/complex/8.png"
  69. # path = "合同备案.doc"
  70. # path = "1.png"
  71. # path = "1.pdf"
  72. # path = "(清单)衢州市第二人民医院二期工程电缆采购项目.xls"
  73. # path = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \
  74. # "\\00fb43acbc7e11eb836000163e0ae709.png"
  75. # path = "D:\\BIDI_DOC\\比地_文档\\8a949486788ccc6d017969f189301d41.pdf"
  76. # path = "be8a17f2cc1b11eba26800163e0857b6.docx"
  77. # path = "江苏省通州中等专业学校春节物资采购公 告.docx"
  78. # path = "test_files/1.zip"
  79. # path = "C:\\Users\\Administrator\\Desktop\\33f52292cdad11ebb58300163e0857b6.zip"
  80. path = "C:\\Users\\Administrator\\Desktop\\Test_Interface\\1623392355541.zip"
  81. with open(path, "rb") as f:
  82. base64_data = base64.b64encode(f.read())
  83. # print(base64_data)
  84. url = _url + '/convert'
  85. # url = 'http://127.0.0.1:15014/convert'
  86. # headers = {'Content-Type': 'application/json'}
  87. headers = {
  88. 'Connection': 'keep-alive'
  89. }
  90. data = urllib.parse.urlencode({"file": base64_data, "type": path.split(".")[-1]}).encode('utf-8')
  91. req = urllib.request.Request(url, data=data, headers=headers)
  92. with urllib.request.urlopen(req) as response:
  93. _dict = eval(response.read().decode("utf-8"))
  94. result = _dict.get("result")
  95. is_success = _dict.get("is_success")
  96. print("is_success", is_success)
  97. print("len(result)", len(result))
  98. for i in range(len(result)):
  99. print("=================")
  100. print(result[i])
  101. print("-----------------")
  102. # print(len(eval(r.content.decode("utf-8")).get("result")))
  103. # print(r.content)
  104. def test_appendix_downloaded(_list):
  105. # 直接使用下载好的附件
  106. i = 0
  107. # for docid_file in glob("/mnt/html_files/*"):
  108. for docid_file in _list:
  109. if i % 100 == 0:
  110. print("Loop", i)
  111. # print(docid_file)
  112. for file_path in iglob(docid_file + "/*"):
  113. print(file_path)
  114. with open(file_path, "rb") as f:
  115. base64_data = base64.b64encode(f.read())
  116. url = _url + '/convert'
  117. # print(url)
  118. try:
  119. # headers = {
  120. # 'Connection': 'keep-alive'
  121. # }
  122. # data = urllib.parse.urlencode({"file": base64_data, "type": file_path.split(".")[-1]}).encode('utf-8')
  123. # req = urllib.request.Request(url, data=data, headers=headers)
  124. # with urllib.request.urlopen(req, timeout=2000) as response:
  125. # _dict = eval(response.read().decode("utf-8"))
  126. # timeout=2000
  127. r = requests.post(url, data={"file": base64_data,
  128. "type": file_path.split(".")[-1]}, timeout=2000)
  129. _dict = eval(r.content.decode("utf-8"))
  130. print("is_success:", _dict.get("is_success"))
  131. except Exception as e:
  132. print("docid " + str(docid_file) + " time out!", e)
  133. i += 1
  134. def test_convert_maxcompute():
  135. try:
  136. ocr_model = OcrModels().get_model()
  137. otr_model = OtrModels().get_model()
  138. path_list = []
  139. path_suffix = "未命名4.pdf"
  140. if get_platform() == "Windows":
  141. path_prefix = "C:\\Users\\Administrator\\Desktop\\Test_ODPS\\"
  142. # path_prefix = "C:\\Users\\Administrator\\Desktop\\"
  143. path_list.append(path_prefix + path_suffix)
  144. else:
  145. path_list.append(path_suffix)
  146. result_list = []
  147. for path in path_list:
  148. with open(path, "rb") as f:
  149. base64_data = base64.b64encode(f.read())
  150. # print("------------")
  151. # print(base64_data)
  152. # print('------------')
  153. data = {"file": base64_data, "type": path.split(".")[-1]}
  154. result_dict = convert(data, ocr_model, otr_model)
  155. print("garbage object num:%d" % (len(gc.garbage)))
  156. _unreachable = gc.collect()
  157. print("unreachable object num:%d" % (_unreachable))
  158. print("garbage object num:%d" % (len(gc.garbage)))
  159. result_list.append(result_dict)
  160. for result_dict in result_list:
  161. result = result_dict.get("result_text")
  162. is_success = result_dict.get("is_success")
  163. for i in range(len(result)):
  164. print("=================", "is_success", is_success, i, "in", len(result))
  165. # _dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
  166. # _dir = os.path.abspath(_dir) + os.sep
  167. # if i == 0:
  168. # with open(_dir + "result.html", "w") as ff:
  169. # ff.write(result[i])
  170. # else:
  171. # with open(_dir + "result.html", "a") as ff:
  172. # ff.write("<div>=================================================</div>")
  173. # ff.write(result[i])
  174. # print("write result to", _dir + "result.html")
  175. del otr_model
  176. del ocr_model
  177. gc.collect()
  178. except Exception as e:
  179. print(e)
  180. usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
  181. print("memory 2", str(usage))
  182. def getMDFFromFile(path):
  183. _length = 0
  184. try:
  185. _md5 = hashlib.md5()
  186. with open(path, "rb") as ff:
  187. while True:
  188. data = ff.read(4096)
  189. if not data:
  190. break
  191. _length += len(data)
  192. _md5.update(data)
  193. return _md5.hexdigest(), _length
  194. except Exception as e:
  195. traceback.print_exc()
  196. return None, _length
  197. def get_base64():
  198. path = "C:\\Users\\Administrator\\Desktop\\Test_ODPS\\1623430252934.doc"
  199. with open(path, "rb") as f:
  200. base64_data = base64.b64encode(f.read())
  201. print("------------")
  202. print(base64_data)
  203. print('------------')
  204. print(getMDFFromFile(path))
  205. def test_init_model():
  206. class MyThread(threading.Thread):
  207. def __init__(self):
  208. super(MyThread, self).__init__()
  209. self.ocr_model = OcrModels().get_model()
  210. self.otr_model = OtrModels().get_model()
  211. def run(self):
  212. self.result = random.randint(1, 10)
  213. def get_result(self):
  214. return self.result
  215. def _async_raise(self, tid, exctype):
  216. """raises the exception, performs cleanup if needed"""
  217. tid = ctypes.c_long(tid)
  218. if not inspect.isclass(exctype):
  219. exctype = type(exctype)
  220. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  221. if res == 0:
  222. raise ValueError("invalid thread id")
  223. elif res != 1:
  224. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  225. raise SystemError("PyThreadState_SetAsyncExc failed")
  226. def stop_thread(self, tid):
  227. self._async_raise(tid, SystemExit)
  228. class GetModel:
  229. def __init__(self):
  230. # usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  231. # print("memory 2", str(usage))
  232. return
  233. def process(self):
  234. thread = MyThread()
  235. thread.start()
  236. thread.join()
  237. result = thread.get_result()
  238. print(result)
  239. if thread.is_alive():
  240. thread.stop_thread(thread.ident)
  241. # usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  242. # print("memory 3", str(usage))
  243. m = GetModel()
  244. m.process()
  245. # spawn模式复制进程,否则模型挂起
  246. # multiprocessing.set_start_method('spawn', force=True)
  247. ocr_model = ""
  248. otr_model = ""
  249. class TestProcess:
  250. def __init__(self):
  251. super(TestProcess, self).__init__()
  252. self.process_num = 2
  253. self.data_list = []
  254. self.result_list = []
  255. self.current_data = ""
  256. self.result_num = 0
  257. def child_process_1(self):
  258. # 初始化模型
  259. globals().update({"ocr_model": OcrModels().get_model()})
  260. globals().update({"otr_model": OtrModels().get_model()})
  261. # 循环转换
  262. for data in self.data_list:
  263. self.current_data = data
  264. # self.child_process_2()
  265. p = Process(target=self.child_process_2)
  266. p.start()
  267. p.join()
  268. if p.is_alive():
  269. print("p.close")
  270. p.close()
  271. # 初始化
  272. self.data_list = []
  273. # 删除之前模型
  274. global ocr_model, otr_model
  275. del ocr_model
  276. del otr_model
  277. gc.collect()
  278. def child_process_2(self):
  279. global ocr_model, otr_model
  280. result = convert(self.current_data, ocr_model, otr_model)
  281. print("result", result.get("is_success"))
  282. self.result_list.append(result)
  283. print("len(self.result_list)======================", len(self.result_list))
  284. self.result_num += 1
  285. def process(self, path_list):
  286. for path in path_list:
  287. with open(path, "rb") as f:
  288. base64_data = base64.b64encode(f.read())
  289. data = {"file": base64_data, "type": path.split(".")[-1]}
  290. self.data_list.append(data)
  291. # 攒够10条数据执行
  292. if len(self.data_list) == self.process_num:
  293. p = Process(target=self.child_process_1)
  294. p.start()
  295. p.join()
  296. p.close()
  297. print("init data_list result_list!")
  298. self.data_list = []
  299. print("self.result_num", self.result_num)
  300. def test_convert_process():
  301. t = TestProcess()
  302. t.process(["1623430252934.doc", "1623430252934.doc"])
  303. usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
  304. print("----- memory info start - test_convert_process" + " - " + str(usage) + " GB")
  305. # t.process(["1.docx", "1.docx"])
  306. # usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
  307. # print("----- memory info start - test_convert_process" + " - " + str(usage) + " GB")
  308. gpu_url = "http://192.168.2.101"
  309. memory_url = "http://47.97.90.190"
  310. local_url = "http://127.0.0.1"
  311. production_url = "http://47.98.57.0"
  312. _url = local_url + ":15015"
  313. if __name__ == '__main__':
  314. # test_convert()
  315. # test_convert_process()
  316. test_convert_maxcompute()
  317. # test_init_model()
  318. # test_ocr()
  319. # test_otr()
  320. # test_appendix_downloaded()
  321. # get_base64()
  322. # print(getMDFFromFile("C:\\Users\\Administrator\\Desktop\\Test_ODPS\\1624900794475.docx"))
  323. # 多线程调用 #####################################
  324. # threads_num = 30
  325. # thread_list = []
  326. # glob_list = glob("html_files/*")
  327. # sub_num = int(len(glob_list) / threads_num)
  328. # print(len(glob_list), sub_num)
  329. #
  330. # for i in range(threads_num):
  331. # if i == threads_num - 1:
  332. # _list = glob_list[i*sub_num:]
  333. # else:
  334. # _list = glob_list[i*sub_num:(i+1)*sub_num]
  335. # print(i*sub_num, len(_list))
  336. #
  337. # thread = myThread_appendix("Thread-"+str(i), _list)
  338. # thread_list.append(thread)
  339. #
  340. # for thread in thread_list:
  341. # thread.start()
  342. # for thread in thread_list:
  343. # thread.join()