table_postprocess.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. import copy
  2. import math
  3. import numpy as np
  4. import cv2
  5. import time
  6. def get_line_from_binary_image1(image_np, point_value=1, is_row=True, threshold=5,
  7. extend_px=0):
  8. """
  9. 根据像素点的变化,将像素点为特定值的转化为line,即找出端点坐标。
  10. 需要二值化的图。
  11. 仅支持竖线横线。
  12. :param image_np: numpy格式 image
  13. :param point_value: 像素点的特定值
  14. :param is_row: 是否是行,否则为列
  15. :param threshold: 行或列间的合并像素距离
  16. :param extend_px: 每条线延长的像素值
  17. :return: line list
  18. """
  19. # 取值大于point_value的点的坐标
  20. ys, xs = np.where(image_np >= point_value)
  21. points = [[xs[i], ys[i]] for i in range(len(xs))]
  22. lines = []
  23. # 提取横线
  24. if is_row:
  25. points.sort(key=lambda x: (x[1], x[0]))
  26. row_x, row_y = points[0]
  27. sub_line = []
  28. for p in points:
  29. # y在一定范围内,认为在同一列
  30. if row_y-threshold <= p[1] <= row_y+threshold:
  31. # 在同一行,且x连续
  32. if row_x-1 <= p[0] <= row_x+1:
  33. sub_line.append(p)
  34. # 在同一行,但x不连续
  35. else:
  36. if len(sub_line) >= 2:
  37. sub_line.sort(key=lambda x: (x[0], x[1]))
  38. lines.append([sub_line[0][0]-extend_px, sub_line[0][1],
  39. sub_line[-1][0]+extend_px, sub_line[0][1]])
  40. sub_line = []
  41. # 为了比较下个点是否连续,更新标准x值
  42. row_x = p[0]
  43. # 不在同一行
  44. else:
  45. row_y = p[1]
  46. if len(sub_line) >= 2:
  47. sub_line.sort(key=lambda x: (x[0], x[1]))
  48. lines.append([sub_line[0][0]-extend_px, sub_line[0][1],
  49. sub_line[-1][0]+extend_px, sub_line[0][1]])
  50. sub_line = []
  51. if len(sub_line) >= 2:
  52. sub_line.sort(key=lambda x: (x[0], x[1]))
  53. lines.append([sub_line[0][0]-extend_px, sub_line[0][1],
  54. sub_line[-1][0]+extend_px, sub_line[0][1]])
  55. # 提取竖线
  56. else:
  57. points.sort(key=lambda x: (x[0], x[1]))
  58. col_x, col_y = points[0]
  59. sub_line = []
  60. for p in points:
  61. # x在一定范围内,认为在同一列
  62. if col_x-threshold <= p[0] <= col_x+threshold:
  63. # 在同一列,且y连续
  64. if col_y-1 <= p[1] <= col_y+1:
  65. sub_line.append(p)
  66. # 在同一列,但y不连续
  67. else:
  68. if len(sub_line) >= 2:
  69. sub_line.sort(key=lambda x: (x[1], x[0]))
  70. lines.append([sub_line[0][0], sub_line[0][1]-extend_px,
  71. sub_line[0][0], sub_line[-1][1]+extend_px])
  72. sub_line = []
  73. # 为了比较下个点是否连续,更新标准y值
  74. col_y = p[1]
  75. # 不在同一列
  76. else:
  77. # 为了比较下一列,更新标准x值
  78. col_x = p[0]
  79. if len(sub_line) >= 2:
  80. sub_line.sort(key=lambda x: (x[1], x[0]))
  81. lines.append([sub_line[0][0], sub_line[0][1]-extend_px,
  82. sub_line[0][0], sub_line[-1][1]+extend_px])
  83. sub_line = []
  84. if len(sub_line) >= 2:
  85. sub_line.sort(key=lambda x: (x[1], x[0]))
  86. lines.append([sub_line[0][0], sub_line[0][1]-extend_px,
  87. sub_line[0][0], sub_line[-1][1]+extend_px])
  88. return lines
  89. def get_line_from_binary_image2(image_np, point_value=1, is_row=True, threshold=5,
  90. extend_px=0):
  91. """
  92. 根据像素点的变化,将像素点为特定值的转化为line,即找出端点坐标。
  93. 需要二值化的图。
  94. 仅支持竖线横线。
  95. :param image_np: numpy格式 image
  96. :param point_value: 像素点的特定值
  97. :param is_row: 是否是行,否则为列
  98. :param threshold: 行或列间的合并像素距离
  99. :param extend_px: 每条线延长的像素值
  100. :return: line list
  101. """
  102. def get_point_average(_list, axis=0):
  103. if axis:
  104. _list.sort(key=lambda x: (x[0], x[1]))
  105. else:
  106. _list.sort(key=lambda x: (x[1], x[0]))
  107. p_axis = 0
  108. for l in _list:
  109. p_axis += l[axis]
  110. p_axis = int(p_axis / len(_list))
  111. return p_axis
  112. def get_line_average(_list, axis=0):
  113. line = []
  114. if axis:
  115. sub_line.sort(key=lambda x: (x[1], x[0]))
  116. x = get_point_average(_list, 0)
  117. line.append([x, _list[0][1]-extend_px,
  118. x, _list[-1][1]+extend_px])
  119. else:
  120. _list.sort(key=lambda x: (x[0], x[1]))
  121. y = get_point_average(_list, 1)
  122. line.append([_list[0][0]-extend_px, y,
  123. _list[-1][0]+extend_px, y])
  124. return line
  125. # 取值大于point_value的点的坐标
  126. ys, xs = np.where(image_np >= point_value)
  127. points = [[xs[i], ys[i]] for i in range(len(xs))]
  128. lines = []
  129. used_points = []
  130. # 提取横线
  131. if is_row:
  132. points.sort(key=lambda x: (x[1], x[0]))
  133. row_x, row_y = points[0]
  134. sub_line = [points[0]]
  135. for p in points:
  136. if p in used_points:
  137. continue
  138. # y在一定范围内,认为在同一行
  139. if row_y-threshold <= p[1] <= row_y+threshold:
  140. # 在同一行,且x连续
  141. sub_line.sort(key=lambda z: z[0])
  142. if sub_line[0][0]-threshold <= p[0] <= sub_line[-1][0]+threshold:
  143. sub_line.append(p)
  144. # 在同一行,但x不连续
  145. else:
  146. if len(sub_line) >= 2:
  147. lines += get_line_average(sub_line, 0)
  148. used_points += sub_line
  149. sub_line = [p]
  150. # 不在同一行
  151. else:
  152. row_y = p[1]
  153. if len(sub_line) >= 2:
  154. lines += get_line_average(sub_line, 0)
  155. used_points += sub_line
  156. sub_line = [p]
  157. if len(sub_line) >= 2:
  158. lines += get_line_average(sub_line, 0)
  159. # 提取竖线
  160. else:
  161. points.sort(key=lambda x: (x[0], x[1]))
  162. col_x, col_y = points[0]
  163. sub_line = [points[0]]
  164. for p in points:
  165. if p in used_points:
  166. continue
  167. # x在一定范围内,认为在同一列
  168. if col_x-threshold <= p[0] <= col_x+threshold:
  169. # 在同一列,且y连续
  170. sub_line.sort(key=lambda z: z[1])
  171. if sub_line[0][1]-threshold <= p[1] <= sub_line[-1][1]+threshold:
  172. sub_line.append(p)
  173. # 在同一列,但y不连续
  174. else:
  175. if len(sub_line) >= 2:
  176. lines += get_line_average(sub_line, 1)
  177. used_points += sub_line
  178. sub_line = [p]
  179. # 不在同一列
  180. else:
  181. # 为了比较下一列,更新标准x值
  182. col_x = p[0]
  183. if len(sub_line) >= 2:
  184. lines += get_line_average(sub_line, 1)
  185. used_points += sub_line
  186. sub_line = [p]
  187. if len(sub_line) >= 2:
  188. lines += get_line_average(sub_line, 1)
  189. print("lines", lines)
  190. return lines
  191. def get_line_from_binary_image(image_np, point_value=1, axis=0):
  192. """
  193. 根据像素点的变化,将像素点为特定值的转化为line,即找出端点坐标。
  194. 需要二值化的图。
  195. 仅支持竖线横线。
  196. :param image_np: numpy格式 image
  197. :param point_value: 像素点的特定值
  198. :param is_row: 是否是行,否则为列
  199. :param threshold: 行或列间的合并像素距离
  200. :param extend_px: 每条线延长的像素值
  201. :return: line list
  202. """
  203. def get_axis_points(_list, axis=0):
  204. _list.sort(key=lambda x: (x[1-axis], x[axis]))
  205. standard_axis = points[axis][1-axis]
  206. axis_points = []
  207. sub_points = []
  208. for p in _list:
  209. if p[1-axis] == standard_axis:
  210. sub_points.append(p)
  211. else:
  212. standard_axis = p[1-axis]
  213. if sub_points:
  214. axis_points.append(sub_points)
  215. sub_points = []
  216. # 最后一行/列
  217. if sub_points:
  218. axis_points.append(sub_points)
  219. return axis_points
  220. def get_axis_lines(_list, axis=0):
  221. # 逐行/列判断,一行/列可能多条横线/竖线
  222. points_lines = []
  223. for axis_list in _list:
  224. sub_line = [axis_list[0]]
  225. for p in axis_list:
  226. # 设置基准点
  227. standard_p = sub_line[-1]
  228. # 判断连续
  229. if p[axis] - standard_p[axis] == 1:
  230. sub_line.append(p)
  231. else:
  232. points_lines.append(sub_line)
  233. sub_line = [p]
  234. # 最后一行/列
  235. if sub_line:
  236. points_lines.append(sub_line)
  237. # 许多点组成的line转为两点line
  238. lines = []
  239. for line in points_lines:
  240. line.sort(key=lambda x: (x[axis], x[1-axis]))
  241. lines.append([line[0][0], line[0][1], line[-1][0], line[-1][1]])
  242. return lines
  243. # 取值大于point_value的点的坐标
  244. ys, xs = np.where(image_np >= point_value)
  245. points = [[xs[i], ys[i]] for i in range(len(xs))]
  246. # 提出所有相同x或相同y的点
  247. # 提取行/列
  248. axis_points = get_axis_points(points, axis)
  249. # 提取每行/列的横线/竖线
  250. axis_lines = get_axis_lines(axis_points, axis)
  251. # print("axis_lines", axis_lines)
  252. return axis_lines
  253. def merge_line2(lines, axis, threshold=2):
  254. """
  255. 解决模型预测一条直线错开成多条直线,合并成一条直线
  256. :param lines: 线条列表
  257. :param axis: 0:横线 1:竖线
  258. :param threshold: 两条线间像素差阈值
  259. :return: 合并后的线条列表
  260. """
  261. # 竖线
  262. if axis:
  263. lines.sort(key=lambda x: (x[0], x[1]))
  264. # 循环找能合并的线,存储下标数组
  265. merge_list = []
  266. for i in range(len(lines)):
  267. col1 = lines[i]
  268. # 只需找一条
  269. sub_merge_list = [i]
  270. for j in range(i+1, len(lines)):
  271. col2 = lines[j]
  272. # x之间超出像素距离,跳出
  273. if abs(col1[0] - col2[0]) > threshold:
  274. break
  275. # 找到一条,跳出
  276. else:
  277. sub_merge_list.append(j)
  278. break
  279. # 找到加入
  280. if len(sub_merge_list) > 1:
  281. merge_list.append(sub_merge_list)
  282. # 横线
  283. else:
  284. lines.sort(key=lambda x: (x[1], x[0]))
  285. # 循环找能合并的线,存储下标数组
  286. merge_list = []
  287. for i in range(len(lines)):
  288. row1 = lines[i]
  289. # 只需找一条
  290. sub_merge_list = [i]
  291. for j in range(i+1,len(lines)):
  292. row2 = lines[j]
  293. # y之间超出像素距离,跳出
  294. if abs(row1[1] - row2[0]) > threshold:
  295. break
  296. # 找到一条,跳出
  297. else:
  298. sub_merge_list.append(j)
  299. break
  300. # 找到加入
  301. if len(sub_merge_list) > 1:
  302. merge_list.append(sub_merge_list)
  303. # 对所有下标待合并集合循环判断交集,有交集则并集
  304. intersection_list = []
  305. finished_list = []
  306. for i in range(len(merge_list)):
  307. # 处理过的下标跳过
  308. if i in finished_list:
  309. continue
  310. list1 = merge_list[i]
  311. sub_result_list = list1
  312. # 循环判断
  313. for j in range(len(merge_list)):
  314. # 处理过的下标跳过
  315. if j in finished_list:
  316. continue
  317. list2 = merge_list[j]
  318. # 交集
  319. if list(set(list1).intersection(set(list2))):
  320. # 并集
  321. sub_result_list = sub_result_list + list2
  322. finished_list.append(j)
  323. finished_list.append(i)
  324. sub_result_list = list(set(sub_result_list))
  325. sub_result_list.sort(key=lambda x: x)
  326. intersection_list.append(sub_result_list)
  327. # 根据不同情况保留组内的线
  328. hold_list = []
  329. # 竖线
  330. if axis:
  331. # 得到完整的线交集列表,选择保留哪一条
  332. for sub_result_list in intersection_list:
  333. # 有第一条
  334. if 0 in sub_result_list:
  335. # 保留分组中最后一条的x
  336. x1 = lines[sub_result_list[-1]][0]
  337. # 有最后一条或者是中间的线
  338. else:
  339. # 保留分组中第一条的x
  340. x1 = lines[sub_result_list[0]][0]
  341. # 取y最长的一条
  342. max_y_index = sub_result_list[0]
  343. max_y = 0
  344. for index in sub_result_list:
  345. if abs(lines[index][1] - lines[index][3]) > max_y:
  346. max_y = abs(lines[index][1] - lines[index][3])
  347. max_y_index = index
  348. y1 = lines[max_y_index][1]
  349. y2 = lines[max_y_index][3]
  350. hold_list.append([x1, y1, x1, y2])
  351. # 横线
  352. else:
  353. # 得到完整的线交集列表,选择保留哪一条
  354. for sub_result_list in intersection_list:
  355. # 有第一条
  356. if 0 in sub_result_list:
  357. # 保留分组中最后一条的y
  358. y1 = lines[sub_result_list[-1]][1]
  359. # 有最后一条或者是中间的线
  360. else:
  361. # 保留分组中第一条的y
  362. y1 = lines[sub_result_list[0]][1]
  363. # 取x最长的一条
  364. max_x_index = sub_result_list[0]
  365. max_x = 0
  366. for index in sub_result_list:
  367. if abs(lines[index][0] - lines[index][2]) > max_x:
  368. max_x = abs(lines[index][0] - lines[index][2])
  369. max_x_index = index
  370. x1 = lines[max_x_index][0]
  371. x2 = lines[max_x_index][2]
  372. hold_list.append([x1, y1, x2, y1])
  373. return hold_list
  374. def merge_line(lines, axis, threshold=5):
  375. """
  376. 解决模型预测一条直线错开成多条直线,合并成一条直线
  377. :param lines: 线条列表
  378. :param axis: 0:横线 1:竖线
  379. :param threshold: 两条线间像素差阈值
  380. :return: 合并后的线条列表
  381. """
  382. # 任意一条line获取该合并的line,横线往下找,竖线往右找
  383. lines.sort(key=lambda x: (x[axis], x[1-axis]))
  384. merged_lines = []
  385. used_lines = []
  386. for line1 in lines:
  387. if line1 in used_lines:
  388. continue
  389. merged_line = [line1]
  390. used_lines.append(line1)
  391. for line2 in lines:
  392. if line2 in used_lines:
  393. continue
  394. if line1[1-axis]-threshold <= line2[1-axis] <= line1[1-axis]+threshold:
  395. # 计算基准长度
  396. min_axis = 10000
  397. max_axis = 0
  398. for line3 in merged_line:
  399. if line3[axis] < min_axis:
  400. min_axis = line3[axis]
  401. if line3[axis+2] > max_axis:
  402. max_axis = line3[axis+2]
  403. # 判断两条线有无交集
  404. if min_axis <= line2[axis] <= max_axis \
  405. or min_axis <= line2[axis+2] <= max_axis:
  406. merged_line.append(line2)
  407. used_lines.append(line2)
  408. if merged_line:
  409. merged_lines.append(merged_line)
  410. # 合并line
  411. result_lines = []
  412. for merged_line in merged_lines:
  413. # 获取line宽的平均值
  414. axis_average = 0
  415. for line in merged_line:
  416. axis_average += line[1-axis]
  417. axis_average = int(axis_average/len(merged_line))
  418. # 获取最长line两端
  419. merged_line.sort(key=lambda x: (x[axis]))
  420. axis_start = merged_line[0][axis]
  421. merged_line.sort(key=lambda x: (x[axis+2]))
  422. axis_end = merged_line[-1][axis+2]
  423. if axis:
  424. result_lines.append([axis_average, axis_start, axis_average, axis_end])
  425. else:
  426. result_lines.append([axis_start, axis_average, axis_end, axis_average])
  427. return result_lines
  428. def fix_gap(rows, cols):
  429. def calculate_line_equation(lines):
  430. """
  431. 根据line的两点式求line的一般式方程
  432. :param lines:
  433. :return:
  434. """
  435. line_equations = {}
  436. for line in lines:
  437. point1 = (line[0], line[1])
  438. point2 = (line[2], line[3])
  439. A = point2[1] - point1[1]
  440. B = point1[0] - point2[0]
  441. C = point2[0] * point1[1] - point1[0] * point2[1]
  442. line_equation = {"A": A, "B": B, "C": C}
  443. line_equations[str(line)] = line_equation
  444. return line_equations
  445. def calculate_point_line_distance(point, line_equation):
  446. """
  447. 计算点到直线距离
  448. :param point:
  449. :param line_equation: line的一般式方程 {A:, B:, C:}
  450. :return: 距离
  451. """
  452. A = line_equation.get("A")
  453. B = line_equation.get("B")
  454. C = line_equation.get("C")
  455. if A == 0.:
  456. distance = abs(point[1] + C / B)
  457. elif B == 0.:
  458. distance = abs(point[0] + C / A)
  459. else:
  460. distance = abs(A * point[0] + B * point[1] + C) / \
  461. math.sqrt(math.pow(A, 2) + math.pow(B, 2))
  462. return distance
  463. def get_point_projection(point, line_equation):
  464. """
  465. 获取点到直线的投影
  466. :param point:
  467. :param line_equation: line的一般式方程 {A:, B:, C:}
  468. :return: 投影点坐标
  469. """
  470. A = line_equation.get("A")
  471. B = line_equation.get("B")
  472. C = line_equation.get("C")
  473. x0 = point[0]
  474. y0 = point[1]
  475. if A == 0.:
  476. x1 = x0
  477. y1 = -((A * x1 + C) / B)
  478. elif B == 0.:
  479. y1 = y0
  480. x1 = -((B * y1 + C) / A)
  481. return (x1, y1)
  482. def is_point_at_line(point, lines, axis=0):
  483. for line in lines:
  484. if point[axis] == line[axis]:
  485. print("line", line, point)
  486. if line[1-axis] <= point[1-axis] <= line[1-axis+2]:
  487. return True
  488. return False
  489. def connect_point_to_line(point, lines, line_equations, axis=0):
  490. distances = []
  491. # 找一条离点最近的线
  492. for line in lines:
  493. # 获取line方程
  494. line_equation = line_equations.get(str(line))
  495. # 计算距离
  496. distance = calculate_point_line_distance(point, line_equation)
  497. distances.append([line, distance])
  498. distances.sort(key=lambda x: x[1])
  499. connect_line = distances[0][0]
  500. print("connect_line", connect_line)
  501. print("distances[0]", distances[0])
  502. print("line_equation", line_equations.get(str(connect_line)))
  503. # 求点到直线的投影点,作为新的点返回
  504. new_point = get_point_projection(point, line_equations.get(str(connect_line)))
  505. return new_point
  506. # 计算所有line方程
  507. rows_equations = calculate_line_equation(rows)
  508. cols_equations = calculate_line_equation(cols)
  509. # 对任意一条line判断两端是否在其他垂直line上
  510. new_rows = []
  511. for line in rows:
  512. point1 = [line[0], line[1]]
  513. point2 = [line[2], line[3]]
  514. flag1 = is_point_at_line(point1, cols, axis=1)
  515. flag2 = is_point_at_line(point2, cols, axis=1)
  516. print("flag1, flag2", flag1, flag2)
  517. if flag1 and flag2:
  518. new_rows.append(line)
  519. elif flag1 and not flag2:
  520. new_point2 = connect_point_to_line(point2, cols, cols_equations, axis=1)
  521. new_rows.append([point1[0], point1[1],
  522. math.ceil(new_point2[0]), math.ceil(new_point2[1])
  523. ])
  524. print("new_point2", new_point2, point2)
  525. elif not flag1 and flag2:
  526. new_point1 = connect_point_to_line(point1, cols, cols_equations, axis=1)
  527. new_rows.append([math.floor(new_point1[0]), math.floor(new_point1[1]),
  528. point2[0], point2[1]
  529. ])
  530. print("new_point1", new_point1, point1)
  531. else:
  532. new_rows.append(line)
  533. new_cols = []
  534. for line in cols:
  535. point1 = [line[0], line[1]]
  536. point2 = [line[2], line[3]]
  537. flag1 = is_point_at_line(point1, rows, axis=0)
  538. flag2 = is_point_at_line(point2, rows, axis=0)
  539. if flag1 and flag2:
  540. new_cols.append(line)
  541. elif flag1 and not flag2:
  542. new_point2 = connect_point_to_line(point2, rows, rows_equations, axis=1)
  543. new_cols.append([point1[0], point1[1],
  544. math.ceil(new_point2[0]), math.ceil(new_point2[1])
  545. ])
  546. elif not flag1 and flag2:
  547. new_point1 = connect_point_to_line(point1, rows, rows_equations, axis=1)
  548. new_cols.append([math.floor(new_point1[0]), math.floor(new_point1[1]),
  549. point2[0], point2[1]
  550. ])
  551. else:
  552. new_cols.append(line)
  553. return new_rows, new_cols
  554. def get_points(row_lines, col_lines, image_size):
  555. """
  556. :param row_lines: 所有区域rows
  557. :param col_lines: 所有区域cols
  558. :param image_size: (h, w)
  559. :return: rows、cols交点
  560. """
  561. # 创建空图
  562. row_img = np.zeros(image_size, np.uint8)
  563. col_img = np.zeros(image_size, np.uint8)
  564. # 画线
  565. thresh = 3
  566. for row in row_lines:
  567. cv2.line(row_img, (int(row[0]-thresh), int(row[1])), (int(row[2]+thresh), int(row[3])), (255, 255, 255), 1)
  568. for col in col_lines:
  569. cv2.line(col_img, (int(col[0]), int(col[1]-thresh)), (int(col[2]), int(col[3]+thresh)), (255, 255, 255), 1)
  570. # 求出交点
  571. point_img = np.bitwise_and(row_img, col_img)
  572. # cv2.imshow("point_img", np.bitwise_not(point_img))
  573. # cv2.waitKey(0)
  574. # 识别黑白图中的白色交叉点,将横纵坐标取出
  575. ys, xs = np.where(point_img > 0)
  576. points = []
  577. for i in range(len(xs)):
  578. points.append((xs[i], ys[i]))
  579. points.sort(key=lambda x: (x[0], x[1]))
  580. return points
  581. def get_split_line(cols, image_size):
  582. """
  583. 解决一张图中多个表格,求出分割区域的线。(最多分割3个表格)
  584. :param cols: 所有区域cols
  585. :param image_size: (h, w)
  586. :return: 分割区域的线及其纵坐标
  587. """
  588. cols.sort(key=lambda x: (x[0], x[1]))
  589. standard_col = cols[0]
  590. split_col = []
  591. for col in cols:
  592. # 判断col是否与standard_col重合,重合则跳过
  593. if standard_col[1] <= col[1] <= standard_col[3] \
  594. or standard_col[1] <= col[3] <= standard_col[3]:
  595. # 获取standard col最大长度
  596. standard_col = [standard_col[0], min([standard_col[1], col[1]]),
  597. standard_col[2], max([standard_col[3], col[3]])]
  598. continue
  599. # 不重合则将standard col加入,不重合的col作为新的standard col
  600. else:
  601. # 判断该standard col与split_col里有无重合
  602. append_flag = 1
  603. for sc in split_col:
  604. if standard_col[1] <= sc[1] <= standard_col[3] \
  605. or standard_col[1] <= sc[3] <= standard_col[3]:
  606. append_flag = 0
  607. break
  608. if append_flag:
  609. split_col.append(standard_col)
  610. standard_col = col
  611. # 判断有3条线后跳出
  612. if len(split_col) == 3:
  613. break
  614. split_y = [0+5, image_size[0]-5]
  615. for col in split_col:
  616. if col[1]-5 > 0:
  617. y_min = col[1]-5
  618. split_y.append(int(y_min))
  619. if col[3]+5 < image_size[0]:
  620. y_max = col[3]+5
  621. split_y.append(int(y_max))
  622. split_y = list(set(split_y))
  623. split_y.sort(key=lambda x: x)
  624. return split_y
  625. def get_point_area(points, split_y):
  626. """
  627. :param points:所有区域points
  628. :param split_y: 区域分割线纵坐标
  629. :return: 多个区域points list
  630. """
  631. point_area_list = []
  632. for i in range(1, len(split_y)):
  633. area = (split_y[i-1], split_y[i])
  634. points.sort(key=lambda x: (x[1], x[0]))
  635. point_area = []
  636. for p in points:
  637. if area[0] <= p[1] <= area[1]:
  638. point_area.append(p)
  639. point_area.sort(key=lambda x: (x[0], x[1]))
  640. point_area_list.append(point_area)
  641. return point_area_list
  642. def get_line_area(lines, split_y):
  643. line_area_list = []
  644. for i in range(1, len(split_y)):
  645. area = (split_y[i-1], split_y[i])
  646. lines.sort(key=lambda x: (x[1], x[3]))
  647. line_area = []
  648. for l in lines:
  649. if area[0] <= l[1] and l[3] <= area[1]:
  650. line_area.append(l)
  651. line_area.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
  652. line_area_list.append(line_area)
  653. return line_area_list
  654. def fix_outline_area(rows_area, cols_area, points_area):
  655. """
  656. 解决表格本身无左右两边或无上下两边的情况,修补表格
  657. :param rows_area: 单个区域rows
  658. :param cols_area: 单个区域cols
  659. :param points_area: 单个区域points
  660. :return: 补线后的新rows、cols、points
  661. """
  662. # 通过rows,cols 取表格的四条边(会有超出表格部分)
  663. rows_area.sort(key=lambda x: (x[1], x[0]))
  664. # print(area)
  665. up_line1 = rows_area[0]
  666. bottom_line1 = rows_area[-1]
  667. cols_area.sort(key=lambda x: (x[0], x[1]))
  668. left_line1 = cols_area[0]
  669. right_line1 = cols_area[-1]
  670. print("left_line1", left_line1)
  671. print("right_line1", right_line1)
  672. # 通过points 取表格的四条边(无超出表格部分)
  673. points_area.sort(key=lambda x: (x[0], x[1]))
  674. left_up = points_area[0]
  675. right_bottom = points_area[-1]
  676. up_line2 = [left_up[0], left_up[1], right_bottom[0], left_up[1]]
  677. bottom_line2 = [left_up[0], right_bottom[1], right_bottom[0], right_bottom[1]]
  678. left_line2 = [left_up[0], left_up[1], left_up[0], right_bottom[1]]
  679. right_line2 = [right_bottom[0], left_up[1], right_bottom[0], right_bottom[1]]
  680. # 判断超出部分的长度,超出一定长度就补线
  681. new_row_lines = []
  682. new_col_lines = []
  683. longer_row_lines = []
  684. longer_col_lines = []
  685. # 补左右两条竖线超出来的线的row
  686. if left_line2[1] - left_line1[1] >= 30 and right_line2[1] - right_line1[1] >= 30:
  687. new_row_lines.append([left_line1[0], left_line1[1], right_line1[0], left_line1[1]])
  688. # 补了row,要将其他短的col连到row上
  689. new_col_y = min([left_line1[1], right_line1[1]])
  690. for col in cols_area:
  691. longer_col_lines.append([col[0], min([new_col_y, col[1]]), col[2], col[3]])
  692. if left_line1[3] - left_line2[3] >= 30 and right_line1[3] - right_line2[3] >= 30:
  693. new_row_lines.append([left_line1[2], left_line1[3], right_line1[2], left_line1[3]])
  694. # 补了row,要将其他短的col连到row上
  695. new_col_y = max([left_line1[3], right_line1[3]])
  696. for col in cols_area:
  697. longer_col_lines.append([col[0], col[1], col[2], max([new_col_y, col[3]])])
  698. # 补上下两条横线超出来的线的col
  699. if up_line2[0] - up_line1[0] >= 30 and bottom_line2[0] - bottom_line1[0] >= 30:
  700. new_col_lines.append([up_line1[0], up_line1[1], up_line1[0], bottom_line1[1]])
  701. # 补了col,要将其他短的row连到col上
  702. new_row_x = min([up_line1[0], bottom_line1[0]])
  703. for row in rows_area:
  704. longer_row_lines.append([min([new_row_x, row[0]]), row[1], row[2], row[3]])
  705. if up_line1[2] - up_line2[2] >= 30 and bottom_line1[2] - bottom_line2[2] >= 30:
  706. new_col_lines.append([up_line1[2], up_line1[3], up_line1[2], bottom_line1[3]])
  707. # 补了col,要将其他短的row连到col上
  708. new_row_x = max([up_line1[2], bottom_line1[2]])
  709. for row in rows_area:
  710. longer_row_lines.append([row[0], row[1], max([new_row_x, row[2]]), row[3]])
  711. return new_row_lines, new_col_lines, longer_row_lines, longer_col_lines
  712. def post_process():
  713. return