table_line_new.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200
  1. import copy
  2. import time
  3. import traceback
  4. import numpy as np
  5. import cv2
  6. import matplotlib.pyplot as plt
  7. from format_convert.utils import log, pil_resize, memory_decorator
  8. def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
  9. log("into table_line, prob is " + str(prob))
  10. # resize
  11. w, h = size
  12. img_new = pil_resize(img, h, w)
  13. img_show = copy.deepcopy(img_new)
  14. # predict
  15. start_time = time.time()
  16. pred = model.predict(np.array([img_new]))
  17. pred = pred[0]
  18. log("otr model predict time " + str(time.time() - start_time))
  19. # show
  20. show(pred, title='pred', prob=prob, mode=1, is_test=is_test)
  21. # 根据点获取线
  22. start_time = time.time()
  23. line_list = points2lines(pred, False, prob=prob)
  24. log("points2lines " + str(time.time() - start_time))
  25. if not line_list:
  26. return []
  27. show(line_list, title="points2lines", mode=2, is_test=is_test)
  28. # 清除短线
  29. start_time = time.time()
  30. line_list = delete_short_lines(line_list, img_new.shape)
  31. show(line_list, title="delete_short_lines", mode=2, is_test=is_test)
  32. log("delete_short_lines " + str(time.time() - start_time))
  33. # 分成横竖线
  34. start_time = time.time()
  35. row_line_list = []
  36. col_line_list = []
  37. for line in line_list:
  38. if line[0] == line[2]:
  39. col_line_list.append(line)
  40. elif line[1] == line[3]:
  41. row_line_list.append(line)
  42. log("divide rows and cols " + str(time.time() - start_time))
  43. # 两种线都需要存在,否则跳过
  44. if not row_line_list or not col_line_list:
  45. return []
  46. # 合并错开线
  47. start_time = time.time()
  48. row_line_list = merge_line(row_line_list, axis=0)
  49. col_line_list = merge_line(col_line_list, axis=1)
  50. show(row_line_list + col_line_list, title="merge_line", mode=2, is_test=is_test)
  51. log("merge_line " + str(time.time() - start_time))
  52. # 计算交点
  53. cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
  54. if not cross_points:
  55. return []
  56. # 删除无交点线 需重复两次才删的干净
  57. row_line_list, col_line_list = delete_single_lines(row_line_list, col_line_list, cross_points)
  58. cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
  59. row_line_list, col_line_list = delete_single_lines(row_line_list, col_line_list, cross_points)
  60. if not row_line_list or not col_line_list:
  61. return []
  62. # 多个表格分割线,获取多个表格区域
  63. start_time = time.time()
  64. split_lines, split_y = get_split_line(cross_points, col_line_list, img_new)
  65. area_row_line_list, area_col_line_list, area_point_list = get_split_area(split_y, row_line_list, col_line_list, cross_points)
  66. log("get_split_area " + str(time.time() - start_time))
  67. # 根据区域循环
  68. need_split_flag = False
  69. for i in range(len(area_point_list)):
  70. sub_row_line_list = area_row_line_list[i]
  71. sub_col_line_list = area_col_line_list[i]
  72. sub_point_list = area_point_list[i]
  73. # 修复边框
  74. start_time = time.time()
  75. new_rows, new_cols, long_rows, long_cols = fix_outline(img_new,
  76. sub_row_line_list,
  77. sub_col_line_list,
  78. sub_point_list)
  79. # 如有补线
  80. if new_rows or new_cols:
  81. # 连接至补线的延长线
  82. if long_rows:
  83. sub_row_line_list = long_rows
  84. if long_cols:
  85. sub_col_line_list = long_cols
  86. # 新的补线
  87. if new_rows:
  88. sub_row_line_list += new_rows
  89. if new_cols:
  90. sub_col_line_list += new_cols
  91. need_split_flag = True
  92. area_row_line_list[i] = sub_row_line_list
  93. area_col_line_list[i] = sub_col_line_list
  94. row_line_list = [y for x in area_row_line_list for y in x]
  95. col_line_list = [y for x in area_col_line_list for y in x]
  96. if need_split_flag:
  97. # 修复边框后重新计算交点
  98. cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
  99. split_lines, split_y = get_split_line(cross_points, col_line_list, img_new)
  100. area_row_line_list, area_col_line_list, area_point_list = get_split_area(split_y, row_line_list, col_line_list, cross_points)
  101. show(cross_points, title="get_points", img=img_show, mode=4, is_test=is_test)
  102. show(split_lines, title="split_lines", img=img_show, mode=3, is_test=is_test)
  103. show(row_line_list + col_line_list, title="fix_outline", mode=2, is_test=is_test)
  104. log("fix_outline " + str(time.time() - start_time))
  105. # 根据区域循环
  106. for i in range(len(area_point_list)):
  107. sub_row_line_list = area_row_line_list[i]
  108. sub_col_line_list = area_col_line_list[i]
  109. sub_point_list = area_point_list[i]
  110. # 验证轮廓的4个交点
  111. sub_row_line_list, sub_col_line_list = fix_4_points(sub_point_list, sub_row_line_list, sub_col_line_list)
  112. # 把四个边线在加一次
  113. sub_point_list = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
  114. sub_row_line_list, sub_col_line_list = add_outline(sub_point_list, sub_row_line_list, sub_col_line_list)
  115. # 修复内部缺线
  116. start_time = time.time()
  117. sub_row_line_list, sub_col_line_list = fix_inner(sub_row_line_list, sub_col_line_list, sub_point_list)
  118. log("fix_inner " + str(time.time() - start_time))
  119. show(sub_row_line_list + sub_col_line_list, title="fix_inner1", mode=2, is_test=is_test)
  120. # 合并错开
  121. start_time = time.time()
  122. sub_row_line_list = merge_line(sub_row_line_list, axis=0)
  123. sub_col_line_list = merge_line(sub_col_line_list, axis=1)
  124. log("merge_line " + str(time.time() - start_time))
  125. show(sub_row_line_list + sub_col_line_list, title="merge_line", mode=2, is_test=is_test)
  126. # 修复内部线后重新计算交点
  127. start_time = time.time()
  128. cross_points = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
  129. show(cross_points, title="get_points3", img=img_show, mode=4, is_test=is_test)
  130. # 消除线突出,获取标准的线
  131. area_row_line_list[i], area_col_line_list[i] = get_standard_lines(sub_row_line_list, sub_col_line_list)
  132. show(area_row_line_list[i] + area_col_line_list[i], title="get_standard_lines", mode=2, is_test=is_test)
  133. row_line_list = [y for x in area_row_line_list for y in x]
  134. col_line_list = [y for x in area_col_line_list for y in x]
  135. line_list = row_line_list + col_line_list
  136. # 打印处理后线
  137. show(line_list, title="all", img=img_show, mode=5, is_test=is_test)
  138. log("otr postprocess table_line " + str(time.time() - start_time))
  139. return line_list
  140. @memory_decorator
  141. def table_line_pdf_post_process(line_list, page_w, page_h, is_test=0):
  142. log('into table_line_pdf_post_process')
  143. for i, line in enumerate(line_list):
  144. line_list[i] = [int(x) for x in line]
  145. # log('pdf img_new h w ' + str(int(page_h+1)) + ' ' + str(int(page_w+1)))
  146. img_new = np.full([int(page_h+1), int(page_w+1), 3], 255, dtype=np.uint8)
  147. # log('pdf np.full')
  148. img_show = copy.deepcopy(img_new)
  149. # log('pdf copy.deepcopy')
  150. show(line_list, title="table_line_pdf start", mode=2, is_test=is_test)
  151. # 分成横竖线
  152. start_time = time.time()
  153. row_line_list = []
  154. col_line_list = []
  155. for line in line_list:
  156. # 可能有斜线
  157. if line[0] == line[2]:
  158. col_line_list.append(line)
  159. elif line[1] == line[3]:
  160. row_line_list.append(line)
  161. else:
  162. if is_test:
  163. print(line)
  164. # log("pdf divide rows and cols " + str(time.time() - start_time))
  165. show(row_line_list + col_line_list, title="divide", mode=2, is_test=is_test)
  166. # 两种线都需要存在,否则跳过
  167. if not row_line_list or not col_line_list:
  168. return []
  169. # 合并线
  170. row_line_list = merge_line(row_line_list, axis=0)
  171. col_line_list = merge_line(col_line_list, axis=1)
  172. # log("pdf merge_line1 " + str(time.time() - start_time))
  173. show(row_line_list + col_line_list, title="merge", mode=2, is_test=is_test)
  174. # 计算交点
  175. # print('img_new.shape', img_new.shape)
  176. cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
  177. if not cross_points:
  178. return []
  179. show(cross_points, title="get_points", img=img_show, mode=4, is_test=is_test)
  180. # 多个表格分割线,获取多个表格区域
  181. start_time = time.time()
  182. split_lines, split_y = get_split_line(cross_points, col_line_list, img_new)
  183. area_row_line_list, area_col_line_list, area_point_list = get_split_area(split_y, row_line_list, col_line_list, cross_points)
  184. # log("pdf get_split_area " + str(time.time() - start_time))
  185. show(split_lines, title="split_lines", img=img_show, mode=3, is_test=is_test)
  186. # 根据区域循环
  187. need_split_flag = False
  188. for i in range(len(area_point_list)):
  189. sub_row_line_list = area_row_line_list[i]
  190. sub_col_line_list = area_col_line_list[i]
  191. sub_point_list = area_point_list[i]
  192. # 修复边框
  193. start_time = time.time()
  194. new_rows, new_cols, long_rows, long_cols = fix_outline(img_new,
  195. sub_row_line_list,
  196. sub_col_line_list,
  197. sub_point_list)
  198. # log("pdf fix_outline1 " + str(time.time() - start_time))
  199. # 如有补线
  200. if new_rows or new_cols:
  201. # 连接至补线的延长线
  202. if long_rows:
  203. sub_row_line_list = long_rows
  204. if long_cols:
  205. sub_col_line_list = long_cols
  206. # 新的补线
  207. if new_rows:
  208. sub_row_line_list += new_rows
  209. if new_cols:
  210. sub_col_line_list += new_cols
  211. need_split_flag = True
  212. area_row_line_list[i] = sub_row_line_list
  213. area_col_line_list[i] = sub_col_line_list
  214. row_line_list = [y for x in area_row_line_list for y in x]
  215. col_line_list = [y for x in area_col_line_list for y in x]
  216. if need_split_flag:
  217. # 修复边框后重新计算交点
  218. cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
  219. split_lines, split_y = get_split_line(cross_points, col_line_list, img_new)
  220. area_row_line_list, area_col_line_list, area_point_list = get_split_area(split_y, row_line_list, col_line_list, cross_points)
  221. # log("pdf fix_outline2 " + str(time.time() - start_time))
  222. # 根据区域循环
  223. for i in range(len(area_point_list)):
  224. sub_row_line_list = area_row_line_list[i]
  225. sub_col_line_list = area_col_line_list[i]
  226. sub_point_list = area_point_list[i]
  227. # 验证轮廓的4个交点
  228. sub_row_line_list, sub_col_line_list = fix_4_points(sub_point_list, sub_row_line_list, sub_col_line_list)
  229. # log("pdf fix_4_points " + str(time.time() - start_time))
  230. # 把四个边线在加一次
  231. sub_point_list = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
  232. sub_row_line_list, sub_col_line_list = add_outline(sub_point_list, sub_row_line_list, sub_col_line_list)
  233. # 修复内部缺线
  234. start_time = time.time()
  235. sub_row_line_list, sub_col_line_list = fix_inner(sub_row_line_list, sub_col_line_list, sub_point_list)
  236. # log("pdf fix_inner " + str(time.time() - start_time))
  237. show(sub_row_line_list + sub_col_line_list, title="fix_inner1", mode=2, is_test=is_test)
  238. # 修复内部线后重新计算交点
  239. start_time = time.time()
  240. cross_points = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
  241. show(cross_points, title="get_points3", img=img_show, mode=4, is_test=is_test)
  242. area_point_list[i] = cross_points
  243. # 合并线
  244. area_row_line_list[i] = merge_line(sub_row_line_list, axis=0)
  245. area_col_line_list[i] = merge_line(sub_col_line_list, axis=1)
  246. # log("pdf merge_line2 " + str(time.time() - start_time))
  247. row_line_list = [y for x in area_row_line_list for y in x]
  248. col_line_list = [y for x in area_col_line_list for y in x]
  249. line_list = row_line_list + col_line_list
  250. # 打印处理后线
  251. show(line_list, title="all", img=img_show, mode=5, is_test=is_test)
  252. log("table_line_pdf cost: " + str(time.time() - start_time))
  253. return line_list
  254. def show(pred_or_lines, title='', prob=0.2, img=None, mode=1, is_test=0):
  255. if not is_test:
  256. return
  257. if mode == 1:
  258. plt.figure()
  259. plt.title(title)
  260. _array = []
  261. for _h in range(len(pred_or_lines)):
  262. _line = []
  263. for _w in range(len(pred_or_lines[_h])):
  264. _prob = pred_or_lines[_h][_w]
  265. if _prob[0] > prob:
  266. _line.append((0, 0, 255))
  267. elif _prob[1] > prob:
  268. _line.append((255, 0, 0))
  269. else:
  270. _line.append((255, 255, 255))
  271. _array.append(_line)
  272. # plt.axis('off')
  273. plt.imshow(np.array(_array))
  274. plt.show()
  275. elif mode == 2:
  276. plt.figure()
  277. plt.title(title)
  278. for _line in pred_or_lines:
  279. x0, y0, x1, y1 = _line
  280. plt.plot([x0, x1], [y0, y1])
  281. plt.show()
  282. elif mode == 3:
  283. for _line in pred_or_lines:
  284. x0, y0 = _line[0]
  285. x1, y1 = _line[1]
  286. cv2.line(img, [int(x0), int(y0)], [int(x1), int(y1)], (0, 0, 255), 2)
  287. cv2.namedWindow(title, cv2.WINDOW_NORMAL)
  288. cv2.imshow(title, img)
  289. cv2.waitKey(0)
  290. elif mode == 4:
  291. for point in pred_or_lines:
  292. point = [int(x) for x in point]
  293. cv2.circle(img, (point[0], point[1]), 1, (0, 255, 0), 2)
  294. cv2.namedWindow(title, cv2.WINDOW_NORMAL)
  295. cv2.imshow(title, img)
  296. cv2.waitKey(0)
  297. elif mode == 5:
  298. for _line in pred_or_lines:
  299. x0, y0, x1, y1 = _line
  300. cv2.line(img, [int(x0), int(y0)], [int(x1), int(y1)], (0, 255, 0), 2)
  301. cv2.namedWindow(title, cv2.WINDOW_NORMAL)
  302. cv2.imshow(title, img)
  303. cv2.waitKey(0)
  304. def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=8, padding=3, min_len=10,
  305. cell_width=13):
  306. _time = time.time()
  307. log("starting points2lines")
  308. height = len(pred)
  309. width = len(pred[0])
  310. _sum = list(np.sum(np.array((pred[..., 0] > prob)).astype(int), axis=1))
  311. h_index = -1
  312. h_lines = []
  313. v_lines = []
  314. _step = line_width
  315. while 1:
  316. h_index += 1
  317. if h_index >= height:
  318. break
  319. w_index = -1
  320. if sourceP_LB:
  321. h_i = height - 1 - h_index
  322. else:
  323. h_i = h_index
  324. _start = None
  325. if _sum[h_index] < min_len:
  326. continue
  327. last_back = 0
  328. while 1:
  329. if w_index >= width:
  330. if _start is not None:
  331. _end = w_index - 1
  332. _bbox = [_start, h_i, _end, h_i]
  333. _dict = {"bbox": _bbox}
  334. h_lines.append(_dict)
  335. _start = None
  336. break
  337. _h, _v = pred[h_i][w_index]
  338. if _h > prob:
  339. if _start is None:
  340. _start = w_index
  341. w_index += _step
  342. else:
  343. if _start is not None:
  344. _end = w_index - 1
  345. _bbox = [_start, h_i, _end, h_i]
  346. _dict = {"bbox": _bbox}
  347. h_lines.append(_dict)
  348. _start = None
  349. w_index -= _step // 2
  350. if w_index <= last_back:
  351. w_index = last_back + _step // 2
  352. last_back = w_index
  353. log("starting points2lines 1")
  354. w_index = -1
  355. _sum = list(np.sum(np.array((pred[..., 1] > prob)).astype(int), axis=0))
  356. _step = line_width
  357. while 1:
  358. w_index += 1
  359. if w_index >= width:
  360. break
  361. if _sum[w_index] < min_len:
  362. continue
  363. h_index = -1
  364. _start = None
  365. last_back = 0
  366. list_test = []
  367. list_lineprob = []
  368. while 1:
  369. if h_index >= height:
  370. if _start is not None:
  371. _end = last_h
  372. _bbox = [w_index, _start, w_index, _end]
  373. _dict = {"bbox": _bbox}
  374. v_lines.append(_dict)
  375. _start = None
  376. list_test.append(_dict)
  377. break
  378. if sourceP_LB:
  379. h_i = height - 1 - h_index
  380. else:
  381. h_i = h_index
  382. _h, _v = pred[h_index][w_index]
  383. list_lineprob.append((h_index, _v))
  384. if _v > prob:
  385. if _start is None:
  386. _start = h_i
  387. h_index += _step
  388. else:
  389. if _start is not None:
  390. _end = last_h
  391. _bbox = [w_index, _start, w_index, _end]
  392. _dict = {"bbox": _bbox}
  393. v_lines.append(_dict)
  394. _start = None
  395. list_test.append(_dict)
  396. h_index -= _step // 2
  397. if h_index <= last_back:
  398. h_index = last_back + _step // 2
  399. last_back = h_index
  400. last_h = h_i
  401. log("starting points2lines 2")
  402. for _line in h_lines:
  403. _bbox = _line["bbox"]
  404. _bbox = [max(_bbox[0] - 2, 0), (_bbox[1] + _bbox[3]) / 2, _bbox[2] + 2, (_bbox[1] + _bbox[3]) / 2]
  405. _line["bbox"] = _bbox
  406. for _line in v_lines:
  407. _bbox = _line["bbox"]
  408. _bbox = [(_bbox[0] + _bbox[2]) / 2, max(_bbox[1] - 2, 0), (_bbox[0] + _bbox[2]) / 2, _bbox[3] + 2]
  409. _line["bbox"] = _bbox
  410. h_lines = lines_cluster(h_lines, line_width=line_width)
  411. v_lines = lines_cluster(v_lines, line_width=line_width)
  412. list_line = []
  413. for _line in h_lines:
  414. _bbox = _line["bbox"]
  415. _bbox = [max(_bbox[0] - 1, 0), (_bbox[1] + _bbox[3]) / 2, _bbox[2] + 1, (_bbox[1] + _bbox[3]) / 2]
  416. list_line.append(_bbox)
  417. for _line in v_lines:
  418. _bbox = _line["bbox"]
  419. _bbox = [(_bbox[0] + _bbox[2]) / 2, max(_bbox[1] - 1, 0), (_bbox[0] + _bbox[2]) / 2, _bbox[3] + 1]
  420. list_line.append(_bbox)
  421. log("points2lines cost %.2fs" % (time.time() - _time))
  422. # import matplotlib.pyplot as plt
  423. # plt.figure()
  424. # for _line in list_line:
  425. # x0,y0,x1,y1 = _line
  426. # plt.plot([x0,x1],[y0,y1])
  427. # for _line in list_line:
  428. # x0,y0,x1,y1 = _line.bbox
  429. # plt.plot([x0,x1],[y0,y1])
  430. # for point in list_crosspoints:
  431. # plt.scatter(point.get("point")[0],point.get("point")[1])
  432. # plt.show()
  433. return list_line
  434. def lines_cluster(list_lines, line_width):
  435. after_len = 0
  436. prelength = len(list_lines)
  437. append_width = line_width // 2
  438. while 1:
  439. c_lines = []
  440. first_len = after_len
  441. for _line in list_lines:
  442. bbox = _line["bbox"]
  443. _find = False
  444. for c_l_i in range(len(c_lines)):
  445. c_l = c_lines[len(c_lines) - c_l_i - 1]
  446. bbox1 = c_l["bbox"]
  447. bboxa = [max(0, bbox[0] - append_width), max(0, bbox[1] - append_width), bbox[2] + append_width,
  448. bbox[3] + append_width]
  449. bboxb = [max(0, bbox1[0] - append_width), max(0, bbox1[1] - append_width), bbox1[2] + append_width,
  450. bbox1[3] + append_width]
  451. _iou = getIOU(bboxa, bboxb)
  452. if _iou > 0:
  453. new_bbox = [min(bbox[0], bbox[2], bbox1[0], bbox1[2]), min(bbox[1], bbox[3], bbox1[1], bbox1[3]),
  454. max(bbox[0], bbox[2], bbox1[0], bbox1[2]), max(bbox[1], bbox[3], bbox1[1], bbox1[3])]
  455. _find = True
  456. c_l["bbox"] = new_bbox
  457. break
  458. if not _find:
  459. c_lines.append(_line)
  460. after_len = len(c_lines)
  461. if first_len == after_len:
  462. break
  463. list_lines = c_lines
  464. log("cluster lines from %d to %d" % (prelength, len(list_lines)))
  465. return c_lines
  466. def getIOU(bbox0, bbox1):
  467. width = abs(max(bbox0[2], bbox1[2]) - min(bbox0[0], bbox1[0])) - (
  468. abs(bbox0[2] - bbox0[0]) + abs(bbox1[2] - bbox1[0]))
  469. height = abs(max(bbox0[3], bbox1[3]) - min(bbox0[1], bbox1[1])) - (
  470. abs(bbox0[3] - bbox0[1]) + abs(bbox1[3] - bbox1[1]))
  471. if width <= 0 and height <= 0:
  472. iou = abs(width * height / min(abs((bbox0[2] - bbox0[0]) * (bbox0[3] - bbox0[1])),
  473. abs((bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]))))
  474. # print("getIOU", iou)
  475. return iou + 0.1
  476. return 0
  477. def delete_short_lines(list_lines, image_shape, scale=100):
  478. # 排除太短的线
  479. x_min_len = max(5, int(image_shape[0] / scale))
  480. y_min_len = max(5, int(image_shape[1] / scale))
  481. new_list_lines = []
  482. for line in list_lines:
  483. if line[0] == line[2]:
  484. if abs(line[3] - line[1]) >= y_min_len:
  485. # print("y_min_len", abs(line[3] - line[1]), y_min_len)
  486. new_list_lines.append(line)
  487. else:
  488. if abs(line[2] - line[0]) >= x_min_len:
  489. # print("x_min_len", abs(line[2] - line[0]), x_min_len)
  490. new_list_lines.append(line)
  491. return new_list_lines
  492. def delete_single_lines(row_line_list, col_line_list, point_list):
  493. new_col_line_list = []
  494. min_point_cnt = 2
  495. for line in col_line_list:
  496. p_cnt = 0
  497. for p in point_list:
  498. # if line[0] == p[0] and line[1] <= p[1] <= line[3]:
  499. if line[0] == p[0]:
  500. p_cnt += 1
  501. if p_cnt >= min_point_cnt:
  502. new_col_line_list.append(line)
  503. break
  504. new_row_line_list = []
  505. for line in row_line_list:
  506. p_cnt = 0
  507. for p in point_list:
  508. # if line[1] == p[1] and line[0] <= p[0] <= line[2]:
  509. if line[1] == p[1]:
  510. p_cnt += 1
  511. if p_cnt >= min_point_cnt:
  512. new_row_line_list.append(line)
  513. break
  514. return new_row_line_list, new_col_line_list
  515. @memory_decorator
  516. def merge_line(lines, axis, threshold=5):
  517. """
  518. 解决模型预测一条直线错开成多条直线,合并成一条直线
  519. :param lines: 线条列表
  520. :param axis: 0:横线 1:竖线
  521. :param threshold: 两条线间像素差阈值
  522. :return: 合并后的线条列表
  523. """
  524. # 任意一条line获取该合并的line,横线往下找,竖线往右找
  525. start_time = time.time()
  526. lines.sort(key=lambda x: (x[axis], x[1 - axis]))
  527. merged_lines = []
  528. used_lines = []
  529. for line1 in lines:
  530. if line1 in used_lines:
  531. continue
  532. merged_line = [line1]
  533. used_lines.append(line1)
  534. for line2 in lines:
  535. if line2 in used_lines:
  536. continue
  537. if line1[1 - axis] - threshold <= line2[1 - axis] <= line1[1 - axis] + threshold:
  538. # 计算基准长度
  539. min_axis = 10000
  540. max_axis = 0
  541. for line3 in merged_line:
  542. if line3[axis] < min_axis:
  543. min_axis = line3[axis]
  544. if line3[axis + 2] > max_axis:
  545. max_axis = line3[axis + 2]
  546. # 判断两条线有无交集
  547. if min_axis <= line2[axis] <= max_axis \
  548. or min_axis <= line2[axis + 2] <= max_axis:
  549. merged_line.append(line2)
  550. used_lines.append(line2)
  551. if merged_line:
  552. merged_lines.append(merged_line)
  553. # 合并line
  554. result_lines = []
  555. for merged_line in merged_lines:
  556. # 获取line宽的平均值
  557. axis_average = 0
  558. for line in merged_line:
  559. axis_average += line[1 - axis]
  560. axis_average = int(axis_average / len(merged_line))
  561. # 获取最长line两端
  562. merged_line.sort(key=lambda x: (x[axis]))
  563. axis_start = merged_line[0][axis]
  564. merged_line.sort(key=lambda x: (x[axis + 2]))
  565. axis_end = merged_line[-1][axis + 2]
  566. if axis:
  567. result_lines.append([axis_average, axis_start, axis_average, axis_end])
  568. else:
  569. result_lines.append([axis_start, axis_average, axis_end, axis_average])
  570. log('merge_line2 cost: ' + str(time.time()-start_time))
  571. return result_lines
  572. def get_points(row_lines, col_lines, image_size, threshold=5):
  573. # 创建空图
  574. row_img = np.zeros(image_size, np.uint8)
  575. col_img = np.zeros(image_size, np.uint8)
  576. # 画线
  577. # threshold = 5
  578. for row in row_lines:
  579. cv2.line(row_img, (int(row[0] - threshold), int(row[1])), (int(row[2] + threshold), int(row[3])), (255, 255, 255), 1)
  580. for col in col_lines:
  581. cv2.line(col_img, (int(col[0]), int(col[1] - threshold)), (int(col[2]), int(col[3] + threshold)), (255, 255, 255), 1)
  582. # cv2.imshow('get_points', row_img+col_img)
  583. # cv2.waitKey(0)
  584. # 求出交点
  585. point_img = np.bitwise_and(row_img, col_img)
  586. # cv2.imwrite("get_points.jpg", row_img+col_img)
  587. # cv2.imshow("get_points", row_img+col_img)
  588. # cv2.waitKey(0)
  589. # 识别黑白图中的白色交叉点,将横纵坐标取出
  590. ys, xs = np.where(point_img > 0)
  591. points = []
  592. for i in range(len(xs)):
  593. points.append((xs[i], ys[i]))
  594. points.sort(key=lambda x: (x[0], x[1]))
  595. return points
  596. def fix_outline(image, row_line_list, col_line_list, point_list, scale=25):
  597. log("into fix_outline")
  598. x_min_len = max(10, int(image.shape[0] / scale))
  599. y_min_len = max(10, int(image.shape[1] / scale))
  600. if len(row_line_list) <= 1 or len(col_line_list) <= 1:
  601. return [], [], row_line_list, col_line_list
  602. # 预测线取上下左右4个边(会有超出表格部分) [(), ()]
  603. row_line_list.sort(key=lambda x: (x[1], x[0]))
  604. up_line = row_line_list[0]
  605. bottom_line = row_line_list[-1]
  606. col_line_list.sort(key=lambda x: x[0])
  607. left_line = col_line_list[0]
  608. right_line = col_line_list[-1]
  609. # 计算单格高度宽度
  610. if len(row_line_list) > 1:
  611. height_dict = {}
  612. for j in range(len(row_line_list)):
  613. if j + 1 > len(row_line_list) - 1:
  614. break
  615. height = abs(int(row_line_list[j][3] - row_line_list[j + 1][3]))
  616. if height >= 10:
  617. if height in height_dict.keys():
  618. height_dict[height] = height_dict[height] + 1
  619. else:
  620. height_dict[height] = 1
  621. height_list = [[x, height_dict[x]] for x in height_dict.keys()]
  622. if height_list:
  623. height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True)
  624. # print("box_height", height_list)
  625. box_height = height_list[0][0]
  626. else:
  627. box_height = y_min_len
  628. else:
  629. box_height = y_min_len
  630. if len(col_line_list) > 1:
  631. box_width = abs(col_line_list[1][2] - col_line_list[0][2])
  632. else:
  633. box_width = x_min_len
  634. # 设置轮廓线需超出阈值
  635. if box_height >= 2 * y_min_len:
  636. fix_h_len = y_min_len
  637. else:
  638. fix_h_len = box_height * 2 / 3
  639. if box_width >= 2 * x_min_len:
  640. fix_w_len = x_min_len
  641. else:
  642. fix_w_len = box_width * 2 / 3
  643. # 判断超出部分的长度,超出一定长度就补线
  644. new_row_lines = []
  645. new_col_lines = []
  646. all_longer_row_lines = []
  647. all_longer_col_lines = []
  648. # print('box_height, box_width, fix_h_len, fix_w_len', box_height, box_width, fix_h_len, fix_w_len)
  649. # print('bottom_line, left_line, right_line', bottom_line, left_line, right_line)
  650. # 补左右两条竖线超出来的线的row
  651. if up_line[1] - left_line[1] >= fix_h_len and up_line[1] - right_line[1] >= fix_h_len:
  652. if up_line[1] - left_line[1] >= up_line[1] - right_line[1]:
  653. new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]])
  654. new_col_y = left_line[1]
  655. # 补了row,要将其他短的col连到row上
  656. for j in range(len(col_line_list)):
  657. col = col_line_list[j]
  658. if abs(new_col_y - col[1]) <= box_height:
  659. col_line_list[j][1] = min([new_col_y, col[1]])
  660. else:
  661. new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]])
  662. new_col_y = right_line[1]
  663. # 补了row,要将其他短的col连到row上
  664. for j in range(len(col_line_list)):
  665. col = col_line_list[j]
  666. # 且距离不能相差太大
  667. if abs(new_col_y - col[1]) <= box_height:
  668. col_line_list[j][1] = min([new_col_y, col[1]])
  669. if left_line[3] - bottom_line[3] >= fix_h_len and right_line[3] - bottom_line[3] >= fix_h_len:
  670. if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]:
  671. new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]])
  672. new_col_y = left_line[3]
  673. # 补了row,要将其他短的col连到row上
  674. for j in range(len(col_line_list)):
  675. col = col_line_list[j]
  676. # 且距离不能相差太大
  677. if abs(new_col_y - col[3]) <= box_height:
  678. col_line_list[j][3] = max([new_col_y, col[3]])
  679. else:
  680. new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]])
  681. new_col_y = right_line[3]
  682. # 补了row,要将其他短的col连到row上
  683. for j in range(len(col_line_list)):
  684. col = col_line_list[j]
  685. # 且距离不能相差太大
  686. if abs(new_col_y - col[3]) <= box_height:
  687. col_line_list[j][3] = max([new_col_y, col[3]])
  688. # 补上下两条横线超出来的线的col
  689. if left_line[0] - up_line[0] >= fix_w_len and left_line[0] - bottom_line[0] >= fix_w_len:
  690. if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]:
  691. new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]])
  692. new_row_x = up_line[0]
  693. # 补了col,要将其他短的row连到col上
  694. for j in range(len(row_line_list)):
  695. row = row_line_list[j]
  696. # 且距离不能相差太大
  697. if abs(new_row_x - row[0]) <= box_width:
  698. row_line_list[j][0] = min([new_row_x, row[0]])
  699. else:
  700. new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]])
  701. new_row_x = bottom_line[0]
  702. # 补了col,要将其他短的row连到col上
  703. for j in range(len(row_line_list)):
  704. row = row_line_list[j]
  705. # 且距离不能相差太大
  706. if abs(new_row_x - row[0]) <= box_width:
  707. row_line_list[j][0] = min([new_row_x, row[0]])
  708. if up_line[2] - right_line[2] >= fix_w_len and bottom_line[2] - right_line[2] >= fix_w_len:
  709. if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]:
  710. new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]])
  711. new_row_x = up_line[2]
  712. # 补了col,要将其他短的row连到col上
  713. for j in range(len(row_line_list)):
  714. row = row_line_list[j]
  715. # 且距离不能相差太大
  716. if abs(new_row_x - row[2]) <= box_width:
  717. row_line_list[j][2] = max([new_row_x, row[2]])
  718. else:
  719. new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]])
  720. new_row_x = bottom_line[2]
  721. # 补了col,要将其他短的row连到col上
  722. for j in range(len(row_line_list)):
  723. row = row_line_list[j]
  724. # 且距离不能相差太大
  725. if abs(new_row_x - row[2]) <= box_width:
  726. row_line_list[j][2] = max([new_row_x, row[2]])
  727. all_longer_row_lines += row_line_list
  728. all_longer_col_lines += col_line_list
  729. # print('new_row_lines, new_col_lines', new_row_lines, new_col_lines)
  730. # print('all_longer_row_lines, all_longer_col_lines', all_longer_row_lines, all_longer_col_lines)
  731. return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines
  732. def fix_inner(row_line_list, col_line_list, point_list):
  733. def fix(fix_lines, assist_lines, split_points, axis):
  734. new_line_point_list = []
  735. delete_line_point_list = []
  736. for line1 in fix_lines:
  737. min_assist_line = [[], []]
  738. min_distance = [1000, 1000]
  739. if_find = [0, 0]
  740. # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点
  741. fix_line_points = []
  742. for point in split_points:
  743. if abs(point[1 - axis] - line1[1 - axis]) <= 2:
  744. if line1[axis] <= point[axis] <= line1[axis + 2]:
  745. fix_line_points.append(point)
  746. # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交
  747. line1_point = [line1[:2], line1[2:]]
  748. for i in range(2):
  749. point = line1_point[i]
  750. for line2 in assist_lines:
  751. if not if_find[i] and abs(point[axis] - line2[axis]) <= 2:
  752. if line1[1 - axis] <= point[1 - axis] <= line2[1 - axis + 2]:
  753. # print("line1, match line2", line1, line2)
  754. if_find[i] = 1
  755. break
  756. else:
  757. if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1 - axis] <= point[1 - axis] <= \
  758. line2[1 - axis + 2]:
  759. if line1[axis] <= line2[axis] <= line1[axis + 2]:
  760. continue
  761. min_distance[i] = abs(point[axis] - line2[axis])
  762. min_assist_line[i] = line2
  763. if len(min_assist_line[0]) == 0 and len(min_assist_line[1]) == 0:
  764. continue
  765. # 找出离assist_line最近的交点
  766. min_distance = [1000, 1000]
  767. min_col_point = [[], []]
  768. for i in range(2):
  769. # print("顶点", i, line1_point[i])
  770. if min_assist_line[i]:
  771. for point in fix_line_points:
  772. if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]:
  773. min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
  774. min_col_point[i] = point
  775. if len(min_col_point[0]) == 0 and len(min_col_point[1]) == 0:
  776. continue
  777. # print('line1', line1)
  778. # print("min_col_point", min_col_point)
  779. # print("min_assist_line", min_assist_line)
  780. # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3
  781. # print("line1_point", line1_point)
  782. if min_assist_line[0] and min_assist_line[0] == min_assist_line[1]:
  783. if min_assist_line[0][axis] < line1_point[0][axis]:
  784. bbox_len = abs(min_col_point[0][axis] - min_assist_line[0][axis])
  785. line_distance = abs(min_col_point[0][axis] - line1_point[0][axis])
  786. if bbox_len / 3 <= line_distance <= bbox_len:
  787. if axis == 1:
  788. add_point = (line1_point[0][1 - axis], min_assist_line[0][axis])
  789. else:
  790. add_point = (min_assist_line[0][axis], line1_point[0][1 - axis])
  791. new_line_point_list.append([line1, add_point])
  792. elif min_assist_line[1][axis] > line1_point[1][axis]:
  793. bbox_len = abs(min_col_point[1][axis] - min_assist_line[1][axis])
  794. line_distance = abs(min_col_point[1][axis] - line1_point[1][axis])
  795. if bbox_len / 3 <= line_distance <= bbox_len:
  796. if axis == 1:
  797. add_point = (line1_point[1][1 - axis], min_assist_line[1][axis])
  798. else:
  799. add_point = (min_assist_line[1][axis], line1_point[1][1 - axis])
  800. new_line_point_list.append([line1, add_point])
  801. else:
  802. for i in range(2):
  803. if min_col_point[i]:
  804. bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis])
  805. line_distance = abs(min_col_point[i][axis] - line1_point[i][axis])
  806. # print("bbox_len, line_distance", bbox_len, line_distance)
  807. if bbox_len / 3 <= line_distance <= bbox_len:
  808. if axis == 1:
  809. add_point = (line1_point[i][1 - axis], min_assist_line[i][axis])
  810. else:
  811. add_point = (min_assist_line[i][axis], line1_point[i][1 - axis])
  812. new_line_point_list.append([line1, add_point])
  813. return new_line_point_list
  814. row_line_list_copy = copy.deepcopy(row_line_list)
  815. col_line_list_copy = copy.deepcopy(col_line_list)
  816. try:
  817. new_point_list = fix(col_line_list, row_line_list, point_list, axis=1)
  818. for line, new_point in new_point_list:
  819. if line in col_line_list:
  820. index = col_line_list.index(line)
  821. point1 = line[:2]
  822. point2 = line[2:]
  823. if new_point[1] >= point2[1]:
  824. col_line_list[index] = [point1[0], point1[1], new_point[0], new_point[1]]
  825. elif new_point[1] <= point1[1]:
  826. col_line_list[index] = [new_point[0], new_point[1], point2[0], point2[1]]
  827. new_point_list = fix(row_line_list, col_line_list, point_list, axis=0)
  828. for line, new_point in new_point_list:
  829. if line in row_line_list:
  830. index = row_line_list.index(line)
  831. point1 = line[:2]
  832. point2 = line[2:]
  833. if new_point[0] >= point2[0]:
  834. row_line_list[index] = [point1[0], point1[1], new_point[0], new_point[1]]
  835. elif new_point[0] <= point1[0]:
  836. row_line_list[index] = [new_point[0], new_point[1], point2[0], point2[1]]
  837. return row_line_list, col_line_list
  838. except:
  839. traceback.print_exc()
  840. return row_line_list_copy, col_line_list_copy
  841. def fix_4_points(cross_points, row_line_list, col_line_list):
  842. if not (len(row_line_list) >= 2 and len(col_line_list) >= 2):
  843. return row_line_list, col_line_list
  844. cross_points.sort(key=lambda x: (x[0], x[1]))
  845. left_up_p = cross_points[0]
  846. right_down_p = cross_points[-1]
  847. cross_points.sort(key=lambda x: (-x[0], x[1]))
  848. right_up_p = cross_points[0]
  849. left_down_p = cross_points[-1]
  850. # print('left_up_p', left_up_p, 'left_down_p', left_down_p)
  851. # print('right_up_p', right_up_p, 'right_down_p', right_down_p)
  852. min_x = min(left_up_p[0], left_down_p[0], right_down_p[0], right_up_p[0])
  853. max_x = max(left_up_p[0], left_down_p[0], right_down_p[0], right_up_p[0])
  854. min_y = min(left_up_p[1], left_down_p[1], right_down_p[1], right_up_p[1])
  855. max_y = max(left_up_p[1], left_down_p[1], right_down_p[1], right_up_p[1])
  856. if left_up_p[0] != min_x or left_up_p[1] != min_y:
  857. log('轮廓左上角交点有问题')
  858. row_line_list.append([min_x, min_y, max_x, min_y])
  859. col_line_list.append([min_x, min_y, min_x, max_y])
  860. if left_down_p[0] != min_x or left_down_p[1] != max_y:
  861. log('轮廓左下角交点有问题')
  862. row_line_list.append([min_x, max_y, max_x, max_y])
  863. col_line_list.append([min_x, min_y, min_x, max_y])
  864. if right_up_p[0] != max_x or right_up_p[1] != min_y:
  865. log('轮廓右上角交点有问题')
  866. row_line_list.append([min_x, max_y, max_x, max_y])
  867. col_line_list.append([max_x, min_y, max_x, max_y])
  868. if right_down_p[0] != max_x or right_down_p[1] != max_y:
  869. log('轮廓右下角交点有问题')
  870. row_line_list.append([min_x, max_y, max_x, max_y])
  871. col_line_list.append([max_x, min_y, max_x, max_y])
  872. return row_line_list, col_line_list
  873. def get_split_line(points, col_lines, image_np, threshold=5):
  874. # 线贴着边缘无法得到split_y,导致无法分区
  875. for _col in col_lines:
  876. if _col[3] >= image_np.shape[0] - 5:
  877. _col[3] = image_np.shape[0] - 6
  878. if _col[1] <= 0 + 5:
  879. _col[1] = 6
  880. # print("get_split_line", image_np.shape)
  881. points.sort(key=lambda x: (x[1], x[0]))
  882. # 遍历y坐标,并判断y坐标与上一个y坐标是否存在连接线
  883. i = 0
  884. split_line_y = []
  885. for point in points:
  886. # 从已分开的线下面开始判断
  887. if split_line_y:
  888. if point[1] <= split_line_y[-1] + threshold:
  889. last_y = point[1]
  890. continue
  891. if last_y <= split_line_y[-1] + threshold:
  892. last_y = point[1]
  893. continue
  894. if i == 0:
  895. last_y = point[1]
  896. i += 1
  897. continue
  898. current_line = (last_y, point[1])
  899. split_flag = 1
  900. for col in col_lines:
  901. # 只要找到一条col包含就不是分割线
  902. if current_line[0] >= col[1] - 3 and current_line[1] <= col[3] + 3:
  903. split_flag = 0
  904. break
  905. if split_flag:
  906. split_line_y.append(current_line[0] + 5)
  907. split_line_y.append(current_line[1] - 5)
  908. last_y = point[1]
  909. # 加上收尾分割线
  910. points.sort(key=lambda x: (x[1], x[0]))
  911. y_min = points[0][1]
  912. y_max = points[-1][1]
  913. if y_min - threshold < 0:
  914. split_line_y.append(0)
  915. else:
  916. split_line_y.append(y_min - threshold)
  917. if y_max + threshold > image_np.shape[0]:
  918. split_line_y.append(image_np.shape[0])
  919. else:
  920. split_line_y.append(y_max + threshold)
  921. split_line_y = list(set(split_line_y))
  922. # 剔除两条相隔太近分割线
  923. temp_split_line_y = []
  924. split_line_y.sort(key=lambda x: x)
  925. last_y = -20
  926. for y in split_line_y:
  927. if y - last_y >= 20:
  928. temp_split_line_y.append(y)
  929. last_y = y
  930. split_line_y = temp_split_line_y
  931. # 生成分割线
  932. split_line = []
  933. for y in split_line_y:
  934. split_line.append([(0, y), (image_np.shape[1], y)])
  935. split_line.append([(0, 0), (image_np.shape[1], 0)])
  936. split_line.append([(0, image_np.shape[0]), (image_np.shape[1], image_np.shape[0])])
  937. split_line.sort(key=lambda x: x[0][1])
  938. return split_line, split_line_y
  939. def get_split_area(split_y, row_line_list, col_line_list, cross_points):
  940. # 分割线纵坐标
  941. if len(split_y) < 2:
  942. return [], [], []
  943. split_y.sort(key=lambda x: x)
  944. # new_split_y = []
  945. # for i in range(1, len(split_y), 2):
  946. # new_split_y.append(int((split_y[i] + split_y[i - 1]) / 2))
  947. area_row_line_list = []
  948. area_col_line_list = []
  949. area_point_list = []
  950. for i in range(1, len(split_y)):
  951. y = split_y[i]
  952. last_y = split_y[i - 1]
  953. split_row = []
  954. for row in row_line_list:
  955. if last_y <= row[3] <= y:
  956. split_row.append(row)
  957. split_col = []
  958. for col in col_line_list:
  959. if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]:
  960. split_col.append(col)
  961. split_point = []
  962. for point in cross_points:
  963. if last_y <= point[1] <= y:
  964. split_point.append(point)
  965. # 满足条件才能形成表格区域
  966. if len(split_row) >= 2 and len(split_col) >= 2 and len(split_point) >= 4:
  967. # print('len(split_row), len(split_col), len(split_point)', len(split_row), len(split_col), len(split_point))
  968. area_row_line_list.append(split_row)
  969. area_col_line_list.append(split_col)
  970. area_point_list.append(split_point)
  971. return area_row_line_list, area_col_line_list, area_point_list
  972. def get_standard_lines(row_line_list, col_line_list):
  973. new_row_line_list = []
  974. for row in row_line_list:
  975. w1 = row[0]
  976. w2 = row[2]
  977. # 横线的两个顶点分别找到最近的竖线
  978. min_distance = [10000, 10000]
  979. min_dis_w = [None, None]
  980. for col in col_line_list:
  981. if abs(col[0] - w1) < min_distance[0]:
  982. min_distance[0] = abs(col[0] - w1)
  983. min_dis_w[0] = col[0]
  984. if abs(col[0] - w2) < min_distance[1]:
  985. min_distance[1] = abs(col[0] - w2)
  986. min_dis_w[1] = col[0]
  987. if min_dis_w[0] is not None:
  988. row[0] = min_dis_w[0]
  989. if min_dis_w[1] is not None:
  990. row[2] = min_dis_w[1]
  991. new_row_line_list.append(row)
  992. new_col_line_list = []
  993. for col in col_line_list:
  994. h1 = col[1]
  995. h2 = col[3]
  996. # 横线的两个顶点分别找到最近的竖线
  997. min_distance = [10000, 10000]
  998. min_dis_w = [None, None]
  999. for row in row_line_list:
  1000. if abs(row[1] - h1) < min_distance[0]:
  1001. min_distance[0] = abs(row[1] - h1)
  1002. min_dis_w[0] = row[1]
  1003. if abs(row[1] - h2) < min_distance[1]:
  1004. min_distance[1] = abs(row[1] - h2)
  1005. min_dis_w[1] = row[1]
  1006. if min_dis_w[0] is not None:
  1007. col[1] = min_dis_w[0]
  1008. if min_dis_w[1] is not None:
  1009. col[3] = min_dis_w[1]
  1010. new_col_line_list.append(col)
  1011. # all_line_list = []
  1012. # # 横线竖线两个维度
  1013. # for i in range(2):
  1014. # axis = i
  1015. # cross_points.sort(key=lambda x: (x[axis], x[1-axis]))
  1016. # current_axis = cross_points[0][axis]
  1017. # points = []
  1018. # line_list = []
  1019. # for p in cross_points:
  1020. # if p[axis] == current_axis:
  1021. # points.append(p)
  1022. # else:
  1023. # if points:
  1024. # line_list.append([points[0][0], points[0][1], points[-1][0], points[-1][1]])
  1025. # points = [p]
  1026. # current_axis = p[axis]
  1027. # if points:
  1028. # line_list.append([points[0][0], points[0][1], points[-1][0], points[-1][1]])
  1029. # all_line_list.append(line_list)
  1030. # new_col_line_list, new_row_line_list = all_line_list
  1031. return new_col_line_list, new_row_line_list
  1032. def add_outline(cross_points, row_line_list, col_line_list):
  1033. cross_points.sort(key=lambda x: (x[0], x[1]))
  1034. left_up_p = cross_points[0]
  1035. right_down_p = cross_points[-1]
  1036. row_line_list.append([left_up_p[0], left_up_p[1], right_down_p[0], left_up_p[1]])
  1037. row_line_list.append([left_up_p[0], right_down_p[1], right_down_p[0], right_down_p[1]])
  1038. col_line_list.append([left_up_p[0], left_up_p[1], left_up_p[0], right_down_p[1]])
  1039. col_line_list.append([right_down_p[0], left_up_p[1], right_down_p[0], right_down_p[1]])
  1040. return row_line_list, col_line_list