convert_image.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # encoding=utf8
  2. import copy
  3. import inspect
  4. import io
  5. import logging
  6. import os
  7. import sys
  8. import time
  9. import requests
  10. import numpy as np
  11. from PIL import Image
  12. sys.path.append(os.path.dirname(__file__) + "/../")
  13. from pdfminer.layout import LTLine
  14. import traceback
  15. import cv2
  16. from isr.pre_process import count_red_pixel
  17. from format_convert.utils import judge_error_code, add_div, LineTable, get_table_html, get_logger, log, \
  18. memory_decorator, pil_resize
  19. from format_convert.convert_need_interface import from_otr_interface, from_ocr_interface, from_gpu_interface_redis, \
  20. from_idc_interface, from_isr_interface
  21. from format_convert.table_correct import get_rotated_image
  22. def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, use_ocr=True):
  23. from format_convert.convert_tree import _Table, _Sentence
  24. def get_cluster(t_list, b_list, axis):
  25. zip_list = list(zip(t_list, b_list))
  26. if len(zip_list) == 0:
  27. return t_list, b_list
  28. if len(zip_list[0]) > 0:
  29. zip_list.sort(key=lambda x: x[1][axis][1])
  30. cluster_list = []
  31. margin = 5
  32. for text, bbox in zip_list:
  33. _find = 0
  34. for cluster in cluster_list:
  35. if abs(cluster[1] - bbox[axis][1]) <= margin:
  36. cluster[0].append([text, bbox])
  37. cluster[1] = bbox[axis][1]
  38. _find = 1
  39. break
  40. if not _find:
  41. cluster_list.append([[[text, bbox]], bbox[axis][1]])
  42. new_text_list = []
  43. new_bbox_list = []
  44. for cluster in cluster_list:
  45. # print("=============convert_image")
  46. # print("cluster_list", cluster)
  47. center_y = 0
  48. for text, bbox in cluster[0]:
  49. center_y += bbox[axis][1]
  50. center_y = int(center_y / len(cluster[0]))
  51. for text, bbox in cluster[0]:
  52. bbox[axis][1] = center_y
  53. new_text_list.append(text)
  54. new_bbox_list.append(bbox)
  55. # print("cluster_list", cluster)
  56. return new_text_list, new_bbox_list
  57. def merge_textbox(textbox_list, in_objs):
  58. delete_obj = []
  59. threshold = 5
  60. textbox_list.sort(key=lambda x:x.bbox[0])
  61. for k in range(len(textbox_list)):
  62. tb1 = textbox_list[k]
  63. if tb1 not in in_objs and tb1 not in delete_obj:
  64. for m in range(k+1, len(textbox_list)):
  65. tb2 = textbox_list[m]
  66. if tb2 in in_objs:
  67. continue
  68. if abs(tb1.bbox[1]-tb2.bbox[1]) <= threshold \
  69. and abs(tb1.bbox[3]-tb2.bbox[3]) <= threshold:
  70. if tb1.bbox[0] <= tb2.bbox[0]:
  71. tb1.text = tb1.text + tb2.text
  72. else:
  73. tb1.text = tb2.text + tb1.text
  74. tb1.bbox[0] = min(tb1.bbox[0], tb2.bbox[0])
  75. tb1.bbox[2] = max(tb1.bbox[2], tb2.bbox[2])
  76. delete_obj.append(tb2)
  77. for _obj in delete_obj:
  78. if _obj in textbox_list:
  79. textbox_list.remove(_obj)
  80. return textbox_list
  81. log("into image_preprocess")
  82. try:
  83. # 图片倾斜校正,写入原来的图片路径
  84. # print("image_process", image_path)
  85. g_r_i = get_rotated_image(image_np, image_path)
  86. if judge_error_code(g_r_i):
  87. if is_from_docx:
  88. return []
  89. else:
  90. return g_r_i
  91. image_np = cv2.imread(image_path)
  92. image_np_copy = copy.deepcopy(image_np)
  93. if image_np is None:
  94. return []
  95. # if image_np is None:
  96. # return []
  97. #
  98. # # idc模型实现图片倾斜校正
  99. # image_resize = pil_resize(image_np, 640, 640)
  100. # image_resize_path = image_path.split(".")[0] + "_resize_idc." + image_path.split(".")[-1]
  101. # cv2.imwrite(image_resize_path, image_resize)
  102. #
  103. # with open(image_resize_path, "rb") as f:
  104. # image_bytes = f.read()
  105. # angle = from_idc_interface(image_bytes)
  106. # if judge_error_code(angle):
  107. # if is_from_docx:
  108. # return []
  109. # else:
  110. # return angle
  111. # # 根据角度旋转
  112. # image_pil = Image.fromarray(image_np)
  113. # image_np = np.array(image_pil.rotate(angle, expand=1))
  114. # # 写入
  115. # idc_path = image_path.split(".")[0] + "_idc." + image_path.split(".")[-1]
  116. # cv2.imwrite(idc_path, image_np)
  117. # isr模型去除印章
  118. _isr_time = time.time()
  119. if count_red_pixel(image_np):
  120. # 红色像素达到一定值才过模型
  121. with open(image_path, "rb") as f:
  122. image_bytes = f.read()
  123. image_np = from_isr_interface(image_bytes)
  124. if judge_error_code(image_np):
  125. if is_from_docx:
  126. return []
  127. else:
  128. return image_np
  129. # [1]代表检测不到印章,直接返回
  130. if isinstance(image_np, list) and image_np == [1]:
  131. log("no seals detected!")
  132. image_np = image_np_copy
  133. else:
  134. isr_path = image_path.split(".")[0] + "_isr." + image_path.split(".")[-1]
  135. cv2.imwrite(isr_path, image_np)
  136. log("isr total time "+str(time.time()-_isr_time))
  137. # otr模型识别表格,需要图片resize成模型所需大小, 写入另一个路径
  138. best_h, best_w = get_best_predict_size(image_np)
  139. # image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
  140. image_resize = pil_resize(image_np, best_h, best_w)
  141. image_resize_path = image_path.split(".")[0] + "_resize_otr." + image_path.split(".")[-1]
  142. cv2.imwrite(image_resize_path, image_resize)
  143. # 调用otr模型接口
  144. with open(image_resize_path, "rb") as f:
  145. image_bytes = f.read()
  146. list_line = from_otr_interface(image_bytes, is_from_pdf)
  147. if judge_error_code(list_line):
  148. return list_line
  149. # # 预处理
  150. # if is_from_pdf:
  151. # prob = 0.2
  152. # else:
  153. # prob = 0.5
  154. # with open(image_resize_path, "rb") as f:
  155. # image_bytes = f.read()
  156. # img_new, inputs = table_preprocess(image_bytes, prob)
  157. # if type(img_new) is list and judge_error_code(img_new):
  158. # return img_new
  159. # log("img_new.shape " + str(img_new.shape))
  160. #
  161. # # 调用模型运行接口
  162. # _dict = {"inputs": inputs, "md5": _global.get("md5")}
  163. # result = from_gpu_interface(_dict, model_type="otr", predictor_type="")
  164. # if judge_error_code(result):
  165. # logging.error("from_gpu_interface failed! " + str(result))
  166. # raise requests.exceptions.RequestException
  167. #
  168. # pred = result.get("preds")
  169. # gpu_time = result.get("gpu_time")
  170. # log("otr model predict time " + str(gpu_time))
  171. #
  172. # # # 解压numpy
  173. # # decompressed_array = io.BytesIO()
  174. # # decompressed_array.write(pred)
  175. # # decompressed_array.seek(0)
  176. # # pred = np.load(decompressed_array, allow_pickle=True)['arr_0']
  177. # # log("inputs.shape" + str(pred.shape))
  178. #
  179. # 调用gpu共享内存处理
  180. # _dict = {"inputs": inputs, "md5": _global.get("md5")}
  181. # result = from_gpu_share_memory(_dict, model_type="otr", predictor_type="")
  182. # if judge_error_code(result):
  183. # logging.error("from_gpu_interface failed! " + str(result))
  184. # raise requests.exceptions.RequestException
  185. #
  186. # pred = result.get("preds")
  187. # gpu_time = result.get("gpu_time")
  188. # log("otr model predict time " + str(gpu_time))
  189. #
  190. # # 后处理
  191. # list_line = table_postprocess(img_new, pred, prob)
  192. # log("len(list_line) " + str(len(list_line)))
  193. # if judge_error_code(list_line):
  194. # return list_line
  195. # otr resize后得到的bbox根据比例还原
  196. start_time = time.time()
  197. ratio = (image_np.shape[0]/best_h, image_np.shape[1]/best_w)
  198. for i in range(len(list_line)):
  199. point = list_line[i]
  200. list_line[i] = [int(point[0]*ratio[1]), int(point[1]*ratio[0]),
  201. int(point[2]*ratio[1]), int(point[3]*ratio[0])]
  202. log("otr resize bbox recover " + str(time.time()-start_time))
  203. # ocr图片过大内存溢出,需resize
  204. start_time = time.time()
  205. threshold = 3000
  206. if image_np.shape[0] >= threshold or image_np.shape[1] >= threshold:
  207. best_h, best_w = get_best_predict_size2(image_np, threshold)
  208. # image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
  209. image_resize = pil_resize(image_np, best_h, best_w)
  210. image_resize_path = image_path.split(".")[0] + "_resize_ocr." + image_path.split(".")[-1]
  211. cv2.imwrite(image_resize_path, image_resize)
  212. log("ocr resize before " + str(time.time()-start_time))
  213. # 调用ocr模型接口
  214. with open(image_resize_path, "rb") as f:
  215. image_bytes = f.read()
  216. text_list, bbox_list = from_ocr_interface(image_bytes, is_table=True)
  217. if judge_error_code(text_list):
  218. return text_list
  219. # # PaddleOCR内部包括预处理,调用模型运行接口,后处理
  220. # paddle_ocr = PaddleOCR(use_angle_cls=True, lang="ch")
  221. # results = paddle_ocr.ocr(image_resize, det=True, rec=True, cls=True)
  222. # # 循环每张图片识别结果
  223. # text_list = []
  224. # bbox_list = []
  225. # for line in results:
  226. # # print("ocr_interface line", line)
  227. # text_list.append(line[-1][0])
  228. # bbox_list.append(line[0])
  229. # if len(text_list) == 0:
  230. # return []
  231. # ocr resize后的bbox还原
  232. ratio = (image_np.shape[0]/best_h, image_np.shape[1]/best_w)
  233. for i in range(len(bbox_list)):
  234. point = bbox_list[i]
  235. bbox_list[i] = [[int(point[0][0]*ratio[1]), int(point[0][1]*ratio[0])],
  236. [int(point[1][0]*ratio[1]), int(point[1][1]*ratio[0])],
  237. [int(point[2][0]*ratio[1]), int(point[2][1]*ratio[0])],
  238. [int(point[3][0]*ratio[1]), int(point[3][1]*ratio[0])]]
  239. # for _a,_b in zip(text_list,bbox_list):
  240. # print("bbox1",_a,_b)
  241. # 调用现成方法形成表格
  242. try:
  243. from format_convert.convert_tree import TableLine
  244. list_lines = []
  245. for line in list_line:
  246. list_lines.append(LTLine(1, (line[0], line[1]), (line[2], line[3])))
  247. from format_convert.convert_tree import TextBox
  248. list_text_boxes = []
  249. for i in range(len(bbox_list)):
  250. bbox = bbox_list[i]
  251. b_text = text_list[i]
  252. list_text_boxes.append(TextBox([bbox[0][0], bbox[0][1],
  253. bbox[2][0], bbox[2][1]], b_text))
  254. # for _textbox in list_text_boxes:
  255. # print("==",_textbox.get_text())
  256. lt = LineTable()
  257. tables, obj_in_table, _ = lt.recognize_table(list_text_boxes, list_lines, False)
  258. # 合并同一行textbox
  259. list_text_boxes = merge_textbox(list_text_boxes, obj_in_table)
  260. obj_list = []
  261. for table in tables:
  262. obj_list.append(_Table(table["table"], table["bbox"]))
  263. for text_box in list_text_boxes:
  264. if text_box not in obj_in_table:
  265. obj_list.append(_Sentence(text_box.get_text(), text_box.bbox))
  266. return obj_list
  267. except:
  268. traceback.print_exc()
  269. return [-8]
  270. except Exception as e:
  271. log("image_preprocess error")
  272. traceback.print_exc()
  273. return [-1]
  274. @memory_decorator
  275. def picture2text(path, html=False):
  276. log("into picture2text")
  277. try:
  278. # 判断图片中表格
  279. img = cv2.imread(path)
  280. if img is None:
  281. return [-3]
  282. text = image_process(img, path)
  283. if judge_error_code(text):
  284. return text
  285. if html:
  286. text = add_div(text)
  287. return [text]
  288. except Exception as e:
  289. log("picture2text error!")
  290. print("picture2text", traceback.print_exc())
  291. return [-1]
  292. def get_best_predict_size(image_np, times=64):
  293. sizes = []
  294. for i in range(1, 100):
  295. if i*times <= 1300:
  296. sizes.append(i*times)
  297. sizes.sort(key=lambda x: x, reverse=True)
  298. min_len = 10000
  299. best_height = sizes[0]
  300. for height in sizes:
  301. if abs(image_np.shape[0] - height) < min_len:
  302. min_len = abs(image_np.shape[0] - height)
  303. best_height = height
  304. min_len = 10000
  305. best_width = sizes[0]
  306. for width in sizes:
  307. if abs(image_np.shape[1] - width) < min_len:
  308. min_len = abs(image_np.shape[1] - width)
  309. best_width = width
  310. return best_height, best_width
  311. def get_best_predict_size2(image_np, threshold=3000):
  312. h, w = image_np.shape[:2]
  313. scale = threshold / max(h, w)
  314. h = int(h * scale)
  315. w = int(w * scale)
  316. return h, w
  317. class ImageConvert:
  318. def __init__(self, path, unique_type_dir):
  319. from format_convert.convert_tree import _Document
  320. self._doc = _Document(path)
  321. self.path = path
  322. self.unique_type_dir = unique_type_dir
  323. def init_package(self):
  324. # 各个包初始化
  325. try:
  326. with open(self.path, "rb") as f:
  327. self.image = f.read()
  328. except:
  329. log("cannot open image!")
  330. traceback.print_exc()
  331. self._doc.error_code = [-3]
  332. def convert(self):
  333. from format_convert.convert_tree import _Page, _Image
  334. self.init_package()
  335. if self._doc.error_code is not None:
  336. return
  337. _page = _Page(None, 0)
  338. _image = _Image(self.image, self.path)
  339. _page.add_child(_image)
  340. self._doc.add_child(_page)
  341. def get_html(self):
  342. try:
  343. self.convert()
  344. except:
  345. traceback.print_exc()
  346. self._doc.error_code = [-1]
  347. if self._doc.error_code is not None:
  348. return self._doc.error_code
  349. return self._doc.get_html()