table_line_new.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057
  1. import copy
  2. import time
  3. import traceback
  4. import numpy as np
  5. import cv2
  6. import matplotlib.pyplot as plt
  7. from format_convert.utils import log, pil_resize
  8. def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
  9. log("into table_line, prob is " + str(prob))
  10. # resize
  11. w, h = size
  12. img_new = pil_resize(img, h, w)
  13. img_show = copy.deepcopy(img_new)
  14. # predict
  15. start_time = time.time()
  16. pred = model.predict(np.array([img_new]))
  17. pred = pred[0]
  18. log("otr model predict time " + str(time.time() - start_time))
  19. # show
  20. show(pred, title='pred', prob=prob, mode=1, is_test=is_test)
  21. # 根据点获取线
  22. start_time = time.time()
  23. line_list = points2lines(pred, False, prob=prob)
  24. log("points2lines " + str(time.time() - start_time))
  25. if not line_list:
  26. return []
  27. show(line_list, title="points2lines", mode=2, is_test=is_test)
  28. # 清除短线
  29. start_time = time.time()
  30. line_list = delete_short_lines(line_list, img_new.shape)
  31. show(line_list, title="delete_short_lines", mode=2, is_test=is_test)
  32. log("delete_short_lines " + str(time.time() - start_time))
  33. # 分成横竖线
  34. start_time = time.time()
  35. row_line_list = []
  36. col_line_list = []
  37. for line in line_list:
  38. if line[0] == line[2]:
  39. col_line_list.append(line)
  40. elif line[1] == line[3]:
  41. row_line_list.append(line)
  42. log("divide rows and cols " + str(time.time() - start_time))
  43. # 两种线都需要存在,否则跳过
  44. if not row_line_list or not col_line_list:
  45. return []
  46. # 合并错开线
  47. start_time = time.time()
  48. row_line_list = merge_line(row_line_list, axis=0)
  49. col_line_list = merge_line(col_line_list, axis=1)
  50. show(row_line_list + col_line_list, title="merge_line", mode=2, is_test=is_test)
  51. log("merge_line " + str(time.time() - start_time))
  52. # 计算交点
  53. cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
  54. if not cross_points:
  55. return []
  56. # 删除无交点线 需重复两次才删的干净
  57. row_line_list, col_line_list = delete_single_lines(row_line_list, col_line_list, cross_points)
  58. cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
  59. row_line_list, col_line_list = delete_single_lines(row_line_list, col_line_list, cross_points)
  60. if not row_line_list or not col_line_list:
  61. return []
  62. # 多个表格分割线,获取多个表格区域
  63. start_time = time.time()
  64. split_lines, split_y = get_split_line(cross_points, col_line_list, img_new)
  65. area_row_line_list, area_col_line_list, area_point_list = get_split_area(split_y, row_line_list, col_line_list, cross_points)
  66. log("get_split_area " + str(time.time() - start_time))
  67. # 根据区域循环
  68. need_split_flag = False
  69. for i in range(len(area_point_list)):
  70. sub_row_line_list = area_row_line_list[i]
  71. sub_col_line_list = area_col_line_list[i]
  72. sub_point_list = area_point_list[i]
  73. # 修复边框
  74. start_time = time.time()
  75. new_rows, new_cols, long_rows, long_cols = fix_outline(img_new,
  76. sub_row_line_list,
  77. sub_col_line_list,
  78. sub_point_list)
  79. # 如有补线
  80. if new_rows or new_cols:
  81. # 连接至补线的延长线
  82. if long_rows:
  83. sub_row_line_list = long_rows
  84. if long_cols:
  85. sub_col_line_list = long_cols
  86. # 新的补线
  87. if new_rows:
  88. sub_row_line_list += new_rows
  89. if new_cols:
  90. sub_col_line_list += new_cols
  91. need_split_flag = True
  92. area_row_line_list[i] = sub_row_line_list
  93. area_col_line_list[i] = sub_col_line_list
  94. row_line_list = [y for x in area_row_line_list for y in x]
  95. col_line_list = [y for x in area_col_line_list for y in x]
  96. if need_split_flag:
  97. # 修复边框后重新计算交点
  98. cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
  99. split_lines, split_y = get_split_line(cross_points, col_line_list, img_new)
  100. area_row_line_list, area_col_line_list, area_point_list = get_split_area(split_y, row_line_list, col_line_list, cross_points)
  101. show(cross_points, title="get_points", img=img_show, mode=4, is_test=is_test)
  102. show(split_lines, title="split_lines", img=img_show, mode=3, is_test=is_test)
  103. show(row_line_list + col_line_list, title="fix_outline", mode=2, is_test=is_test)
  104. log("fix_outline " + str(time.time() - start_time))
  105. # 根据区域循环
  106. for i in range(len(area_point_list)):
  107. sub_row_line_list = area_row_line_list[i]
  108. sub_col_line_list = area_col_line_list[i]
  109. sub_point_list = area_point_list[i]
  110. # 验证轮廓的4个交点
  111. sub_row_line_list, sub_col_line_list = fix_4_points(sub_point_list, sub_row_line_list, sub_col_line_list)
  112. # 把四个边线在加一次
  113. sub_point_list = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
  114. sub_row_line_list, sub_col_line_list = add_outline(sub_point_list, sub_row_line_list, sub_col_line_list)
  115. # 修复内部缺线
  116. start_time = time.time()
  117. sub_row_line_list, sub_col_line_list = fix_inner(sub_row_line_list, sub_col_line_list, sub_point_list)
  118. log("fix_inner " + str(time.time() - start_time))
  119. show(sub_row_line_list + sub_col_line_list, title="fix_inner1", mode=2, is_test=is_test)
  120. # 合并错开
  121. start_time = time.time()
  122. sub_row_line_list = merge_line(sub_row_line_list, axis=0)
  123. sub_col_line_list = merge_line(sub_col_line_list, axis=1)
  124. log("merge_line " + str(time.time() - start_time))
  125. show(sub_row_line_list + sub_col_line_list, title="merge_line", mode=2, is_test=is_test)
  126. # 修复内部线后重新计算交点
  127. start_time = time.time()
  128. cross_points = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
  129. show(cross_points, title="get_points3", img=img_show, mode=4, is_test=is_test)
  130. # 消除线突出,获取标准的线
  131. area_row_line_list[i], area_col_line_list[i] = get_standard_lines(sub_row_line_list, sub_col_line_list)
  132. show(area_row_line_list[i] + area_col_line_list[i], title="get_standard_lines", mode=2, is_test=is_test)
  133. row_line_list = [y for x in area_row_line_list for y in x]
  134. col_line_list = [y for x in area_col_line_list for y in x]
  135. line_list = row_line_list + col_line_list
  136. # 打印处理后线
  137. show(line_list, title="all", img=img_show, mode=5, is_test=is_test)
  138. log("otr postprocess table_line " + str(time.time() - start_time))
  139. return line_list
  140. def show(pred_or_lines, title='', prob=0.2, img=None, mode=1, is_test=0):
  141. if not is_test:
  142. return
  143. if mode == 1:
  144. plt.figure()
  145. plt.title(title)
  146. _array = []
  147. for _h in range(len(pred_or_lines)):
  148. _line = []
  149. for _w in range(len(pred_or_lines[_h])):
  150. _prob = pred_or_lines[_h][_w]
  151. if _prob[0] > prob:
  152. _line.append((0, 0, 255))
  153. elif _prob[1] > prob:
  154. _line.append((255, 0, 0))
  155. else:
  156. _line.append((255, 255, 255))
  157. _array.append(_line)
  158. # plt.axis('off')
  159. plt.imshow(np.array(_array))
  160. plt.show()
  161. elif mode == 2:
  162. plt.figure()
  163. plt.title(title)
  164. for _line in pred_or_lines:
  165. x0, y0, x1, y1 = _line
  166. plt.plot([x0, x1], [y0, y1])
  167. plt.show()
  168. elif mode == 3:
  169. for _line in pred_or_lines:
  170. x0, y0 = _line[0]
  171. x1, y1 = _line[1]
  172. cv2.line(img, [int(x0), int(y0)], [int(x1), int(y1)], (0, 0, 255), 2)
  173. cv2.namedWindow(title, cv2.WINDOW_NORMAL)
  174. cv2.imshow(title, img)
  175. cv2.waitKey(0)
  176. elif mode == 4:
  177. for point in pred_or_lines:
  178. point = [int(x) for x in point]
  179. cv2.circle(img, (point[0], point[1]), 1, (0, 255, 0), 2)
  180. cv2.namedWindow(title, cv2.WINDOW_NORMAL)
  181. cv2.imshow(title, img)
  182. cv2.waitKey(0)
  183. elif mode == 5:
  184. for _line in pred_or_lines:
  185. x0, y0, x1, y1 = _line
  186. cv2.line(img, [int(x0), int(y0)], [int(x1), int(y1)], (0, 255, 0), 2)
  187. cv2.namedWindow(title, cv2.WINDOW_NORMAL)
  188. cv2.imshow(title, img)
  189. cv2.waitKey(0)
  190. def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=8, padding=3, min_len=10,
  191. cell_width=13):
  192. _time = time.time()
  193. log("starting points2lines")
  194. height = len(pred)
  195. width = len(pred[0])
  196. _sum = list(np.sum(np.array((pred[..., 0] > prob)).astype(int), axis=1))
  197. h_index = -1
  198. h_lines = []
  199. v_lines = []
  200. _step = line_width
  201. while 1:
  202. h_index += 1
  203. if h_index >= height:
  204. break
  205. w_index = -1
  206. if sourceP_LB:
  207. h_i = height - 1 - h_index
  208. else:
  209. h_i = h_index
  210. _start = None
  211. if _sum[h_index] < min_len:
  212. continue
  213. last_back = 0
  214. while 1:
  215. if w_index >= width:
  216. if _start is not None:
  217. _end = w_index - 1
  218. _bbox = [_start, h_i, _end, h_i]
  219. _dict = {"bbox": _bbox}
  220. h_lines.append(_dict)
  221. _start = None
  222. break
  223. _h, _v = pred[h_i][w_index]
  224. if _h > prob:
  225. if _start is None:
  226. _start = w_index
  227. w_index += _step
  228. else:
  229. if _start is not None:
  230. _end = w_index - 1
  231. _bbox = [_start, h_i, _end, h_i]
  232. _dict = {"bbox": _bbox}
  233. h_lines.append(_dict)
  234. _start = None
  235. w_index -= _step // 2
  236. if w_index <= last_back:
  237. w_index = last_back + _step // 2
  238. last_back = w_index
  239. log("starting points2lines 1")
  240. w_index = -1
  241. _sum = list(np.sum(np.array((pred[..., 1] > prob)).astype(int), axis=0))
  242. _step = line_width
  243. while 1:
  244. w_index += 1
  245. if w_index >= width:
  246. break
  247. if _sum[w_index] < min_len:
  248. continue
  249. h_index = -1
  250. _start = None
  251. last_back = 0
  252. list_test = []
  253. list_lineprob = []
  254. while 1:
  255. if h_index >= height:
  256. if _start is not None:
  257. _end = last_h
  258. _bbox = [w_index, _start, w_index, _end]
  259. _dict = {"bbox": _bbox}
  260. v_lines.append(_dict)
  261. _start = None
  262. list_test.append(_dict)
  263. break
  264. if sourceP_LB:
  265. h_i = height - 1 - h_index
  266. else:
  267. h_i = h_index
  268. _h, _v = pred[h_index][w_index]
  269. list_lineprob.append((h_index, _v))
  270. if _v > prob:
  271. if _start is None:
  272. _start = h_i
  273. h_index += _step
  274. else:
  275. if _start is not None:
  276. _end = last_h
  277. _bbox = [w_index, _start, w_index, _end]
  278. _dict = {"bbox": _bbox}
  279. v_lines.append(_dict)
  280. _start = None
  281. list_test.append(_dict)
  282. h_index -= _step // 2
  283. if h_index <= last_back:
  284. h_index = last_back + _step // 2
  285. last_back = h_index
  286. last_h = h_i
  287. log("starting points2lines 2")
  288. for _line in h_lines:
  289. _bbox = _line["bbox"]
  290. _bbox = [max(_bbox[0] - 2, 0), (_bbox[1] + _bbox[3]) / 2, _bbox[2] + 2, (_bbox[1] + _bbox[3]) / 2]
  291. _line["bbox"] = _bbox
  292. for _line in v_lines:
  293. _bbox = _line["bbox"]
  294. _bbox = [(_bbox[0] + _bbox[2]) / 2, max(_bbox[1] - 2, 0), (_bbox[0] + _bbox[2]) / 2, _bbox[3] + 2]
  295. _line["bbox"] = _bbox
  296. h_lines = lines_cluster(h_lines, line_width=line_width)
  297. v_lines = lines_cluster(v_lines, line_width=line_width)
  298. list_line = []
  299. for _line in h_lines:
  300. _bbox = _line["bbox"]
  301. _bbox = [max(_bbox[0] - 1, 0), (_bbox[1] + _bbox[3]) / 2, _bbox[2] + 1, (_bbox[1] + _bbox[3]) / 2]
  302. list_line.append(_bbox)
  303. for _line in v_lines:
  304. _bbox = _line["bbox"]
  305. _bbox = [(_bbox[0] + _bbox[2]) / 2, max(_bbox[1] - 1, 0), (_bbox[0] + _bbox[2]) / 2, _bbox[3] + 1]
  306. list_line.append(_bbox)
  307. log("points2lines cost %.2fs" % (time.time() - _time))
  308. # import matplotlib.pyplot as plt
  309. # plt.figure()
  310. # for _line in list_line:
  311. # x0,y0,x1,y1 = _line
  312. # plt.plot([x0,x1],[y0,y1])
  313. # for _line in list_line:
  314. # x0,y0,x1,y1 = _line.bbox
  315. # plt.plot([x0,x1],[y0,y1])
  316. # for point in list_crosspoints:
  317. # plt.scatter(point.get("point")[0],point.get("point")[1])
  318. # plt.show()
  319. return list_line
  320. def lines_cluster(list_lines, line_width):
  321. after_len = 0
  322. prelength = len(list_lines)
  323. append_width = line_width // 2
  324. while 1:
  325. c_lines = []
  326. first_len = after_len
  327. for _line in list_lines:
  328. bbox = _line["bbox"]
  329. _find = False
  330. for c_l_i in range(len(c_lines)):
  331. c_l = c_lines[len(c_lines) - c_l_i - 1]
  332. bbox1 = c_l["bbox"]
  333. bboxa = [max(0, bbox[0] - append_width), max(0, bbox[1] - append_width), bbox[2] + append_width,
  334. bbox[3] + append_width]
  335. bboxb = [max(0, bbox1[0] - append_width), max(0, bbox1[1] - append_width), bbox1[2] + append_width,
  336. bbox1[3] + append_width]
  337. _iou = getIOU(bboxa, bboxb)
  338. if _iou > 0:
  339. new_bbox = [min(bbox[0], bbox[2], bbox1[0], bbox1[2]), min(bbox[1], bbox[3], bbox1[1], bbox1[3]),
  340. max(bbox[0], bbox[2], bbox1[0], bbox1[2]), max(bbox[1], bbox[3], bbox1[1], bbox1[3])]
  341. _find = True
  342. c_l["bbox"] = new_bbox
  343. break
  344. if not _find:
  345. c_lines.append(_line)
  346. after_len = len(c_lines)
  347. if first_len == after_len:
  348. break
  349. list_lines = c_lines
  350. log("cluster lines from %d to %d" % (prelength, len(list_lines)))
  351. return c_lines
  352. def getIOU(bbox0, bbox1):
  353. width = abs(max(bbox0[2], bbox1[2]) - min(bbox0[0], bbox1[0])) - (
  354. abs(bbox0[2] - bbox0[0]) + abs(bbox1[2] - bbox1[0]))
  355. height = abs(max(bbox0[3], bbox1[3]) - min(bbox0[1], bbox1[1])) - (
  356. abs(bbox0[3] - bbox0[1]) + abs(bbox1[3] - bbox1[1]))
  357. if width <= 0 and height <= 0:
  358. iou = abs(width * height / min(abs((bbox0[2] - bbox0[0]) * (bbox0[3] - bbox0[1])),
  359. abs((bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]))))
  360. # print("getIOU", iou)
  361. return iou + 0.1
  362. return 0
  363. def delete_short_lines(list_lines, image_shape, scale=100):
  364. # 排除太短的线
  365. x_min_len = max(5, int(image_shape[0] / scale))
  366. y_min_len = max(5, int(image_shape[1] / scale))
  367. new_list_lines = []
  368. for line in list_lines:
  369. if line[0] == line[2]:
  370. if abs(line[3] - line[1]) >= y_min_len:
  371. # print("y_min_len", abs(line[3] - line[1]), y_min_len)
  372. new_list_lines.append(line)
  373. else:
  374. if abs(line[2] - line[0]) >= x_min_len:
  375. # print("x_min_len", abs(line[2] - line[0]), x_min_len)
  376. new_list_lines.append(line)
  377. return new_list_lines
  378. def delete_single_lines(row_line_list, col_line_list, point_list):
  379. new_col_line_list = []
  380. min_point_cnt = 2
  381. for line in col_line_list:
  382. p_cnt = 0
  383. for p in point_list:
  384. # if line[0] == p[0] and line[1] <= p[1] <= line[3]:
  385. if line[0] == p[0]:
  386. p_cnt += 1
  387. if p_cnt >= min_point_cnt:
  388. new_col_line_list.append(line)
  389. break
  390. new_row_line_list = []
  391. for line in row_line_list:
  392. p_cnt = 0
  393. for p in point_list:
  394. # if line[1] == p[1] and line[0] <= p[0] <= line[2]:
  395. if line[1] == p[1]:
  396. p_cnt += 1
  397. if p_cnt >= min_point_cnt:
  398. new_row_line_list.append(line)
  399. break
  400. return new_row_line_list, new_col_line_list
  401. def merge_line(lines, axis, threshold=5):
  402. """
  403. 解决模型预测一条直线错开成多条直线,合并成一条直线
  404. :param lines: 线条列表
  405. :param axis: 0:横线 1:竖线
  406. :param threshold: 两条线间像素差阈值
  407. :return: 合并后的线条列表
  408. """
  409. # 任意一条line获取该合并的line,横线往下找,竖线往右找
  410. lines.sort(key=lambda x: (x[axis], x[1 - axis]))
  411. merged_lines = []
  412. used_lines = []
  413. for line1 in lines:
  414. if line1 in used_lines:
  415. continue
  416. merged_line = [line1]
  417. used_lines.append(line1)
  418. for line2 in lines:
  419. if line2 in used_lines:
  420. continue
  421. if line1[1 - axis] - threshold <= line2[1 - axis] <= line1[1 - axis] + threshold:
  422. # 计算基准长度
  423. min_axis = 10000
  424. max_axis = 0
  425. for line3 in merged_line:
  426. if line3[axis] < min_axis:
  427. min_axis = line3[axis]
  428. if line3[axis + 2] > max_axis:
  429. max_axis = line3[axis + 2]
  430. # 判断两条线有无交集
  431. if min_axis <= line2[axis] <= max_axis \
  432. or min_axis <= line2[axis + 2] <= max_axis:
  433. merged_line.append(line2)
  434. used_lines.append(line2)
  435. if merged_line:
  436. merged_lines.append(merged_line)
  437. # 合并line
  438. result_lines = []
  439. for merged_line in merged_lines:
  440. # 获取line宽的平均值
  441. axis_average = 0
  442. for line in merged_line:
  443. axis_average += line[1 - axis]
  444. axis_average = int(axis_average / len(merged_line))
  445. # 获取最长line两端
  446. merged_line.sort(key=lambda x: (x[axis]))
  447. axis_start = merged_line[0][axis]
  448. merged_line.sort(key=lambda x: (x[axis + 2]))
  449. axis_end = merged_line[-1][axis + 2]
  450. if axis:
  451. result_lines.append([axis_average, axis_start, axis_average, axis_end])
  452. else:
  453. result_lines.append([axis_start, axis_average, axis_end, axis_average])
  454. return result_lines
  455. def get_points(row_lines, col_lines, image_size):
  456. # 创建空图
  457. row_img = np.zeros(image_size, np.uint8)
  458. col_img = np.zeros(image_size, np.uint8)
  459. # 画线
  460. threshold = 5
  461. for row in row_lines:
  462. cv2.line(row_img, (int(row[0] - threshold), int(row[1])), (int(row[2] + threshold), int(row[3])), (255, 255, 255), 1)
  463. for col in col_lines:
  464. cv2.line(col_img, (int(col[0]), int(col[1] - threshold)), (int(col[2]), int(col[3] + threshold)), (255, 255, 255), 1)
  465. # 求出交点
  466. point_img = np.bitwise_and(row_img, col_img)
  467. # cv2.imwrite("get_points.jpg", row_img+col_img)
  468. # cv2.imshow("get_points", row_img+col_img)
  469. # cv2.waitKey(0)
  470. # 识别黑白图中的白色交叉点,将横纵坐标取出
  471. ys, xs = np.where(point_img > 0)
  472. points = []
  473. for i in range(len(xs)):
  474. points.append((xs[i], ys[i]))
  475. points.sort(key=lambda x: (x[0], x[1]))
  476. return points
  477. def fix_outline(image, row_line_list, col_line_list, point_list, scale=25):
  478. log("into fix_outline")
  479. x_min_len = max(10, int(image.shape[0] / scale))
  480. y_min_len = max(10, int(image.shape[1] / scale))
  481. if len(row_line_list) <= 1 or len(col_line_list) <= 1:
  482. return [], [], row_line_list, col_line_list
  483. # 预测线取上下左右4个边(会有超出表格部分) [(), ()]
  484. row_line_list.sort(key=lambda x: (x[1], x[0]))
  485. up_line = row_line_list[0]
  486. bottom_line = row_line_list[-1]
  487. col_line_list.sort(key=lambda x: x[0])
  488. left_line = col_line_list[0]
  489. right_line = col_line_list[-1]
  490. # 计算单格高度宽度
  491. if len(row_line_list) > 1:
  492. height_dict = {}
  493. for j in range(len(row_line_list)):
  494. if j + 1 > len(row_line_list) - 1:
  495. break
  496. height = abs(int(row_line_list[j][3] - row_line_list[j + 1][3]))
  497. if height >= 10:
  498. if height in height_dict.keys():
  499. height_dict[height] = height_dict[height] + 1
  500. else:
  501. height_dict[height] = 1
  502. height_list = [[x, height_dict[x]] for x in height_dict.keys()]
  503. if height_list:
  504. height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True)
  505. # print("box_height", height_list)
  506. box_height = height_list[0][0]
  507. else:
  508. box_height = y_min_len
  509. else:
  510. box_height = y_min_len
  511. if len(col_line_list) > 1:
  512. box_width = abs(col_line_list[1][2] - col_line_list[0][2])
  513. else:
  514. box_width = x_min_len
  515. # 设置轮廓线需超出阈值
  516. if box_height >= 2 * y_min_len:
  517. fix_h_len = y_min_len
  518. else:
  519. fix_h_len = box_height * 2 / 3
  520. if box_width >= 2 * x_min_len:
  521. fix_w_len = x_min_len
  522. else:
  523. fix_w_len = box_width * 2 / 3
  524. # 判断超出部分的长度,超出一定长度就补线
  525. new_row_lines = []
  526. new_col_lines = []
  527. all_longer_row_lines = []
  528. all_longer_col_lines = []
  529. # print('box_height, box_width, fix_h_len, fix_w_len', box_height, box_width, fix_h_len, fix_w_len)
  530. # print('bottom_line, left_line, right_line', bottom_line, left_line, right_line)
  531. # 补左右两条竖线超出来的线的row
  532. if up_line[1] - left_line[1] >= fix_h_len and up_line[1] - right_line[1] >= fix_h_len:
  533. if up_line[1] - left_line[1] >= up_line[1] - right_line[1]:
  534. new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]])
  535. new_col_y = left_line[1]
  536. # 补了row,要将其他短的col连到row上
  537. for j in range(len(col_line_list)):
  538. col = col_line_list[j]
  539. if abs(new_col_y - col[1]) <= box_height:
  540. col_line_list[j][1] = min([new_col_y, col[1]])
  541. else:
  542. new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]])
  543. new_col_y = right_line[1]
  544. # 补了row,要将其他短的col连到row上
  545. for j in range(len(col_line_list)):
  546. col = col_line_list[j]
  547. # 且距离不能相差太大
  548. if abs(new_col_y - col[1]) <= box_height:
  549. col_line_list[j][1] = min([new_col_y, col[1]])
  550. if left_line[3] - bottom_line[3] >= fix_h_len and right_line[3] - bottom_line[3] >= fix_h_len:
  551. if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]:
  552. new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]])
  553. new_col_y = left_line[3]
  554. # 补了row,要将其他短的col连到row上
  555. for j in range(len(col_line_list)):
  556. col = col_line_list[j]
  557. # 且距离不能相差太大
  558. if abs(new_col_y - col[3]) <= box_height:
  559. col_line_list[j][3] = max([new_col_y, col[3]])
  560. else:
  561. new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]])
  562. new_col_y = right_line[3]
  563. # 补了row,要将其他短的col连到row上
  564. for j in range(len(col_line_list)):
  565. col = col_line_list[j]
  566. # 且距离不能相差太大
  567. if abs(new_col_y - col[3]) <= box_height:
  568. col_line_list[j][3] = max([new_col_y, col[3]])
  569. # 补上下两条横线超出来的线的col
  570. if left_line[0] - up_line[0] >= fix_w_len and left_line[0] - bottom_line[0] >= fix_w_len:
  571. if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]:
  572. new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]])
  573. new_row_x = up_line[0]
  574. # 补了col,要将其他短的row连到col上
  575. for j in range(len(row_line_list)):
  576. row = row_line_list[j]
  577. # 且距离不能相差太大
  578. if abs(new_row_x - row[0]) <= box_width:
  579. row_line_list[j][0] = min([new_row_x, row[0]])
  580. else:
  581. new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]])
  582. new_row_x = bottom_line[0]
  583. # 补了col,要将其他短的row连到col上
  584. for j in range(len(row_line_list)):
  585. row = row_line_list[j]
  586. # 且距离不能相差太大
  587. if abs(new_row_x - row[0]) <= box_width:
  588. row_line_list[j][0] = min([new_row_x, row[0]])
  589. if up_line[2] - right_line[2] >= fix_w_len and bottom_line[2] - right_line[2] >= fix_w_len:
  590. if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]:
  591. new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]])
  592. new_row_x = up_line[2]
  593. # 补了col,要将其他短的row连到col上
  594. for j in range(len(row_line_list)):
  595. row = row_line_list[j]
  596. # 且距离不能相差太大
  597. if abs(new_row_x - row[2]) <= box_width:
  598. row_line_list[j][2] = max([new_row_x, row[2]])
  599. else:
  600. new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]])
  601. new_row_x = bottom_line[2]
  602. # 补了col,要将其他短的row连到col上
  603. for j in range(len(row_line_list)):
  604. row = row_line_list[j]
  605. # 且距离不能相差太大
  606. if abs(new_row_x - row[2]) <= box_width:
  607. row_line_list[j][2] = max([new_row_x, row[2]])
  608. all_longer_row_lines += row_line_list
  609. all_longer_col_lines += col_line_list
  610. # print('new_row_lines, new_col_lines', new_row_lines, new_col_lines)
  611. # print('all_longer_row_lines, all_longer_col_lines', all_longer_row_lines, all_longer_col_lines)
  612. return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines
  613. def fix_inner(row_line_list, col_line_list, point_list):
  614. def fix(fix_lines, assist_lines, split_points, axis):
  615. new_line_point_list = []
  616. delete_line_point_list = []
  617. for line1 in fix_lines:
  618. min_assist_line = [[], []]
  619. min_distance = [1000, 1000]
  620. if_find = [0, 0]
  621. # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点
  622. fix_line_points = []
  623. for point in split_points:
  624. if abs(point[1 - axis] - line1[1 - axis]) <= 2:
  625. if line1[axis] <= point[axis] <= line1[axis + 2]:
  626. fix_line_points.append(point)
  627. # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交
  628. line1_point = [line1[:2], line1[2:]]
  629. for i in range(2):
  630. point = line1_point[i]
  631. for line2 in assist_lines:
  632. if not if_find[i] and abs(point[axis] - line2[axis]) <= 2:
  633. if line1[1 - axis] <= point[1 - axis] <= line2[1 - axis + 2]:
  634. # print("line1, match line2", line1, line2)
  635. if_find[i] = 1
  636. break
  637. else:
  638. if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1 - axis] <= point[1 - axis] <= \
  639. line2[1 - axis + 2]:
  640. if line1[axis] <= line2[axis] <= line1[axis + 2]:
  641. continue
  642. min_distance[i] = abs(line1[axis] - line2[axis])
  643. min_assist_line[i] = line2
  644. if len(min_assist_line[0]) == 0 and len(min_assist_line[1]) == 0:
  645. continue
  646. # 找出离assist_line最近的交点
  647. min_distance = [1000, 1000]
  648. min_col_point = [[], []]
  649. for i in range(2):
  650. # print("顶点", i, line1_point[i])
  651. if min_assist_line[i]:
  652. for point in fix_line_points:
  653. if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]:
  654. min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
  655. min_col_point[i] = point
  656. # print("min_col_point", min_col_point)
  657. # print("min_assist_line", min_assist_line)
  658. if len(min_col_point[0]) == 0 and len(min_col_point[1]) == 0:
  659. continue
  660. # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3
  661. # print("line1_point", line1_point)
  662. if min_assist_line[0] and min_assist_line[0] == min_assist_line[1]:
  663. if min_assist_line[0][axis] < line1_point[0][axis]:
  664. bbox_len = abs(min_col_point[0][axis] - min_assist_line[0][axis])
  665. line_distance = abs(min_col_point[0][axis] - line1_point[0][axis])
  666. if bbox_len / 3 <= line_distance <= bbox_len:
  667. if axis == 1:
  668. add_point = (line1_point[0][1 - axis], min_assist_line[0][axis])
  669. else:
  670. add_point = (min_assist_line[0][axis], line1_point[0][1 - axis])
  671. new_line_point_list.append([line1, add_point])
  672. elif min_assist_line[1][axis] > line1_point[1][axis]:
  673. bbox_len = abs(min_col_point[1][axis] - min_assist_line[1][axis])
  674. line_distance = abs(min_col_point[1][axis] - line1_point[1][axis])
  675. if bbox_len / 3 <= line_distance <= bbox_len:
  676. if axis == 1:
  677. add_point = (line1_point[1][1 - axis], min_assist_line[1][axis])
  678. else:
  679. add_point = (min_assist_line[1][axis], line1_point[1][1 - axis])
  680. new_line_point_list.append([line1, add_point])
  681. else:
  682. for i in range(2):
  683. if min_col_point[i]:
  684. bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis])
  685. line_distance = abs(min_col_point[i][axis] - line1_point[i][axis])
  686. # print("bbox_len, line_distance", bbox_len, line_distance)
  687. if bbox_len / 3 <= line_distance <= bbox_len:
  688. if axis == 1:
  689. add_point = (line1_point[i][1 - axis], min_assist_line[i][axis])
  690. else:
  691. add_point = (min_assist_line[i][axis], line1_point[i][1 - axis])
  692. new_line_point_list.append([line1, add_point])
  693. return new_line_point_list
  694. row_line_list_copy = copy.deepcopy(row_line_list)
  695. col_line_list_copy = copy.deepcopy(col_line_list)
  696. try:
  697. new_point_list = fix(col_line_list, row_line_list, point_list, axis=1)
  698. for line, new_point in new_point_list:
  699. if line in col_line_list:
  700. index = col_line_list.index(line)
  701. point1 = line[:2]
  702. point2 = line[2:]
  703. if new_point[1] >= point2[1]:
  704. col_line_list[index] = [point1[0], point1[1], new_point[0], new_point[1]]
  705. elif new_point[1] <= point1[1]:
  706. col_line_list[index] = [new_point[0], new_point[1], point2[0], point2[1]]
  707. new_point_list = fix(row_line_list, col_line_list, point_list, axis=0)
  708. for line, new_point in new_point_list:
  709. if line in row_line_list:
  710. index = row_line_list.index(line)
  711. point1 = line[:2]
  712. point2 = line[2:]
  713. if new_point[0] >= point2[0]:
  714. row_line_list[index] = [point1[0], point1[1], new_point[0], new_point[1]]
  715. elif new_point[0] <= point1[0]:
  716. row_line_list[index] = [new_point[0], new_point[1], point2[0], point2[1]]
  717. return row_line_list, col_line_list
  718. except:
  719. traceback.print_exc()
  720. return row_line_list_copy, col_line_list_copy
  721. def fix_4_points(cross_points, row_line_list, col_line_list):
  722. if not (len(row_line_list) >= 2 and len(col_line_list) >= 2):
  723. return row_line_list, col_line_list
  724. cross_points.sort(key=lambda x: (x[0], x[1]))
  725. left_up_p = cross_points[0]
  726. right_down_p = cross_points[-1]
  727. cross_points.sort(key=lambda x: (-x[0], x[1]))
  728. right_up_p = cross_points[0]
  729. left_down_p = cross_points[-1]
  730. # print('left_up_p', left_up_p, 'left_down_p', left_down_p)
  731. # print('right_up_p', right_up_p, 'right_down_p', right_down_p)
  732. min_x = min(left_up_p[0], left_down_p[0], right_down_p[0], right_up_p[0])
  733. max_x = max(left_up_p[0], left_down_p[0], right_down_p[0], right_up_p[0])
  734. min_y = min(left_up_p[1], left_down_p[1], right_down_p[1], right_up_p[1])
  735. max_y = max(left_up_p[1], left_down_p[1], right_down_p[1], right_up_p[1])
  736. if left_up_p[0] != min_x or left_up_p[1] != min_y:
  737. log('轮廓左上角交点有问题')
  738. row_line_list.append([min_x, min_y, max_x, min_y])
  739. col_line_list.append([min_x, min_y, min_x, max_y])
  740. if left_down_p[0] != min_x or left_down_p[1] != max_y:
  741. log('轮廓左下角交点有问题')
  742. row_line_list.append([min_x, max_y, max_x, max_y])
  743. col_line_list.append([min_x, min_y, min_x, max_y])
  744. if right_up_p[0] != max_x or right_up_p[1] != min_y:
  745. log('轮廓右上角交点有问题')
  746. row_line_list.append([min_x, max_y, max_x, max_y])
  747. col_line_list.append([max_x, min_y, max_x, max_y])
  748. if right_down_p[0] != max_x or right_down_p[1] != max_y:
  749. log('轮廓右下角交点有问题')
  750. row_line_list.append([min_x, max_y, max_x, max_y])
  751. col_line_list.append([max_x, min_y, max_x, max_y])
  752. return row_line_list, col_line_list
  753. def get_split_line(points, col_lines, image_np, threshold=5):
  754. # 线贴着边缘无法得到split_y,导致无法分区
  755. for _col in col_lines:
  756. if _col[3] >= image_np.shape[0] - 5:
  757. _col[3] = image_np.shape[0] - 6
  758. if _col[1] <= 0 + 5:
  759. _col[1] = 6
  760. # print("get_split_line", image_np.shape)
  761. points.sort(key=lambda x: (x[1], x[0]))
  762. # 遍历y坐标,并判断y坐标与上一个y坐标是否存在连接线
  763. i = 0
  764. split_line_y = []
  765. for point in points:
  766. # 从已分开的线下面开始判断
  767. if split_line_y:
  768. if point[1] <= split_line_y[-1] + threshold:
  769. last_y = point[1]
  770. continue
  771. if last_y <= split_line_y[-1] + threshold:
  772. last_y = point[1]
  773. continue
  774. if i == 0:
  775. last_y = point[1]
  776. i += 1
  777. continue
  778. current_line = (last_y, point[1])
  779. split_flag = 1
  780. for col in col_lines:
  781. # 只要找到一条col包含就不是分割线
  782. if current_line[0] >= col[1] - 3 and current_line[1] <= col[3] + 3:
  783. split_flag = 0
  784. break
  785. if split_flag:
  786. split_line_y.append(current_line[0] + 5)
  787. split_line_y.append(current_line[1] - 5)
  788. last_y = point[1]
  789. # 加上收尾分割线
  790. points.sort(key=lambda x: (x[1], x[0]))
  791. y_min = points[0][1]
  792. y_max = points[-1][1]
  793. if y_min - threshold < 0:
  794. split_line_y.append(0)
  795. else:
  796. split_line_y.append(y_min - threshold)
  797. if y_max + threshold > image_np.shape[0]:
  798. split_line_y.append(image_np.shape[0])
  799. else:
  800. split_line_y.append(y_max + threshold)
  801. split_line_y = list(set(split_line_y))
  802. # 剔除两条相隔太近分割线
  803. temp_split_line_y = []
  804. split_line_y.sort(key=lambda x: x)
  805. last_y = -20
  806. for y in split_line_y:
  807. if y - last_y >= 20:
  808. temp_split_line_y.append(y)
  809. last_y = y
  810. split_line_y = temp_split_line_y
  811. # 生成分割线
  812. split_line = []
  813. for y in split_line_y:
  814. split_line.append([(0, y), (image_np.shape[1], y)])
  815. split_line.append([(0, 0), (image_np.shape[1], 0)])
  816. split_line.append([(0, image_np.shape[0]), (image_np.shape[1], image_np.shape[0])])
  817. split_line.sort(key=lambda x: x[0][1])
  818. return split_line, split_line_y
  819. def get_split_area(split_y, row_line_list, col_line_list, cross_points):
  820. # 分割线纵坐标
  821. if len(split_y) < 2:
  822. return [], [], []
  823. split_y.sort(key=lambda x: x)
  824. # new_split_y = []
  825. # for i in range(1, len(split_y), 2):
  826. # new_split_y.append(int((split_y[i] + split_y[i - 1]) / 2))
  827. area_row_line_list = []
  828. area_col_line_list = []
  829. area_point_list = []
  830. for i in range(1, len(split_y)):
  831. y = split_y[i]
  832. last_y = split_y[i - 1]
  833. split_row = []
  834. for row in row_line_list:
  835. if last_y <= row[3] <= y:
  836. split_row.append(row)
  837. split_col = []
  838. for col in col_line_list:
  839. if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]:
  840. split_col.append(col)
  841. split_point = []
  842. for point in cross_points:
  843. if last_y <= point[1] <= y:
  844. split_point.append(point)
  845. # 满足条件才能形成表格区域
  846. if len(split_row) >= 2 and len(split_col) >= 2 and len(split_point) >= 4:
  847. # print('len(split_row), len(split_col), len(split_point)', len(split_row), len(split_col), len(split_point))
  848. area_row_line_list.append(split_row)
  849. area_col_line_list.append(split_col)
  850. area_point_list.append(split_point)
  851. return area_row_line_list, area_col_line_list, area_point_list
  852. def get_standard_lines(row_line_list, col_line_list):
  853. new_row_line_list = []
  854. for row in row_line_list:
  855. w1 = row[0]
  856. w2 = row[2]
  857. # 横线的两个顶点分别找到最近的竖线
  858. min_distance = [10000, 10000]
  859. min_dis_w = [None, None]
  860. for col in col_line_list:
  861. if abs(col[0] - w1) < min_distance[0]:
  862. min_distance[0] = abs(col[0] - w1)
  863. min_dis_w[0] = col[0]
  864. if abs(col[0] - w2) < min_distance[1]:
  865. min_distance[1] = abs(col[0] - w2)
  866. min_dis_w[1] = col[0]
  867. if min_dis_w[0] is not None:
  868. row[0] = min_dis_w[0]
  869. if min_dis_w[1] is not None:
  870. row[2] = min_dis_w[1]
  871. new_row_line_list.append(row)
  872. new_col_line_list = []
  873. for col in col_line_list:
  874. h1 = col[1]
  875. h2 = col[3]
  876. # 横线的两个顶点分别找到最近的竖线
  877. min_distance = [10000, 10000]
  878. min_dis_w = [None, None]
  879. for row in row_line_list:
  880. if abs(row[1] - h1) < min_distance[0]:
  881. min_distance[0] = abs(row[1] - h1)
  882. min_dis_w[0] = row[1]
  883. if abs(row[1] - h2) < min_distance[1]:
  884. min_distance[1] = abs(row[1] - h2)
  885. min_dis_w[1] = row[1]
  886. if min_dis_w[0] is not None:
  887. col[1] = min_dis_w[0]
  888. if min_dis_w[1] is not None:
  889. col[3] = min_dis_w[1]
  890. new_col_line_list.append(col)
  891. # all_line_list = []
  892. # # 横线竖线两个维度
  893. # for i in range(2):
  894. # axis = i
  895. # cross_points.sort(key=lambda x: (x[axis], x[1-axis]))
  896. # current_axis = cross_points[0][axis]
  897. # points = []
  898. # line_list = []
  899. # for p in cross_points:
  900. # if p[axis] == current_axis:
  901. # points.append(p)
  902. # else:
  903. # if points:
  904. # line_list.append([points[0][0], points[0][1], points[-1][0], points[-1][1]])
  905. # points = [p]
  906. # current_axis = p[axis]
  907. # if points:
  908. # line_list.append([points[0][0], points[0][1], points[-1][0], points[-1][1]])
  909. # all_line_list.append(line_list)
  910. # new_col_line_list, new_row_line_list = all_line_list
  911. return new_col_line_list, new_row_line_list
  912. def add_outline(cross_points, row_line_list, col_line_list):
  913. cross_points.sort(key=lambda x: (x[0], x[1]))
  914. left_up_p = cross_points[0]
  915. right_down_p = cross_points[-1]
  916. row_line_list.append([left_up_p[0], left_up_p[1], right_down_p[0], left_up_p[1]])
  917. row_line_list.append([left_up_p[0], right_down_p[1], right_down_p[0], right_down_p[1]])
  918. col_line_list.append([left_up_p[0], left_up_p[1], left_up_p[0], right_down_p[1]])
  919. col_line_list.append([right_down_p[0], left_up_p[1], right_down_p[0], right_down_p[1]])
  920. return row_line_list, col_line_list