get_table_by_bbox.py 127 KB


  1. import base64
  2. import copy
  3. import json
  4. import logging
  5. import math
  6. import random
  7. import re
  8. import traceback
  9. from glob import glob
  10. import cv2
  11. from sklearn.cluster import AffinityPropagation, DBSCAN
  12. # from tensorflow_version.table_head_predict import predict
  13. from botr.utils import request_post, line_iou, pil_resize, get_best_predict_size2, line_overlap
  14. import jieba
  15. import numpy as np
  16. from matplotlib import pyplot as plt
  17. def _plot(_line_list, mode=1):
  18. for _line in _line_list:
  19. if mode == 1:
  20. x0, y0, x1, y1 = _line.__dict__.get("bbox")
  21. elif mode == 2:
  22. x0, y0, x1, y1 = _line
  23. elif mode == 3:
  24. x0, y0 = _line[0]
  25. x1, y1 = _line[1]
  26. plt.plot([x0, x1], [y0, y1])
  27. plt.show()
  28. return
  29. def get_table_by_rule2(img, text_list, bbox_list, table_location, is_test=0):
  30. # 处理bbox,缩小框
  31. bbox_list = shrink_bbox(img, bbox_list)
  32. # 创建对应dict
  33. bbox_text_dict = {}
  34. for i in range(len(text_list)):
  35. bbox_text_dict[str(bbox_list[i])] = text_list[i]
  36. # 获取全局的按行排列bbox
  37. row_list = get_table_rows(bbox_list, bbox_text_dict)
  38. if len(row_list) == 0:
  39. return [], [], []
  40. # 删除只有一个bbox的第一行和最后一行
  41. if len(row_list[0]) == 1:
  42. table_location = [table_location[0], row_list[0][0][2][1],
  43. table_location[2], table_location[3]]
  44. row_list = row_list[1:]
  45. if len(row_list[-1]) == 1:
  46. table_location = [table_location[0], table_location[1],
  47. table_location[2], row_list[-1][0][0][1]]
  48. row_list = row_list[:-1]
  49. # 获取表格区域,以及区域里的按行排列bbox
  50. table_location_list = [[[int(table_location[0]), int(table_location[1])], [int(table_location[2]), int(table_location[3])]]]
  51. area_row_list = [row_list]
  52. area_row_list = merge_row_bbox_list(area_row_list)
  53. # 获取全局的按列排列bbox
  54. area_col_list = get_table_cols(bbox_list, table_location_list)
  55. # 获取行线、列线
  56. area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
  57. if is_test:
  58. _plot(area_row_lines[0] + area_col_lines[0], mode=3)
  59. # 判断列线合法
  60. area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict)
  61. # 判断行线合法
  62. area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict)
  63. if is_test:
  64. _plot(area_row_lines[0] + area_col_lines[0], mode=3)
  65. # 由线得到按行列排列的bbox
  66. area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list)
  67. if is_test:
  68. for a in area_table_cell_list:
  69. for r in a:
  70. for c in r:
  71. cv2.rectangle(img, c[0], c[1], (255, 0, 0), 1)
  72. # for b in c:
  73. # cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1)
  74. cv2.imshow('table_cell', img)
  75. cv2.waitKey(0)
  76. # 展示
  77. if is_test:
  78. show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list)
  79. if not area_row_lines or not area_col_lines:
  80. return [], [], []
  81. line_list = [[x[0][0], x[0][1], x[1][0], x[1][1]] for x in area_row_lines[0] + area_col_lines[0]]
  82. cell_list = area_table_cell_list[0]
  83. return line_list, cell_list, table_location
  84. def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=1):
  85. # 处理bbox,缩小框
  86. bbox_list = shrink_bbox(img, bbox_list)
  87. # 创建对应dict
  88. bbox_text_dict = {}
  89. for i in range(len(text_list)):
  90. bbox_text_dict[str(bbox_list[i])] = text_list[i]
  91. # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox
  92. table_left_up_point = [table_location[0], table_location[1]]
  93. min_distance = 100000000000
  94. first_bbox = bbox_list[0]
  95. for bbox in bbox_list:
  96. distance = abs(bbox[0][0] - table_left_up_point[0]) + abs(bbox[0][1] - table_left_up_point[1])
  97. if distance < min_distance:
  98. min_distance = distance
  99. first_bbox = bbox
  100. # 对first_bbox预处理
  101. # 分割
  102. new_bbox_list, bbox_text_dict = split_bbox(img, first_bbox, bbox_text_dict)
  103. if new_bbox_list:
  104. if first_bbox in bbox_list:
  105. bbox_list.remove(first_bbox)
  106. bbox_list += new_bbox_list
  107. new_bbox_list.sort(key=lambda x: (x[0][0]))
  108. first_bbox = new_bbox_list[0]
  109. # 根据第一个bbox,得到第一行
  110. first_row = []
  111. bbox_list.sort(key=lambda x: (x[0][1], x[0][0]))
  112. for bbox in bbox_list:
  113. # h有交集
  114. if first_bbox[0][1] <= bbox[0][1] <= first_bbox[2][1] \
  115. or first_bbox[0][1] <= bbox[2][1] <= first_bbox[2][1] \
  116. or bbox[0][1] <= first_bbox[0][1] <= bbox[2][1] \
  117. or bbox[0][1] <= first_bbox[2][1] <= bbox[2][1]:
  118. first_row.append(bbox)
  119. # h小于first_box
  120. elif bbox[2][1] <= first_bbox[0][1]:
  121. first_row.append(bbox)
  122. # 对第一行分列
  123. first_row.sort(key=lambda x: (x[0][0], x[0][1]))
  124. first_row_col = []
  125. used_bbox = []
  126. for bbox in first_row:
  127. if bbox in used_bbox:
  128. continue
  129. temp_col = []
  130. for bbox1 in first_row:
  131. if bbox1 in used_bbox:
  132. continue
  133. if bbox1[0][0] <= bbox[0][0] <= bbox1[2][0] \
  134. or bbox1[0][0] <= bbox[2][0] <= bbox1[2][0] \
  135. or bbox[0][0] <= bbox1[0][0] <= bbox[2][0] \
  136. or bbox[0][0] <= bbox1[2][0] <= bbox[2][0]:
  137. temp_col.append(bbox1)
  138. used_bbox.append(bbox1)
  139. first_row_col.append(temp_col)
  140. # 根据第一个bbox,得到第一列
  141. first_col = []
  142. bbox_list.sort(key=lambda x: (x[0][0], x[0][1]))
  143. for bbox in bbox_list:
  144. # w有交集
  145. if first_bbox[0][0] <= bbox[0][0] <= first_bbox[2][0] \
  146. or first_bbox[0][0] <= bbox[2][0] <= first_bbox[2][0] \
  147. or bbox[0][0] <= first_bbox[0][0] <= bbox[2][0] \
  148. or bbox[0][0] <= first_bbox[2][0] <= bbox[2][0]:
  149. first_col.append(bbox)
  150. # w小于first_box
  151. elif bbox[2][0] <= first_bbox[0][0]:
  152. first_col.append(bbox)
  153. # 对第一列分行
  154. first_col.sort(key=lambda x: (x[0][1], x[0][0]))
  155. first_col_row = []
  156. current_bbox = first_col[0]
  157. temp_row = []
  158. for bbox in first_col:
  159. if current_bbox[0][1] <= bbox[0][1] <= current_bbox[2][1] \
  160. or current_bbox[0][1] <= bbox[2][1] <= current_bbox[2][1] \
  161. or bbox[0][1] <= current_bbox[0][1] <= bbox[2][1] \
  162. or bbox[0][1] <= current_bbox[2][1] <= bbox[2][1]:
  163. temp_row.append(bbox)
  164. else:
  165. if temp_row:
  166. temp_row.sort(key=lambda x: x[0][1])
  167. first_col_row.append(temp_row)
  168. temp_row = [bbox]
  169. current_bbox = bbox
  170. if temp_row:
  171. temp_row.sort(key=lambda x: x[0][1])
  172. first_col_row.append(temp_row)
  173. print('len(first_row)', len(first_row))
  174. print('first_row', [bbox_text_dict.get(str(x)) for x in first_row])
  175. print('first_col', [bbox_text_dict.get(str(x)) for x in first_col])
  176. print('len(first_col)', len(first_col))
  177. print('len(first_row_col)', len(first_row_col))
  178. print('len(first_col_row)', len(first_col_row))
  179. # 划线 列
  180. col_line_list = []
  181. for col in first_row_col:
  182. # 画2条线,根据左右bbox
  183. min_w, max_w = 1000000, 0
  184. print('col', [bbox_text_dict.get(str(x)) for x in col])
  185. for bbox in col:
  186. if bbox[0][0] < min_w:
  187. min_w = bbox[0][0]
  188. if bbox[2][0] > max_w:
  189. max_w = bbox[2][0]
  190. col_line_list.append([min_w, table_location[1], min_w, table_location[3]])
  191. col_line_list.append([max_w, table_location[1], max_w, table_location[3]])
  192. # 划线 行
  193. row_line_list = []
  194. last_max_h = None
  195. for row in first_col_row:
  196. # 画3条线,根据上下bbox
  197. min_h, max_h = 1000000, 0
  198. for bbox in row:
  199. if bbox[0][1] < min_h:
  200. min_h = bbox[0][1]
  201. if bbox[2][1] > max_h:
  202. max_h = bbox[2][1]
  203. row_line_list.append([table_location[0], min_h, table_location[2], min_h])
  204. row_line_list.append([table_location[0], max_h, table_location[2], max_h])
  205. # if last_max_h:
  206. # row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)])
  207. last_max_h = max_h
  208. print('len(col_line_list)', len(col_line_list))
  209. print('col_line_list', col_line_list)
  210. print('len(row_line_list)', len(row_line_list))
  211. # 判断列线有没有压在黑色像素上,若有则移动
  212. temp_list = []
  213. for i in range(1, len(col_line_list), 2):
  214. # 前一列右边线
  215. line1 = col_line_list[i]
  216. line1 = [int(x) for x in line1]
  217. # 后一列左边线
  218. if i+1 >= len(col_line_list):
  219. break
  220. line2 = col_line_list[i+1]
  221. line2 = [int(x) for x in line2]
  222. max_black_cnt = 10
  223. black_threshold = 150
  224. black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold)
  225. print('col black_cnt2', black_cnt2)
  226. if black_cnt2 <= max_black_cnt:
  227. temp_list.append(line2)
  228. else:
  229. black_cnt1 = count_black(img[line1[1]:line1[3], line1[0]:line1[2]+1, :], threshold=black_threshold)
  230. print('col black_cnt1', black_cnt1)
  231. if black_cnt1 <= max_black_cnt:
  232. temp_list.append(line1)
  233. else:
  234. # 两条线都不符合,从右向左移寻找
  235. for j in range(line2[0], line1[0], -1):
  236. black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold)
  237. print('col black_cnt', black_cnt)
  238. if black_cnt <= max_black_cnt:
  239. temp_list.append([j, line2[1], j, line2[3]])
  240. break
  241. col_line_list = temp_list
  242. # 根据列的划线对bbox分列
  243. last_line = [0, 0, 0, 0]
  244. col_bbox_list = []
  245. # used_bbox_list = []
  246. for line in col_line_list + [[img.shape[0], 0, img.shape[0], 0]]:
  247. col = []
  248. for bbox in bbox_list:
  249. # if bbox in used_bbox_list:
  250. # continue
  251. # print('last_line, line, bbox', last_line, line, bbox)
  252. iou = line_iou([[last_line[0], 0], [line[0], 0]], [[bbox[0][0], 0], [bbox[2][0], 0]], axis=0)
  253. if iou >= 0.6:
  254. col.append(bbox)
  255. # used_bbox_list.append(bbox)
  256. col.sort(key=lambda x: x[0][1])
  257. col_bbox_list.append(col)
  258. last_line = line
  259. # 判断行线
  260. temp_list = []
  261. for i in range(1, len(row_line_list), 2):
  262. # 前一行下边线
  263. line1 = row_line_list[i]
  264. line1 = [int(x) for x in line1]
  265. # 后一行上边线
  266. if i+1 >= len(row_line_list):
  267. break
  268. line2 = row_line_list[i+1]
  269. line2 = [int(x) for x in line2]
  270. # 判断行线之间的bbox分别属于哪一行
  271. sub_bbox_list = []
  272. threshold = 5
  273. for bbox in bbox_list:
  274. if line1[1] - threshold <= bbox[0][1] <= bbox[2][1] <= line2[1]+threshold:
  275. sub_bbox_list.append(bbox)
  276. # 根据行的h和分列判断bbox属于上一行还是下一行
  277. line1_bbox_list = []
  278. line2_bbox_list = []
  279. if sub_bbox_list:
  280. sub_bbox_list.sort(key=lambda x: x[0][1])
  281. min_h = sub_bbox_list[0][0][1] - 1
  282. max_h = sub_bbox_list[-1][2][1] + 1
  283. for bbox in sub_bbox_list:
  284. # 找到属于哪一列
  285. current_col = None
  286. for col in col_bbox_list:
  287. if bbox in col:
  288. current_col = copy.deepcopy(col)
  289. break
  290. if current_col:
  291. # 行做成bbox加入列作为基准
  292. line1_bbox = [[0, min_h], [], [0, min_h], []]
  293. line2_bbox = [[0, max_h], [], [0, max_h], []]
  294. current_col += [line1_bbox, line2_bbox]
  295. current_col.sort(key=lambda x: x[0][1])
  296. bbox_index = current_col.index(bbox)
  297. line1_bbox_index = current_col.index(line1_bbox)
  298. line2_bbox_index = current_col.index(line2_bbox)
  299. print('current_col', [bbox_text_dict.get(str(x)) for x in current_col])
  300. print('line1_bbox_index, bbox_index, line2_bbox_index', line1_bbox_index, bbox_index, line2_bbox_index)
  301. # 计算距离
  302. distance1 = 10000
  303. for index in range(line1_bbox_index, bbox_index):
  304. h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
  305. h2 = (current_col[index+1][0][1] + current_col[index+1][2][1]) / 2
  306. # print(bbox_text_dict.get())
  307. distance1 = abs(h1 - h2)
  308. distance2 = 10000
  309. for index in range(line2_bbox_index, bbox_index, -1):
  310. h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
  311. h2 = (current_col[index-1][0][1] + current_col[index-1][2][1]) / 2
  312. distance2 = abs(h1 - h2)
  313. print(bbox_text_dict.get(str(bbox)), distance1, distance2)
  314. ratio = 1.5
  315. # 属于下一行
  316. if distance1 >= distance2 * ratio or distance1 >= distance2 + 8:
  317. line2_bbox_list.append(bbox)
  318. # 属于上一行
  319. elif distance2 >= distance1 * ratio or distance2 >= distance1 + 8:
  320. line1_bbox_list.append(bbox)
  321. else:
  322. print('距离不明确,需要nsp模型介入判断')
  323. if line1_bbox_list:
  324. print('line1_bbox_list', [bbox_text_dict.get(str(x)) for x in line1_bbox_list])
  325. line1_bbox_list.sort(key=lambda x: x[0][1])
  326. b = line1_bbox_list[-1]
  327. line1 = [line1[0], b[2][1], line1[2], b[2][1]]
  328. if line2_bbox_list:
  329. print('line2_bbox_list', [bbox_text_dict.get(str(x)) for x in line2_bbox_list])
  330. line2_bbox_list.sort(key=lambda x: x[0][1])
  331. b = line2_bbox_list[0]
  332. line2 = [line2[0], b[0][1], line2[2], b[0][1]]
  333. _line = [line1[0], (line1[1]+line2[1])/2, line1[2], (line1[3]+line2[3])/2]
  334. _line = [int(x) for x in _line]
  335. temp_list.append(_line)
  336. row_line_list = temp_list
  337. # 加上表格轮廓线
  338. row_line_list.append([table_location[0], table_location[1], table_location[2], table_location[1]])
  339. row_line_list.append([table_location[0], table_location[3], table_location[2], table_location[3]])
  340. col_line_list.append([table_location[0], table_location[1], table_location[0], table_location[3]])
  341. col_line_list.append([table_location[2], table_location[1], table_location[2], table_location[3]])
  342. # 由线得到按行列排列的bbox
  343. area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, [row_line_list], [col_line_list], [table_location], bbox_list)
  344. # show
  345. if is_test:
  346. for line in col_line_list:
  347. cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2)
  348. for line in row_line_list:
  349. cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2)
  350. cv2.namedWindow('img', cv2.WINDOW_NORMAL)
  351. cv2.imshow('img', cv2.resize(img, (768, 1024)))
  352. cv2.waitKey(0)
  353. return [], [], []
  354. def split_bbox_by_kmeans(img, bbox, bbox_text_dict):
  355. sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
  356. # 从左至右扫描
  357. def get_table():
  358. # 1. 一个单元格多行合并需解决 √
  359. # 2. 一行多个单字合并 1007.jpg √
  360. # 3. ocr识别错误bbox剔除
  361. # 4. 上下表格合并 距离近,列数一样,或只少了第一列 1005.jpg 1014.jpg 1033.jpg √
  362. # 5. 相近行列线合并 1020.jpg 1025.jpg 1054.jpg 1068.jpg
  363. # 6. 行线在合并bbox中间,需向上或向下移动 105.jpg 1054.jpg 1020.jpg
  364. # 7. 贴着左边框的长bbox也当做标题分开表格 1047.jpg 1059.jpg √
  365. # 8. 判断非规整表格,单个单元格多个bbox,排除上下连接的bbox 105.jpg
  366. # 9. 判断非规整表格,ocr识别漏,黑色像素多 1050.jpg √
  367. # 10. 第一列序号ocr识别漏 1051.jpg
  368. # 11. 用其他列作为分行标准,作为辅助,挑平均间隔最大的,行数也够的列 1085.jpg
  369. # 12. 判断表格 两个bbox靠的太近的不能作为开始行 1106.jpg √
  370. # 13. 列中所有行间隔都很小,聚类距离统一值 1098.jpg √
  371. # 14. 漏列(需剔除表格中非表格部分) 1059.jpg
  372. # 15. 漏行 1064.jpg 1065.jpg 1067.jpg 1085.jpg 1097.jpg 1101.jpg √
  373. # 16. 表格分割错误 1045.jpg 1051.jpg 1078.jpg 1079.jpg √
  374. # 17. 分列时,第一行的表头选定 1051.jpg 1106.jpg 1129.jpg
  375. # 18. 分割同一行中多个列 1093.jpg 1095.jpg 110.jpg
  376. # 19. 表格漏了 1119.jpg 1141.jpg
  377. # 20. 非规整表格判断错误,黑色像素 1122.jpg 1121.jpg √
  378. # 21. 分列错误 1125.jpg 1158.jpg 1020.jpg √
  379. # 22. 分行分列错误(需在第一列排除过长bbox) 1131.jpg 1132.jpg √
  380. # 1135.jpg 1136.jpg 1147.jpg
  381. # 23. 表格范围外,与单元格内的文字上下相连 1134.jpg 1142.jpg
  382. # 24. 第一列空单元格太多可列为非规整
  383. # 25. 竖线跨越多个bbox的较中心位置,考虑剔除
  384. # 26. 竖线跨越bbox,考虑竖线缩短,将跨越的那一截去掉 1020.jpg
  385. # 27. 竖线插在一列中间,需调整其向右找到空白位置 1023.jpg
  386. # label_path = glob('../data/borderless_tables/*_label.jpg')
  387. # temp_label_path = []
  388. # label_row_dict = {}
  389. # for p in label_path:
  390. # img = cv2.imread(p)
  391. # row_img, col_img = get_lines_from_img(img)
  392. # label_row_list, is_standard = get_bbox_by_img(row_img, col_img)
  393. # label_row_dict[p] = label_row_list
  394. # if is_standard:
  395. # temp_label_path.append(p)
  396. # label_path = temp_label_path
  397. # print('len(label_path)', len(label_path))
  398. # for p in label_path:
  399. # print(p)
  400. with open('standard_table.txt', 'r') as f:
  401. label_path_list = f.readlines()
  402. # paths = glob('../data/borderless_tables/1.jpg') # merge_row
  403. # paths = glob('../data/borderless_tables/5.jpg') # title
  404. # paths = glob('../data/borderless_tables/26.jpg') # merge_col
  405. paths = glob('../data/borderless_tables/59.jpg') # split bbox
  406. paths = glob('../../hrnet-pytorch-main/my_dataset/borderless_tables/62.jpg')
  407. # paths = glob('../data/borderless_tables/57.jpg')
  408. paths = glob('../../hrnet-pytorch-main/my_dataset/borderless_tables/3.jpg') # not standard table
  409. # paths = glob(r'C:\Users\Administrator\Desktop\test_pdf_table\1.png')
  410. # label_path_list.append(r'C:\Users\Administrator\Desktop\test_pdf_table\1_label.jpg\n')
  411. paths = glob('../data/borderless_tables/*.jpg')
  412. # paths = glob('../data/standard_tables/*.jpg')
  413. path_cnt = 0
  414. all_teds = 0
  415. all_standard_cnt = 0
  416. for p in paths:
  417. if 'label' in p:
  418. continue
  419. label_p = p[:-4] + '_label.jpg\n'
  420. if label_p not in label_path_list:
  421. continue
  422. # if path_cnt <= 10:
  423. # path_cnt += 1
  424. # continue
  425. path_cnt += 1
  426. img = cv2.imread(p)
  427. result = test_ocr_model(p)
  428. print(p)
  429. # print(result)
  430. bbox_list = eval(result.get('bbox'))
  431. text_list = eval(result.get('text'))
  432. bbox_text_dict = {}
  433. for i in range(len(text_list)):
  434. bbox_text_dict[str(bbox_list[i])] = text_list[i]
  435. # split_bbox(img, text_list, bbox_list)
  436. # 获取全局的按行排列bbox
  437. row_list = get_table_rows(bbox_list)
  438. # bbox预处理
  439. bbox_list, text_list, bbox_text_dict = bbox_preprocess(bbox_list, text_list, row_list, bbox_text_dict)
  440. # bbox处理后再按行排列bbox
  441. row_list = get_table_rows(bbox_list)
  442. # 获取表格区域,以及区域里的按行排列bbox
  443. table_location_list, area_row_list = get_table_location(row_list)
  444. # 表格分割
  445. table_location_list, area_row_list = split_table(table_location_list, area_row_list, bbox_text_dict)
  446. table_location_list, area_row_list = split_table(table_location_list, area_row_list, bbox_text_dict)
  447. print('fix_table_location_list', table_location_list)
  448. # print('fix_area_row_list', area_row_list)
  449. # 获取表格区域里,按列排序bbox
  450. area_col_list = get_table_cols(bbox_list, table_location_list)
  451. # 合并一列中多行bbox
  452. area_row_list = merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list)
  453. # 排除非规整表格
  454. table_standard_list = delete_not_standard_table(img, area_row_list, area_col_list, table_location_list, bbox_list, bbox_text_dict)
  455. # 上下表格合并
  456. area_row_list, area_col_list, table_location_list = merge_table(area_row_list, area_col_list, table_location_list, bbox_list)
  457. # 获取行线、列线
  458. area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
  459. # 根据行列线生成对应bbox行列
  460. area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list)
  461. # 添加列线
  462. add_area_col_lines = add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict)
  463. for j in range(len(area_col_lines)):
  464. area_col_lines[j] += add_area_col_lines[j]
  465. # 判断列线合法
  466. area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict)
  467. area_col_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=1)
  468. area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0)
  469. #
  470. # for a in area_col_list:
  471. # for c in a:
  472. # print('area_col_list', [bbox_text_dict.get(str(x)) for x in c])
  473. #
  474. # # 合并一列中多行bbox
  475. # area_row_list = merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list)
  476. #
  477. # # 获取行线、列线
  478. # area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
  479. #
  480. # add_area_col_lines = add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict)
  481. #
  482. # for j in range(len(area_col_lines)):
  483. # area_col_lines[j] += add_area_col_lines[j]
  484. #
  485. # area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list)
  486. #
  487. # area_col_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=1)
  488. # area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0)
  489. #
  490. #
  491. add_area_row_lines = add_row_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict, area_row_lines)
  492. for j in range(len(area_row_lines)):
  493. area_row_lines[j] += add_area_row_lines[j]
  494. #
  495. area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict)
  496. # 合并相近线
  497. for j in range(len(area_col_lines)):
  498. area_col_lines[j] = merge_lines(area_col_lines[j], axis=1)
  499. area_row_lines[j] = merge_lines(area_row_lines[j], axis=0)
  500. # area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list)
  501. # area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list)
  502. # # 重新生成
  503. # table_location_list = []
  504. # temp_area_row_list = []
  505. # for temp_row_list in area_row_list:
  506. # location_list, temp_row_list = get_table_location(temp_row_list)
  507. # table_location_list += location_list
  508. # temp_area_row_list += temp_row_list
  509. # area_col_list = get_table_cols(bbox_list, table_location_list)
  510. # area_row_list = temp_area_row_list
  511. #
  512. # # 获取行线、列线
  513. # area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
  514. #
  515. # print('len(table_location_list)', len(table_location_list))
  516. # for bbox in bbox_list:
  517. # cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])),
  518. # (0, 0, 255), 1)
  519. #
  520. # for i in range(len(table_location_list)):
  521. # # location = table_location_list[i]
  522. # # cv2.rectangle(img, location[0], location[1], (0, 255, 0), 1)
  523. #
  524. # row_lines = area_row_lines[i]
  525. # col_lines = area_col_lines[i]
  526. # for r in row_lines:
  527. # cv2.line(img, r[0], r[1], (0, 255, 0), 1)
  528. # for c in col_lines:
  529. # cv2.line(img, c[0], c[1], (0, 255, 0), 1)
  530. #
  531. # cv2.imshow('img', img)
  532. # cv2.waitKey(0)
  533. # 计算标注表格和生成表格的相似度
  534. if len(table_location_list) == 1:
  535. # if not table_standard_list[0]:
  536. # continue
  537. row_lines = area_row_lines[0]
  538. col_lines = area_col_lines[0]
  539. row_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
  540. col_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
  541. for r in row_lines:
  542. cv2.line(row_img, r[0], r[1], (255, 255, 255), 1)
  543. for c in col_lines:
  544. cv2.line(col_img, c[0], c[1], (255, 255, 255), 1)
  545. row_list, is_standard = get_bbox_by_img(row_img, col_img)
  546. if not is_standard:
  547. continue
  548. row_list = merge_text_and_table(bbox_list, row_list)
  549. continue_flag = 0
  550. for row in row_list:
  551. for b in row:
  552. if len(b) > 1:
  553. continue_flag = 1
  554. break
  555. if continue_flag:
  556. continue
  557. max_len = 1
  558. continue_flag = 0
  559. for row in row_list:
  560. if abs(max_len - len(row)) > 2:
  561. continue_flag = 1
  562. break
  563. if len(row) > max_len:
  564. max_len = len(row)
  565. if continue_flag:
  566. continue
  567. img_label = cv2.imread(label_p[:-1])
  568. row_img1, col_img1 = get_lines_from_img(img_label)
  569. label_row_list, label_is_standard = get_bbox_by_img(row_img1, col_img1)
  570. if not label_is_standard:
  571. continue
  572. label_row_list = merge_text_and_table(bbox_list, label_row_list)
  573. add_flag = 0
  574. modify_flag = 0
  575. for i in range(len(row_list)):
  576. if i >= len(label_row_list):
  577. continue
  578. row = row_list[i]
  579. label_row = label_row_list[i]
  580. for r in label_row:
  581. if r not in row:
  582. add_flag += 1
  583. else:
  584. if label_row.index(r) != row.index(r):
  585. modify_flag += 1
  586. bbox_cnt = 0
  587. for row in row_list:
  588. for b in row:
  589. bbox_cnt += 1
  590. label_bbox_cnt = 0
  591. for row in label_row_list:
  592. for b in row:
  593. label_bbox_cnt += 1
  594. teds = 1 - (add_flag + modify_flag) / max(bbox_cnt, label_bbox_cnt)
  595. print('add_flag', add_flag, 'modify_flag', modify_flag, 'bbox_cnt', bbox_cnt, 'label_bbox_cnt', label_bbox_cnt)
  596. print('TEDS:', teds, p)
  597. all_teds += teds
  598. all_standard_cnt += 1
  599. # if teds <= 0.8:
  600. # print('row_list', [y for y in [x for x in row_list]])
  601. # print('label_row_list', [y for y in [x for x in label_row_list]])
  602. # cv2.imshow('model_table', row_img+col_img)
  603. # cv2.imshow('label_table', row_img1+col_img1)
  604. # cv2.waitKey(0)
  605. # for i in range(len(row_list)):
  606. try:
  607. avg_teds = all_teds / all_standard_cnt
  608. except:
  609. avg_teds = 0
  610. print('standard table cnt', all_standard_cnt)
  611. print('Avg TEDS', avg_teds)
  612. return
  613. def get_table_new():
  614. with open('standard_table.txt', 'r') as f:
  615. label_path_list = f.readlines()
  616. # 表格分割问题:1019.jpg, 1020.jpg, 1023.jpg, 1027.jpg, 1029.jpg, 1030.jpg, 1031.jpg, 1035.jpg, 1040.jpg, 1042.jpg, 1046.jpg, 1047.jpg, 1061.jpg, 1064.jpg, 1067.jpg, 1072.jpg
  617. # 分列问题:1059.jpg,
  618. paths = glob('../data/borderless_tables/*.jpg')
  619. # paths = glob(r'C:\Users\Administrator\Desktop\test_pdf_table\1.png')
  620. paths = ['1019.jpg', '1020.jpg', '1023.jpg', '1027.jpg', '1029.jpg', '1030.jpg', '1031.jpg', '1035.jpg', '1040.jpg', '1042.jpg', '1046.jpg', '1047.jpg', '1061.jpg', '1064.jpg', '1067.jpg', '1072.jpg']
  621. paths = ['../data/borderless_tables/' + x for x in paths]
  622. path_cnt = 0
  623. for p in paths:
  624. if 'label' in p:
  625. continue
  626. # label_p = p[:-4] + '_label.jpg\n'
  627. # if label_p not in label_path_list:
  628. # continue
  629. # if path_cnt <= 22:
  630. # path_cnt += 1
  631. # continue
  632. path_cnt += 1
  633. img = cv2.imread(p)
  634. result = test_ocr_model(p)
  635. print(p)
  636. bbox_list = eval(result.get('bbox'))
  637. text_list = eval(result.get('text'))
  638. # 处理bbox,缩小框
  639. bbox_list = shrink_bbox(img, bbox_list)
  640. # 创建对应dict
  641. bbox_text_dict = {}
  642. for i in range(len(text_list)):
  643. bbox_text_dict[str(bbox_list[i])] = text_list[i]
  644. # 获取全局的按行排列bbox
  645. row_list = get_table_rows(bbox_list, bbox_text_dict)
  646. # 获取表格区域,以及区域里的按行排列bbox
  647. table_location_list, area_row_list = get_table_location(row_list, bbox_text_dict)
  648. area_row_list = merge_row_bbox_list(area_row_list)
  649. # for a in area_row_list:
  650. # i = 0
  651. # for r in a:
  652. # print('row', i)
  653. # i += 1
  654. # for b in r:
  655. # print(bbox_text_dict.get(str(b)))
  656. # 获取全局的按列排列bbox
  657. area_col_list = get_table_cols(bbox_list, table_location_list)
  658. # 获取行线、列线
  659. area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
  660. # 判断列线合法
  661. area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict)
  662. # # 判断行线合法
  663. area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict)
  664. # 由线得到按行列排列的bbox
  665. area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list)
  666. for a in area_table_bbox_list:
  667. for r in a:
  668. for c in r:
  669. # cv2.rectangle(img, c[0], c[1], (255, 0, 0), 1)
  670. for b in c:
  671. cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1)
  672. cv2.imshow('table_cell', img)
  673. # 分割表格
  674. # table_location_list, _ = split_table_new2(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict)
  675. # table_location_list, _ = split_table(table_location_list, area_row_list, bbox_text_dict)
  676. # table_location_list = split_table_by_col(table_location_list, area_table_bbox_list, bbox_text_dict)
  677. # table_location_list = split_table_by_table_head(table_location_list, area_table_bbox_list, bbox_text_dict)
  678. # 重新生成按行排列bbox
  679. area_row_list = get_table_rows2(area_row_list, table_location_list)
  680. # for a in area_row_list:
  681. # for r in a:
  682. # for b in r:
  683. # cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1)
  684. # cv2.imshow('area_row_list', img)
  685. # 获取全局的按列排列bbox
  686. area_col_list = get_table_cols(bbox_list, table_location_list)
  687. # 获取行线、列线
  688. area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
  689. # 判断列线合法
  690. area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict)
  691. # 判断行线合法
  692. area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict)
  693. # 展示
  694. show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list)
  695. return
  696. def show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list):
  697. for bbox in bbox_list:
  698. cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])),
  699. (0, 0, 255), 1)
  700. for i in range(len(table_location_list)):
  701. # location = table_location_list[i]
  702. # cv2.rectangle(img, location[0], location[1], (0, 255, 0), 1)
  703. row_lines = area_row_lines[i]
  704. col_lines = area_col_lines[i]
  705. for r in row_lines:
  706. cv2.line(img, r[0], r[1], (0, 255, 0), 1)
  707. for c in col_lines:
  708. cv2.line(img, c[0], c[1], (0, 255, 0), 1)
  709. cv2.namedWindow('img', cv2.WINDOW_NORMAL)
  710. cv2.imshow('img', img)
  711. cv2.waitKey(0)
  712. return
  713. def get_table_borders(area_row_list, area_col_list, table_location_list):
  714. area_row_lines = []
  715. area_col_lines = []
  716. # 循环每个表格
  717. for i in range(len(area_row_list)):
  718. row_list = area_row_list[i]
  719. col_list = area_col_list[i]
  720. location = table_location_list[i]
  721. # 获取行线
  722. row_lines = []
  723. row_lines.append([[location[0][0], location[0][1]], [location[1][0], location[0][1]]])
  724. for row in row_list:
  725. max_h = 0
  726. for bbox in row:
  727. if bbox[2][1] > max_h:
  728. max_h = int(bbox[2][1])
  729. row_lines.append([[location[0][0], max_h], [location[1][0], max_h]])
  730. row_lines[-1][0][1] = max(location[1][1], row_lines[-1][0][1])
  731. row_lines[-1][1][1] = max(location[1][1], row_lines[-1][1][1])
  732. # 补充表格行范围
  733. table_location_list[i][1][1] = max(location[1][1], row_lines[-1][1][1])
  734. location = table_location_list[i]
  735. # 获取列线
  736. col_lines = []
  737. col_lines.append([[location[0][0], location[0][1]], [location[0][0], location[1][1]]])
  738. for col in col_list:
  739. max_w = 0
  740. for bbox in col:
  741. if bbox[2][0] > max_w:
  742. max_w = int(bbox[2][0])
  743. col_lines.append([[max_w, location[0][1]], [max_w, location[1][1]]])
  744. # 补充表格列范围
  745. table_location_list[i][1][0] = max(location[1][0], col_lines[-1][1][0])
  746. location = table_location_list[i]
  747. for row in row_lines:
  748. row[0][0] = location[0][0]
  749. row[1][0] = location[1][0]
  750. area_row_lines.append(row_lines)
  751. area_col_lines.append(col_lines)
  752. return area_row_lines, area_col_lines
  753. def get_table_location(row_list, bbox_text_dict):
  754. # for r in row_list:
  755. # print('row', r)
  756. up_h = 10000
  757. bottom_h = 0
  758. left_w = 10000
  759. right_w = 0
  760. table_rows = 0
  761. tolerance_list = []
  762. area_row_list = []
  763. temp_row_list = []
  764. table_location_list = []
  765. catalog_text_cnt = 0
  766. for row in row_list:
  767. if len(row) >= 2:
  768. if not temp_row_list:
  769. # 第一行bbox之间需大于一定值
  770. max_distance = 0
  771. row.sort(key=lambda x: x[0][0])
  772. row_text_list = []
  773. catalog_text_cnt = 0
  774. bbox_height_list = [abs(row[-1][0][1] - row[-1][2][1])]
  775. for i in range(1, len(row)):
  776. dis = row[i][0][0] - row[i-1][2][0]
  777. if dis >= max_distance:
  778. max_distance = dis
  779. text = bbox_text_dict.get(str(row[i-1]))
  780. row_text_list.append(bbox_text_dict.get(str(row[i-1])))
  781. match = re.findall('\\.+\d+', text)
  782. if match and len(match[0]) == len(text):
  783. catalog_text_cnt += 1
  784. bbox_height_list.append(abs(row[i][0][1] - row[i][2][1]))
  785. # 排除
  786. # if len(row) == 2:
  787. # if max_distance <= abs(row[0][2][0] - row[0][0][0]):
  788. # continue
  789. # else:
  790. if max_distance <= 5:
  791. continue
  792. # 排除 '地 址', '名 称'
  793. # if len(row) == 2 and len(bbox_text_dict.get(str(row[0]))) == 1:
  794. # continue
  795. row_text_list = []
  796. bbox_height_list = []
  797. for i in range(len(row)):
  798. text = bbox_text_dict.get(str(row[i-1]))
  799. row_text_list.append(bbox_text_dict.get(str(row[i-1])))
  800. match = re.findall('\\.+\d+', text)
  801. if match and len(match[0]) == len(text):
  802. catalog_text_cnt += 1
  803. bbox_height_list.append(abs(row[i][0][1] - row[i][2][1]))
  804. # 排除height差别过大的
  805. bbox_height_list.sort(key=lambda x: x)
  806. if bbox_height_list[-1] - bbox_height_list[0] > bbox_height_list[0]:
  807. continue
  808. # 排除目录
  809. if catalog_text_cnt >= 3:
  810. continue
  811. # 排除水印图
  812. if len(list(set(row_text_list))) < 2/3 * len(row):
  813. continue
  814. # 排除有下划线的
  815. table_rows += 1
  816. temp_row_list.append(row)
  817. for bbox in row:
  818. if up_h > bbox[0][1]:
  819. up_h = bbox[0][1]
  820. if bottom_h < bbox[2][1]:
  821. bottom_h = bbox[2][1]
  822. if left_w > bbox[0][0]:
  823. left_w = bbox[0][0]
  824. if right_w < bbox[2][0]:
  825. right_w = bbox[2][0]
  826. else:
  827. if len(tolerance_list) < 3 and table_rows > 0:
  828. tolerance_list.append(row)
  829. temp_row_list.append(row)
  830. continue
  831. if table_rows > 2 and up_h < bottom_h:
  832. table_location_list.append([[int(left_w), int(up_h)],
  833. [int(right_w), int(bottom_h)]])
  834. if tolerance_list[-1] == temp_row_list[-1]:
  835. area_row_list.append(temp_row_list[:-1])
  836. else:
  837. area_row_list.append(temp_row_list)
  838. up_h = 10000
  839. bottom_h = 0
  840. left_w = 10000
  841. right_w = 0
  842. table_rows = 0
  843. tolerance_list = []
  844. temp_row_list = []
  845. if temp_row_list:
  846. if table_rows > 2 and up_h < bottom_h:
  847. table_location_list.append([[int(left_w), int(up_h)],
  848. [int(right_w), int(bottom_h)]])
  849. area_row_list.append(temp_row_list)
  850. return table_location_list, area_row_list
  851. def get_table_rows(bbox_list, bbox_text_dict):
  852. bbox_list.sort(key=lambda x: (x[0][1], x[2][1], x[0][0], x[2][0]))
  853. row_list = []
  854. used_bbox_list = []
  855. for b1 in bbox_list:
  856. if b1 in used_bbox_list:
  857. continue
  858. temp_bbox_list = [b1]
  859. used_bbox_list.append(b1)
  860. for b2 in bbox_list:
  861. if b2 in used_bbox_list:
  862. continue
  863. if abs((b1[0][1] + b1[2][1]) / 2 - (b2[0][1] + b2[2][1]) / 2) <= 10 \
  864. and line_overlap(b1[0][1], b1[2][1], b2[0][1], b2[2][1]) >= 1/2*min(b1[2][1]-b1[0][1], b2[2][1]-b2[0][1]):
  865. temp_bbox_list.append(b2)
  866. used_bbox_list.append(b2)
  867. row_list.append(temp_bbox_list)
  868. return row_list
  869. def get_table_rows2(area_row_list, table_location_list):
  870. temp_area_row_list = []
  871. for area in area_row_list:
  872. temp_area_row_list += area
  873. area_row_list = []
  874. for location in table_location_list:
  875. row_list = []
  876. for row in temp_area_row_list:
  877. if location[0][1] <= row[0][0][1] <= row[0][2][1] <= location[1][1]:
  878. row_list.append(row)
  879. area_row_list.append(row_list)
  880. return area_row_list
  881. def get_table_bbox_row_or_col(bbox_list, axis=0):
  882. bbox_list.sort(key=lambda x: (x[0][1-axis], x[2][1-axis], x[0][axis], x[2][axis]))
  883. row_list = []
  884. used_bbox_list = []
  885. for b1 in bbox_list:
  886. if b1 in used_bbox_list:
  887. continue
  888. temp_bbox_list = [b1]
  889. used_bbox_list.append(b1)
  890. for b2 in bbox_list:
  891. if b2 in used_bbox_list:
  892. continue
  893. if abs((b1[0][1-axis] + b1[2][1-axis]) / 2 - (b2[0][1-axis] + b2[2][1-axis]) / 2) <= 10:
  894. temp_bbox_list.append(b2)
  895. used_bbox_list.append(b2)
  896. row_list.append(temp_bbox_list)
  897. return row_list
  898. def get_table_cols(bbox_list, table_location_list):
  899. bbox_list.sort(key=lambda x: (x[0][0], x[2][0], x[0][1], x[2][1]))
  900. all_col_list = []
  901. used_bbox_list = []
  902. for location in table_location_list:
  903. sub_bbox_list = []
  904. for b in bbox_list:
  905. if location[0][1] <= (b[0][1] + b[2][1])/2 <= location[1][1]:
  906. sub_bbox_list.append(b)
  907. col_list = []
  908. for b1 in sub_bbox_list:
  909. if b1 in used_bbox_list:
  910. continue
  911. col_width = [b1[0][0], b1[2][0]]
  912. temp_bbox_list = [b1]
  913. used_bbox_list.append(b1)
  914. for b2 in sub_bbox_list:
  915. if b2 in used_bbox_list:
  916. continue
  917. # 判断同一列
  918. # 1. 中心点相差一定范围内
  919. # 2. 左边点相差一定范围内
  920. # 3. 行范围包含
  921. # 4. iou大于一定值
  922. if abs((b1[0][0] + b1[2][0]) / 2 - (b2[0][0] + b2[2][0]) / 2) <= 10 \
  923. or abs(b1[0][0] - b2[0][0]) <= 10 \
  924. or col_width[0] <= b2[0][0] <= b2[2][0] <= col_width[1] \
  925. or b2[0][0] <= col_width[0] <= col_width[1] <= b2[2][0] \
  926. or line_iou([[col_width[0], 0], [col_width[1], 0]], [[b2[0][0], 0], [b2[1][0], 0]], axis=0) >= 0.6:
  927. temp_bbox_list.append(b2)
  928. used_bbox_list.append(b2)
  929. if b2[0][0] < col_width[0]:
  930. col_width[0] = b2[0][0]
  931. if b2[2][0] > col_width[1]:
  932. col_width[1] = b2[2][0]
  933. col_list.append(temp_bbox_list)
  934. all_col_list.append(col_list)
  935. return all_col_list
  936. def merge_col_bbox_by_cluster(img, area_row_list, area_col_list, bbox_text_dict, all_bbox_list, table_location_list):
  937. temp_img = copy.deepcopy(img)
  938. # 循环每个表格
  939. for i in range(len(area_row_list)):
  940. row = area_row_list[i]
  941. col = area_col_list[i]
  942. # 循环每一列,计算列中行之间的间隔距离
  943. new_col = []
  944. col_cnt = 0
  945. for bbox_list in col:
  946. # 获取间隔距离
  947. distance_list = []
  948. bbox_list.sort(key=lambda x: (x[0][1], x[1][1]))
  949. text_list = [bbox_text_dict.get(str(x)) for x in bbox_list]
  950. for j in range(1, len(bbox_list)):
  951. dis = bbox_list[j][0][1] - bbox_list[j-1][2][1]
  952. if dis < 0:
  953. dis = 0.
  954. distance_list.append(dis)
  955. print("\n")
  956. print("distance_list", distance_list)
  957. # 聚类获取类别组
  958. data_list = [[0, x] for x in distance_list]
  959. # 排除距离大于一定值的
  960. data_mask_list = []
  961. temp_data_list = []
  962. for j in range(len(data_list)):
  963. if data_list[j][1] < 5.:
  964. data_mask_list.append(True)
  965. temp_data_list.append(data_list[j])
  966. else:
  967. data_mask_list.append(False)
  968. data_list = temp_data_list
  969. print("data_list", data_list)
  970. cluster_list = []
  971. if len(data_list) > 2:
  972. # 聚类
  973. pred_list = dbscan(data_list)
  974. print('pred_list', pred_list)
  975. temp_pred_list = []
  976. for j in data_mask_list:
  977. if j:
  978. temp_pred_list.append(pred_list.pop(0))
  979. else:
  980. temp_pred_list.append(-1)
  981. pred_list = temp_pred_list
  982. print('pred_list', pred_list)
  983. cluster_num = len(list(set(pred_list)))
  984. for k in range(cluster_num):
  985. temp_list = []
  986. for j in range(len(pred_list)):
  987. if pred_list[j] == k:
  988. if temp_list:
  989. if j - temp_list[-1] == 1:
  990. temp_list.append(j)
  991. else:
  992. temp_list.append(j)
  993. else:
  994. if temp_list:
  995. cluster_list.append(temp_list)
  996. temp_list = []
  997. if temp_list:
  998. cluster_list.append(temp_list)
  999. elif len(data_list) > 0:
  1000. temp_list = []
  1001. for j in range(len(distance_list)):
  1002. if distance_list[j] < 5.0:
  1003. temp_list.append(j)
  1004. else:
  1005. if temp_list:
  1006. cluster_list.append(temp_list)
  1007. temp_list = []
  1008. if temp_list:
  1009. cluster_list.append(temp_list)
  1010. # cluster_list.append([x for x in range(len(distance_list))])
  1011. print('text_list', text_list)
  1012. print('cluster_list', cluster_list)
  1013. # 合并bbox
  1014. new_bbox_list = copy.deepcopy(bbox_list)
  1015. for cluster in cluster_list:
  1016. merge_flag = 1
  1017. for dis in [distance_list[x] for x in cluster]:
  1018. if dis >= 5.0:
  1019. merge_flag = 0
  1020. break
  1021. if merge_flag:
  1022. b_list = bbox_list[cluster[0]:cluster[-1]+2]
  1023. t_list = text_list[cluster[0]:cluster[-1]+2]
  1024. min_w = 10000
  1025. max_w = 0
  1026. min_h = 10000
  1027. max_h = 0
  1028. b_list = [eval(x) for x in list(set([str(x) for x in b_list]))]
  1029. for bbox in b_list:
  1030. if bbox in new_bbox_list:
  1031. new_bbox_list.remove(bbox)
  1032. if bbox in all_bbox_list:
  1033. all_bbox_list.remove(bbox)
  1034. if bbox[0][0] < min_w:
  1035. min_w = bbox[0][0]
  1036. if bbox[0][1] < min_h:
  1037. min_h = bbox[0][1]
  1038. if bbox[2][0] > max_w:
  1039. max_w = bbox[2][0]
  1040. if bbox[2][1] > max_h:
  1041. max_h = bbox[2][1]
  1042. new_bbox = [[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]]
  1043. new_bbox_list.append(new_bbox)
  1044. all_bbox_list.append(new_bbox)
  1045. # 根据第一列的合并结果,指导其他列合并
  1046. if col_cnt == 0:
  1047. first_col_rows = get_first_col_rows(new_bbox_list, table_location_list[i])
  1048. for r in first_col_rows:
  1049. cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (0, 0, 255), 1)
  1050. cv2.imshow('temp_img', temp_img)
  1051. # cv2.waitKey(0)
  1052. col_cnt += 1
  1053. # new_col.append(new_bbox_list)
  1054. return all_bbox_list
  1055. def merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list):
  1056. temp_img = copy.deepcopy(img)
  1057. # 循环每个表格
  1058. for i in range(len(area_row_list)):
  1059. row_list = area_row_list[i]
  1060. col_list = area_col_list[i]
  1061. table_location = table_location_list[i]
  1062. sub_bbox_list = []
  1063. for bbox in bbox_list:
  1064. if table_location[0][1] <= bbox[0][1] <= table_location[1][1] \
  1065. or table_location[0][1] <= bbox[1][1] <= table_location[1][1]:
  1066. sub_bbox_list.append(bbox)
  1067. # 对第一列聚类,合并,再根据空白分行
  1068. first_col = col_list[0]
  1069. cluster_list, distance_list = distance_cluster(first_col, axis=1)
  1070. merge_first_col = merge_cluster(first_col, cluster_list, distance_list)
  1071. merge_first_col.sort(key=lambda x: (x[0][1], x[0][0]))
  1072. row_lines = get_first_col_rows(merge_first_col, table_location)
  1073. # 对其他列聚类,合并
  1074. # merge_bbox_list = [] + first_col
  1075. # for col in col_list[1:]:
  1076. # cluster_list = distance_cluster(col, axis=1)
  1077. # merge_col = merge_cluster(col, cluster_list)
  1078. # merge_bbox_list += merge_col
  1079. # 循环每一列,根据分行合并
  1080. new_row_list = []
  1081. row_lines.sort(key=lambda x: x)
  1082. row_cnt = 0
  1083. need_add_bbox = []
  1084. # for c in first_col:
  1085. # print('first col ', bbox_text_dict.get(str(c)))
  1086. for j in range(1, len(row_lines)):
  1087. print('\n')
  1088. top_line = row_lines[j-1]
  1089. bottom_line = row_lines[j]
  1090. new_row = []
  1091. if need_add_bbox:
  1092. # print('add')
  1093. new_row += need_add_bbox
  1094. print('add', bbox_text_dict.get(str(new_row[0])))
  1095. need_add_bbox = []
  1096. # 合并条件:
  1097. # 1. 完全包含
  1098. # 2. 处在两行之间,判断bbox与第一列的这两行的bbox高度距离
  1099. for bbox in sub_bbox_list:
  1100. if top_line <= bbox[0][1] <= bbox[2][1] <= bottom_line:
  1101. new_row.append(bbox)
  1102. # print('bbox, line', bbox_text_dict.get(str(bbox)), top_line, bottom_line)
  1103. else:
  1104. if bbox in first_col:
  1105. continue
  1106. # 如果第一列只有一行,交界处的bbox不算
  1107. if len(first_col) == 1:
  1108. need_add_bbox.append(bbox)
  1109. continue
  1110. # 计算离该bbox最近的上下两个第一列的bbox
  1111. first_col_center_h1 = 0
  1112. first_col_center_h2 = 10000
  1113. first_col_bbox1 = None
  1114. first_col_bbox2 = None
  1115. bbox_center_h = (bbox[0][1] + bbox[2][1]) / 2
  1116. for b in first_col:
  1117. b_center_h = (b[0][1] + b[2][1]) / 2
  1118. # if bbox[0][1] <= b_center_h <= bbox[2][1]:
  1119. # first_col_center_h2 = b_center_h
  1120. # break
  1121. if bbox_center_h >= b_center_h and bbox_center_h - b_center_h <= bbox_center_h - first_col_center_h1:
  1122. first_col_center_h1 = b_center_h
  1123. first_col_bbox1 = b
  1124. if b_center_h >= bbox_center_h and b_center_h - bbox_center_h <= first_col_center_h2 - bbox_center_h:
  1125. first_col_center_h2 = b_center_h
  1126. first_col_bbox2 = b
  1127. # 如果离该bbox最近的第一列的bbox,不是这一行的
  1128. if new_row and first_col_bbox1 != new_row[0] and top_line < bbox[0][1] < bottom_line:
  1129. need_add_bbox.append(bbox)
  1130. continue
  1131. # if top_line <= bbox[2][1] <= bottom_line \
  1132. # and abs(first_col_center_h1 - bbox_center_h) >= abs(first_col_center_h2 - bbox_center_h):
  1133. # new_row.append(bbox)
  1134. # if first_col_bbox1 and first_col_bbox2:
  1135. # print('bbox1, bbox2', bbox_text_dict[str(first_col_bbox1)], bbox_text_dict[str(first_col_bbox2)],
  1136. # bbox_text_dict[str(bbox)])
  1137. if top_line < bbox[0][1] < bottom_line \
  1138. and abs(first_col_center_h1 - bbox_center_h) <= abs(first_col_center_h2 - bbox_center_h):
  1139. new_row.append(bbox)
  1140. elif top_line < bbox[0][1] < bottom_line:
  1141. need_add_bbox.append(bbox)
  1142. for r in need_add_bbox:
  1143. print("next_row bbox", bbox_text_dict.get(str(r)))
  1144. print('row', row_cnt, len(new_row))
  1145. for b in new_row:
  1146. print(bbox_text_dict.get(str(b)))
  1147. row_cnt += 1
  1148. new_row_list.append(new_row)
  1149. area_row_list[i] = new_row_list
  1150. # show
  1151. r_cnt = 0
  1152. # for r in row_lines:
  1153. # if r_cnt == 0 or r_cnt == len(row_lines) - 1:
  1154. # cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (255, 0, 0), 1)
  1155. # else:
  1156. # cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (0, 255, 0), 1)
  1157. # r_cnt += 1
  1158. # for b in merge_bbox_list:
  1159. # cv2.rectangle(temp_img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (0, 0, 255), 1)
  1160. # cv2.imshow('temp_img', temp_img)
  1161. return area_row_list
  1162. def distance_cluster(bbox_list, max_distance=5., axis=1):
  1163. # 获取间隔距离
  1164. distance_list = []
  1165. bbox_list.sort(key=lambda x: (x[0][1], x[1][1]))
  1166. for j in range(1, len(bbox_list)):
  1167. dis = bbox_list[j][0][axis] - bbox_list[j-1][2][axis]
  1168. if dis < 0:
  1169. dis = 0.
  1170. distance_list.append(dis)
  1171. print("\n")
  1172. print("distance_list", distance_list)
  1173. # 聚类获取类别组
  1174. data_list = [[0, x] for x in distance_list]
  1175. # 排除距离大于一定值的
  1176. data_mask_list = []
  1177. temp_data_list = []
  1178. for j in range(len(data_list)):
  1179. if data_list[j][1] < max_distance:
  1180. data_mask_list.append(True)
  1181. temp_data_list.append(data_list[j])
  1182. else:
  1183. data_mask_list.append(False)
  1184. data_list = temp_data_list
  1185. print("data_list", data_list)
  1186. cluster_list = []
  1187. if len(data_list) > 2:
  1188. # 聚类
  1189. pred_list = dbscan(data_list)
  1190. print('pred_list', pred_list)
  1191. temp_pred_list = []
  1192. for j in data_mask_list:
  1193. if j:
  1194. temp_pred_list.append(pred_list.pop(0))
  1195. else:
  1196. temp_pred_list.append(-1)
  1197. pred_list = temp_pred_list
  1198. print('pred_list', pred_list)
  1199. cluster_num = len(list(set(pred_list)))
  1200. for k in range(cluster_num):
  1201. temp_list = []
  1202. for j in range(len(pred_list)):
  1203. if pred_list[j] == k:
  1204. if temp_list:
  1205. if j - temp_list[-1] == 1:
  1206. temp_list.append(j)
  1207. else:
  1208. temp_list.append(j)
  1209. else:
  1210. if temp_list:
  1211. cluster_list.append(temp_list)
  1212. temp_list = []
  1213. if temp_list:
  1214. cluster_list.append(temp_list)
  1215. elif len(data_list) > 0:
  1216. temp_list = []
  1217. for j in range(len(distance_list)):
  1218. if distance_list[j] < max_distance:
  1219. temp_list.append(j)
  1220. else:
  1221. if temp_list:
  1222. cluster_list.append(temp_list)
  1223. temp_list = []
  1224. if temp_list:
  1225. cluster_list.append(temp_list)
  1226. print('cluster_list', cluster_list)
  1227. return cluster_list, distance_list
  1228. def merge_cluster(bbox_list, cluster_list, distance_list):
  1229. new_bbox_list = copy.deepcopy(bbox_list)
  1230. # 特殊情况:每行之间空隙小,且规律,会全被分到一个类
  1231. if len(cluster_list) == 1 and len(cluster_list[0]) >= 4:
  1232. cluster_list = [[x] for x in cluster_list[0]]
  1233. # 每行的空隙小且均匀
  1234. if distance_list:
  1235. if max(distance_list) - min(distance_list) <= 5.5:
  1236. cluster_list = [[i] for i in range(len(distance_list))]
  1237. # 去掉一个最大值,其他的空隙小且均匀
  1238. if distance_list and max(distance_list) - min(distance_list) >= 10:
  1239. index = distance_list.index(max(distance_list))
  1240. if index <= 2 and len(distance_list[index+1:]) >= 3 and max(distance_list[index+1:]) - min(distance_list[index+1:]) <= 5.5:
  1241. if index == 0:
  1242. cluster_list = [[i] for i in range(len(distance_list[index+1:]))]
  1243. else:
  1244. if max(distance_list[:index]) - min(distance_list[:index]) <= 5.5:
  1245. cluster_list = [[i] for i in range(len(distance_list[:index]))]
  1246. cluster_list += [[i] for i in range(len(distance_list[index+1:]))]
  1247. for cluster in cluster_list:
  1248. b_list = bbox_list[cluster[0]:cluster[-1]+2]
  1249. min_w = 10000
  1250. max_w = 0
  1251. min_h = 10000
  1252. max_h = 0
  1253. b_list = [eval(x) for x in list(set([str(x) for x in b_list]))]
  1254. for bbox in b_list:
  1255. if bbox in new_bbox_list:
  1256. new_bbox_list.remove(bbox)
  1257. if bbox[0][0] < min_w:
  1258. min_w = bbox[0][0]
  1259. if bbox[0][1] < min_h:
  1260. min_h = bbox[0][1]
  1261. if bbox[2][0] > max_w:
  1262. max_w = bbox[2][0]
  1263. if bbox[2][1] > max_h:
  1264. max_h = bbox[2][1]
  1265. new_bbox = [[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]]
  1266. new_bbox_list.append(new_bbox)
  1267. return new_bbox_list
  1268. def get_first_col_rows(first_col, table_location):
  1269. """
  1270. 根据第一列的bbox,分行
  1271. :return:
  1272. """
  1273. location_top = table_location[0][1]
  1274. location_bottom = table_location[1][1]
  1275. row_block_list = [table_location[0][1]]
  1276. for i in range(len(first_col)):
  1277. bbox = first_col[i]
  1278. if i + 1 < len(first_col):
  1279. next_bbox = first_col[i+1]
  1280. bbox_distance = abs(bbox[2][1] - next_bbox[0][1])
  1281. else:
  1282. bbox_distance = 10000
  1283. if i == 0:
  1284. top_block = abs(bbox[0][1] - location_top)
  1285. bottom_block = min(top_block, bbox_distance)
  1286. sub_row = bbox[2][1] + bottom_block
  1287. else:
  1288. top_block = abs(bbox[0][1] - row_block_list[-1])
  1289. bottom_block = min(top_block, bbox_distance)
  1290. sub_row = bbox[2][1] + bottom_block
  1291. row_block_list.append(sub_row)
  1292. if len(row_block_list) == 2:
  1293. row_block_list.append(location_bottom)
  1294. else:
  1295. row_block_list[-1] = max(row_block_list[-1], location_bottom)
  1296. return row_block_list
  1297. def judge_standard_table(row_list):
  1298. up_h = 10000
  1299. bottom_h = 0
  1300. left_w = 10000
  1301. right_w = 0
  1302. table_rows = 0
  1303. now_row_len = 0
  1304. init_flag = 0
  1305. tolerance_list = []
  1306. area_row_list = []
  1307. temp_row_list = []
  1308. table_location_list = []
  1309. for row in row_list:
  1310. if init_flag:
  1311. up_h = 10000
  1312. bottom_h = 0
  1313. left_w = 10000
  1314. right_w = 0
  1315. table_rows = 0
  1316. tolerance_list = []
  1317. temp_row_list = []
  1318. init_flag = 0
  1319. if len(row) >= 2:
  1320. if now_row_len == 0:
  1321. now_row_len = len(row)
  1322. else:
  1323. if len(row) != now_row_len:
  1324. init_flag = 1
  1325. continue
  1326. table_rows += 1
  1327. temp_row_list.append(row)
  1328. for bbox in row:
  1329. if up_h > bbox[0][1]:
  1330. up_h = bbox[0][1]
  1331. if bottom_h < bbox[2][1]:
  1332. bottom_h = bbox[2][1]
  1333. if left_w > bbox[0][0]:
  1334. left_w = bbox[0][0]
  1335. if right_w < bbox[2][0]:
  1336. right_w = bbox[2][0]
  1337. else:
  1338. if len(tolerance_list) < 1 and table_rows > 0:
  1339. tolerance_list.append(row)
  1340. temp_row_list.append(row)
  1341. continue
  1342. if table_rows > 1 and up_h < bottom_h:
  1343. table_location_list.append([[int(left_w), int(up_h)],
  1344. [int(right_w), int(bottom_h)]])
  1345. if tolerance_list[-1] == temp_row_list[-1]:
  1346. area_row_list.append(temp_row_list[:-1])
  1347. else:
  1348. area_row_list.append(temp_row_list)
  1349. init_flag = 1
  1350. return table_location_list, area_row_list
  1351. def split_bbox(img, bbox, bbox_text_dict):
  1352. text = bbox_text_dict.get(str(bbox))
  1353. sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
  1354. split_line_list = []
  1355. last_i_status = 1
  1356. # 从左到右遍历img
  1357. for i in range(1, sub_img.shape[1]):
  1358. # 若这一列黑色像素超过一定值
  1359. if np.where(sub_img[:, i, :] < 200)[0].size > sub_img.shape[0]/5:
  1360. i_status = 0
  1361. else:
  1362. i_status = 1
  1363. # 异或,上个像素列为黑且这个像素列为白,或上个像素列为白且这个像素列为黑
  1364. if last_i_status ^ i_status:
  1365. split_line_list.append(int(i))
  1366. last_i_status = i_status
  1367. # 两条分割线太近的去重
  1368. min_len = 5
  1369. last_l = split_line_list[0]
  1370. temp_list = [split_line_list[0]]
  1371. for l in split_line_list[1:]:
  1372. if l - last_l > min_len:
  1373. temp_list.append(l)
  1374. last_l = l
  1375. split_line_list = temp_list
  1376. # 若两个分割线间无黑像素,则是应该分割的
  1377. split_pair_list = []
  1378. last_line = split_line_list[0]
  1379. for line in split_line_list[1:]:
  1380. print('last_line, line', last_line, line, np.where(sub_img[:, last_line:line, :] < 100)[0].size)
  1381. if line - last_line >= 10 and np.where(sub_img[:, last_line:line, :] < 100)[0].size < 10:
  1382. split_pair_list.append([last_line, line])
  1383. last_line = line
  1384. print('split_pair_list', split_pair_list)
  1385. for l in split_line_list:
  1386. l = int(l + bbox[0][0])
  1387. cv2.line(img, (l, int(bbox[0][1])), (l, int(bbox[2][1])), (0, 255, 0), 2)
  1388. cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])),
  1389. (0, 0, 255), 1)
  1390. cv2.imshow('img', img)
  1391. cv2.waitKey(0)
  1392. # 分割得到新bbox
  1393. split_bbox_list = []
  1394. if split_pair_list:
  1395. start_line = 0
  1396. for line1, line2 in split_pair_list:
  1397. w1 = start_line + bbox[0][0]
  1398. w2 = line1 + bbox[0][0]
  1399. start_line = line2
  1400. split_bbox_list.append([[w1, bbox[0][1]], [], [w2, bbox[2][1]], []])
  1401. w1 = start_line + bbox[0][0]
  1402. w2 = bbox[2][0]
  1403. split_bbox_list.append([[w1, bbox[0][1]], [], [w2, bbox[2][1]], []])
  1404. print('split_bbox_list', split_bbox_list)
  1405. # 计算每个字长度
  1406. all_len = 0
  1407. bbox_len_list = []
  1408. for bbox in split_bbox_list:
  1409. _len = abs(bbox[2][0] - bbox[0][0])
  1410. all_len += _len
  1411. bbox_len_list.append(_len)
  1412. single_char_len = all_len / len(text)
  1413. # 根据bbox长度和单字长度比例计算得到截取后的text
  1414. split_text_list = []
  1415. text_start = 0
  1416. for _len in bbox_len_list:
  1417. text_num = int(_len / single_char_len + 0.5)
  1418. text_end = text_start+text_num
  1419. if text_end >= len(text):
  1420. text_end = len(text)
  1421. split_text_list.append(text[text_start:text_end])
  1422. text_start = text_end
  1423. print('split_text_list', split_text_list)
  1424. # 更新bbox_text_dict
  1425. for i, bbox in enumerate(split_bbox_list):
  1426. bbox_text_dict[str(bbox)] = split_text_list[i]
  1427. return split_bbox_list, bbox_text_dict
  1428. def split_table(table_location_list, area_row_list, bbox_text_dict):
  1429. temp_location_list = []
  1430. temp_area_row_list = []
  1431. for i in range(len(table_location_list)):
  1432. location = table_location_list[i]
  1433. sub_row_list = area_row_list[i]
  1434. # 截断标题,对只有行中间或行开头一个bbox的行进行排除
  1435. need_split_index = []
  1436. for j in range(len(sub_row_list)):
  1437. row = sub_row_list[j]
  1438. if len(row) == 1:
  1439. threshold = (row[0][2][0]-row[0][0][0])*1/3
  1440. if row[0][0][0] + threshold <= (location[0][0]+location[1][0])/2 <= row[0][2][0] - threshold:
  1441. need_split_index.append(j)
  1442. elif abs(location[0][0] - row[0][0][0]) <= 5 \
  1443. and row[0][2][0] - row[0][0][0] >= 1/5 * (location[1][0]-location[0][0]):
  1444. need_split_index.append(j)
  1445. print('need_split_index', need_split_index)
  1446. if not need_split_index:
  1447. temp_location_list.append(location)
  1448. temp_area_row_list.append(sub_row_list)
  1449. else:
  1450. last_index = 0
  1451. need_split_index.append(len(sub_row_list))
  1452. for index in need_split_index:
  1453. if index == last_index:
  1454. last_index = index + 1
  1455. continue
  1456. if len(sub_row_list[last_index:index]) < 2:
  1457. last_index = index + 1
  1458. continue
  1459. temp_area_row_list.append(sub_row_list[last_index:index])
  1460. min_w = 10000
  1461. max_w = 0
  1462. min_h = 10000
  1463. max_h = 0
  1464. for row in sub_row_list[last_index:index]:
  1465. for bbox in row:
  1466. if bbox[0][0] < min_w:
  1467. min_w = bbox[0][0]
  1468. if bbox[0][1] < min_h:
  1469. min_h = bbox[0][1]
  1470. if bbox[2][0] > max_w:
  1471. max_w = bbox[2][0]
  1472. if bbox[2][1] > max_h:
  1473. max_h = bbox[2][1]
  1474. temp_location_list.append([[int(min_w), int(min_h)], [int(max_w), int(max_h)]])
  1475. last_index = index+1
  1476. return temp_location_list, temp_area_row_list
  1477. def split_table_by_col(table_location_list, area_table_bbox_list, bbox_text_dict):
  1478. for i in range(len(table_location_list)):
  1479. location = table_location_list[i]
  1480. table_bbox_list = area_table_bbox_list[i]
  1481. # 循环每一行
  1482. split_index_list = []
  1483. for j in range(1, len(table_bbox_list)):
  1484. row = table_bbox_list[j]
  1485. last_row = table_bbox_list[j-1]
  1486. row_bbox_cnt_list = [len(x) for x in row]
  1487. last_row_bbox_cnt_list = [len(x) for x in last_row]
  1488. diff_num = 0
  1489. diff_flag = 0
  1490. for k in range(len(row_bbox_cnt_list)):
  1491. if row_bbox_cnt_list[k] > last_row_bbox_cnt_list[k]:
  1492. if last_row_bbox_cnt_list[k] != 0:
  1493. diff_flag = 1
  1494. diff_num += 1
  1495. if diff_num > 0 and diff_flag:
  1496. split_index_list.append(j)
  1497. continue
  1498. print('split_index_list', split_index_list)
  1499. return table_location_list
  1500. # def split_table_by_table_head(table_location_list, area_table_bbox_list, bbox_text_dict):
  1501. # new_table_location_list = []
  1502. # for i in range(len(table_location_list)):
  1503. # location = table_location_list[i]
  1504. # table_bbox_list = area_table_bbox_list[i]
  1505. #
  1506. # # 每行单独进行表头预测
  1507. # table_head_row_list = []
  1508. # for j in range(len(table_bbox_list)):
  1509. # row = table_bbox_list[j]
  1510. # print('row', row)
  1511. #
  1512. # if row.count([]) == len(row):
  1513. # table_head_row_list.append([['', 0]])
  1514. # continue
  1515. #
  1516. # row_bbox_list = []
  1517. # for col in row:
  1518. # for b in col:
  1519. # new_b = bbox_text_dict.get(str(b))
  1520. # new_b = re.sub("^[^\u4e00-\u9fa5a-zA-Z0-9]+", "", new_b)
  1521. # new_b = re.sub("[^\u4e00-\u9fa5a-zA-Z0-9]+$", "", new_b)
  1522. # row_bbox_list.append(new_b)
  1523. # result_list = predict([row_bbox_list])
  1524. # # 组合结果
  1525. # for m in range(len(result_list)):
  1526. # for n in range(len(result_list[m])):
  1527. # result_list[m][n] = [row_bbox_list[n], int(result_list[m][n])]
  1528. # result_list = result_list[0]
  1529. # print('table_head', result_list)
  1530. # table_head_row_list.append(result_list)
  1531. #
  1532. # # 根据表头分割
  1533. # split_index_list = []
  1534. # for j in range(1, len(table_head_row_list)):
  1535. # row_head = [x[1] for x in table_head_row_list[j]]
  1536. # last_row_head = [x[1] for x in table_head_row_list[j-1]]
  1537. #
  1538. # # [['6', 0], ['税费', 0], ['依法缴纳', 0], ['1', 0], ['次', 0], ['25000', 0], ['25000', 0]]
  1539. # # [['大写', 1], ['肆抢柒万元整', 0]]
  1540. # if 1 in row_head and 1 not in last_row_head:
  1541. # split_index_list.append(j)
  1542. #
  1543. # # [['供应商', 1], ['广东一线达通网络科技有限公司', 0]]
  1544. # # [['货物明细', 1], ['单价金额(元', 1], ['数量', 1], ['总计金额(元', 1]]
  1545. # if 1 in row_head and 1 in last_row_head and 0 not in row_head and row_head.count(1) != last_row_head.count(1):
  1546. # split_index_list.append(j)
  1547. # print('split_index_list', split_index_list)
  1548. #
  1549. # new_location_list = table_split_by_index(location, split_index_list, table_bbox_list)
  1550. # print('new_location_list, location', new_location_list, location)
  1551. # new_table_location_list += new_location_list
  1552. # print('new_table_location_list', new_table_location_list)
  1553. # return new_table_location_list
  1554. def table_split_by_index(table_location, split_index_list, table_bbox_list):
  1555. if split_index_list:
  1556. # 分割表格
  1557. split_index_list = [0] + split_index_list + [len(table_bbox_list)]
  1558. split_index_list = list(set(split_index_list))
  1559. split_index_list.sort(key=lambda x: x)
  1560. print('split_index_list', split_index_list)
  1561. new_location_list = []
  1562. for l in range(1, len(split_index_list)):
  1563. index = split_index_list[l]
  1564. last_index = split_index_list[l-1]
  1565. # if index - last_index <= 2:
  1566. # continue
  1567. # 获取范围
  1568. rows = table_bbox_list[last_index:index]
  1569. min_h, min_w = 10000, 10000
  1570. max_h, max_w = 0, 0
  1571. for r in rows:
  1572. for c in r:
  1573. for b in c:
  1574. if b:
  1575. if b[0][0] < min_w:
  1576. min_w = int(b[0][0])
  1577. if b[0][1] < min_h:
  1578. min_h = int(b[0][1])
  1579. if b[2][0] > max_w:
  1580. max_w = int(b[2][0])
  1581. if b[2][1] > max_h:
  1582. max_h = int(b[2][1])
  1583. new_location = [[min_w, min_h], [max_w, max_h]]
  1584. new_location_list.append(new_location)
  1585. print('new_location', new_location)
  1586. if new_location_list:
  1587. return new_location_list
  1588. else:
  1589. return [table_location]
  1590. else:
  1591. return [table_location]
  1592. def split_table_new(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict):
  1593. temp_location_list = []
  1594. temp_area_row_list = []
  1595. for k in range(len(table_location_list)):
  1596. table = area_table_bbox_list[k]
  1597. location = table_location_list[k]
  1598. row_list = area_row_list[k]
  1599. table_cell_list = area_table_cell_list[k]
  1600. split_row_index_list = []
  1601. # 遍历所有行
  1602. for i in range(len(table)):
  1603. row = table[i]
  1604. # print('row', i)
  1605. # for j in range(len(row)):
  1606. # col = row[j]
  1607. # print('col', j, ';'.join([bbox_text_dict.get(str(x)) for x in col]))
  1608. # 判断该行为表格分割行:
  1609. # 1. 这一行只有一列有值,且这一行的这一列中的最长的bbox,包含其他行同列中的2个以上bbox
  1610. # 2. 这一行只有一列有值,且这一行的这一列中的最长的bbox,跨过了多列
  1611. # 获取前后多行
  1612. n = 3
  1613. if i-n < 0:
  1614. last_n_rows = table[0:i]
  1615. else:
  1616. last_n_rows = table[i-n:i]
  1617. if i+1 >= len(table):
  1618. next_n_rows = []
  1619. elif i+n+1 >= len(table):
  1620. next_n_rows = table[i+1:len(table)]
  1621. else:
  1622. next_n_rows = table[i+1:i+n+1]
  1623. # 寻找一行只有一格有数据的
  1624. not_empty_col_cnt = 0
  1625. only_one_index = -1
  1626. for j in range(len(row)):
  1627. col = row[j]
  1628. if col:
  1629. not_empty_col_cnt += len(col)
  1630. only_one_index = j
  1631. if not_empty_col_cnt == 1:
  1632. print('only_one_index, i', only_one_index, i)
  1633. # 对比前后n行的同一列数据
  1634. for r in last_n_rows+next_n_rows:
  1635. col = r[only_one_index]
  1636. if len(col) > 1:
  1637. print('col', [bbox_text_dict.get(str(x)) for x in col])
  1638. # 找出其他行同一单元格中包含多个横向排列bbox的
  1639. col_bbox_list = [col[0]]
  1640. for bbox in col:
  1641. for j in range(len(col_bbox_list)):
  1642. bbox1 = col_bbox_list[j]
  1643. if bbox1[0][0] <= bbox[0][0] <= bbox[2][0] <= bbox1[2][0]:
  1644. col_bbox_list[j] = bbox
  1645. elif bbox[0][0] <= bbox1[0][0] <= bbox1[2][0] <= bbox[2][0]:
  1646. continue
  1647. else:
  1648. col_bbox_list.append(bbox)
  1649. if len(col_bbox_list) > 1:
  1650. # 找出这一行同列最长的bbox,有没有包含其他行同列的多个bbox
  1651. col = row[only_one_index]
  1652. print('long col', [bbox_text_dict.get(str(x)) for x in col])
  1653. col.sort(key=lambda x: abs(x[2][0]-x[0][0]))
  1654. longest_bbox = col[-1]
  1655. contain_cnt = 0
  1656. cross_cnt = 0
  1657. for bbox in col_bbox_list:
  1658. if longest_bbox[0][0] <= bbox[0][0] <= bbox[2][0] <= longest_bbox[2][0]:
  1659. contain_cnt += 1
  1660. if bbox[0][0] < longest_bbox[0][0] < bbox[2][0] or bbox[0][0] < longest_bbox[2][0] < bbox[2][0]:
  1661. cross_cnt += 1
  1662. print('cross_cnt', cross_cnt)
  1663. if contain_cnt >= 2 or cross_cnt >= 2:
  1664. print('包含多个横向排列bbox', i)
  1665. split_row_index_list.append(i)
  1666. # 看这一行这一列最长bbox有无跨单元格
  1667. col = row[only_one_index]
  1668. col.sort(key=lambda x: abs(x[2][0]-x[0][0]))
  1669. longest_bbox = col[-1]
  1670. cell_row = table_cell_list[i]
  1671. cell_col = cell_row[only_one_index]
  1672. threshold = 15
  1673. if cell_col[0][0]-threshold <= longest_bbox[0][0] <= longest_bbox[2][0] <= cell_col[1][0]+threshold:
  1674. pass
  1675. else:
  1676. print('最长bbox跨单元格', i)
  1677. split_row_index_list.append(i)
  1678. if split_row_index_list:
  1679. # 分割表格
  1680. split_row_index_list.insert(0, -1)
  1681. split_row_index_list.insert(len(split_row_index_list), len(table))
  1682. split_row_index_list = list(set(split_row_index_list))
  1683. split_row_index_list.sort(key=lambda x: x)
  1684. print('split_row_index_list', split_row_index_list, len(table))
  1685. for l in range(1, len(split_row_index_list)):
  1686. index = split_row_index_list[l]
  1687. last_index = split_row_index_list[l-1]
  1688. if index - last_index <= 2:
  1689. continue
  1690. start_row_index = last_index+1
  1691. end_row_index = index-1
  1692. start_row = table[last_index+1]
  1693. end_row = table[index-1]
  1694. start_row = [x for y in start_row for x in y]
  1695. end_row = [x for y in end_row for x in y]
  1696. start_row = list(filter(lambda x: x != [], start_row))
  1697. end_row = list(filter(lambda x: x != [], end_row))
  1698. if not start_row:
  1699. start_row_index = last_index + 2
  1700. start_row = table[start_row_index]
  1701. start_row = [x for y in start_row for x in y]
  1702. start_row = list(filter(lambda x: x != [], start_row))
  1703. if not end_row:
  1704. end_row_index = index - 2
  1705. end_row = table[end_row_index]
  1706. end_row = [x for y in end_row for x in y]
  1707. end_row = list(filter(lambda x: x != [], end_row))
  1708. if not start_row or not end_row or end_row_index-start_row_index < 1:
  1709. continue
  1710. start_row.sort(key=lambda x: x[0][1])
  1711. min_h = int(start_row[0][0][1])
  1712. min_w = location[0][0]
  1713. end_row.sort(key=lambda x: x[2][1])
  1714. max_h = int(end_row[-1][2][1])
  1715. max_w = location[1][0]
  1716. new_location = [[min_w, min_h], [max_w, max_h]]
  1717. temp_location_list.append(new_location)
  1718. temp_area_row_list.append(row_list[last_index+1:index])
  1719. else:
  1720. temp_location_list.append(location)
  1721. temp_area_row_list.append(row_list)
  1722. table_location_list = temp_location_list
  1723. area_row_list = temp_area_row_list
  1724. return table_location_list, area_row_list
  1725. def split_table_new2(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict):
  1726. temp_location_list = []
  1727. temp_area_row_list = []
  1728. for k in range(len(table_location_list)):
  1729. table = area_table_bbox_list[k]
  1730. location = table_location_list[k]
  1731. row_list = area_row_list[k]
  1732. table_cell_list = area_table_cell_list[k]
  1733. split_row_index_list = []
  1734. # 遍历所有行
  1735. table_start_row_index = 0
  1736. for i in range(len(table)):
  1737. row = table[i]
  1738. # 判断该行为表格分割行:
  1739. # 1. 这一行只有一列有值,且这一行的这一列中的最长的bbox,包含其他行同列中的2个以上bbox
  1740. # 2. 这一行只有一列有值,且这一行的这一列中的最长的bbox,跨过了多列
  1741. # print(i, [bbox_text_dict.get(str(y)) for x in row for y in x])
  1742. # 每次找到分割行,更新
  1743. if table_start_row_index >= len(table):
  1744. break
  1745. # 获取前n行
  1746. n = 2
  1747. if i-n < table_start_row_index:
  1748. last_n_rows = table[table_start_row_index:i]
  1749. else:
  1750. last_n_rows = table[i-n:i]
  1751. # 寻找一行中最长的bbox
  1752. max_len_bbox = []
  1753. for col in row:
  1754. for b in col:
  1755. if not max_len_bbox:
  1756. max_len_bbox = b
  1757. else:
  1758. if b[2][0] - b[0][0] > max_len_bbox[2][0]-max_len_bbox[0][0]:
  1759. max_len_bbox = b
  1760. # 对比前n行的数据
  1761. for r in last_n_rows:
  1762. b_list = [y for x in r for y in x]
  1763. # 第n行中的非上下重合的bbox
  1764. temp_b_list = []
  1765. for b in b_list:
  1766. if not temp_b_list:
  1767. temp_b_list.append(b)
  1768. else:
  1769. find_flag = 0
  1770. for tb in temp_b_list:
  1771. if line_overlap(tb[0][0], tb[2][0], b[0][0], b[2][0]) > 0:
  1772. find_flag = 1
  1773. break
  1774. if not find_flag:
  1775. temp_b_list.append(b)
  1776. b_list = temp_b_list
  1777. if len(b_list) > 1 and max_len_bbox:
  1778. # 最长bbox是否包含第n行多个bbox
  1779. contain_cnt = 0
  1780. for b in b_list:
  1781. threshold = (b[2][0]-b[0][0])/4
  1782. if max_len_bbox[0][0] <= b[0][0] <= b[2][0] <= max_len_bbox[2][0]:
  1783. contain_cnt += 1
  1784. if b[0][0]+threshold < max_len_bbox[0][0] < b[2][0]-threshold \
  1785. or b[0][0]+threshold < max_len_bbox[2][0] < b[2][0]-threshold:
  1786. contain_cnt += 1
  1787. # print('contain_cnt', contain_cnt)
  1788. if contain_cnt >= 2:
  1789. # print('包含多个横向排列bbox', i)
  1790. split_row_index_list.append(i)
  1791. table_start_row_index = i+1
  1792. if split_row_index_list:
  1793. # 分割表格
  1794. split_row_index_list.insert(0, -1)
  1795. split_row_index_list.insert(len(split_row_index_list), len(table))
  1796. split_row_index_list = list(set(split_row_index_list))
  1797. split_row_index_list.sort(key=lambda x: x)
  1798. print('split_row_index_list', split_row_index_list, len(table))
  1799. for l in range(1, len(split_row_index_list)):
  1800. index = split_row_index_list[l]
  1801. last_index = split_row_index_list[l-1]
  1802. if index - last_index <= 2:
  1803. continue
  1804. start_row_index = last_index+1
  1805. end_row_index = index-1
  1806. start_row = table[last_index+1]
  1807. end_row = table[index-1]
  1808. start_row = [x for y in start_row for x in y]
  1809. end_row = [x for y in end_row for x in y]
  1810. start_row = list(filter(lambda x: x != [], start_row))
  1811. end_row = list(filter(lambda x: x != [], end_row))
  1812. if not start_row:
  1813. start_row_index = last_index + 2
  1814. start_row = table[start_row_index]
  1815. start_row = [x for y in start_row for x in y]
  1816. start_row = list(filter(lambda x: x != [], start_row))
  1817. if not end_row:
  1818. end_row_index = index - 2
  1819. end_row = table[end_row_index]
  1820. end_row = [x for y in end_row for x in y]
  1821. end_row = list(filter(lambda x: x != [], end_row))
  1822. if not start_row or not end_row or end_row_index-start_row_index < 1:
  1823. continue
  1824. start_row.sort(key=lambda x: x[0][1])
  1825. min_h = int(start_row[0][0][1])
  1826. min_w = location[0][0]
  1827. end_row.sort(key=lambda x: x[2][1])
  1828. # print('end_row', [bbox_text_dict.get(str(x)) for x in end_row])
  1829. max_h = int(end_row[-1][2][1])
  1830. max_w = location[1][0]
  1831. new_location = [[min_w, min_h], [max_w, max_h]]
  1832. temp_location_list.append(new_location)
  1833. temp_area_row_list.append(row_list[start_row_index:end_row_index+1])
  1834. else:
  1835. temp_location_list.append(location)
  1836. temp_area_row_list.append(row_list)
  1837. table_location_list = temp_location_list
  1838. area_row_list = temp_area_row_list
  1839. return table_location_list, area_row_list
  1840. def delete_not_standard_table(img, area_row_list, area_col_list, table_location_list, bbox_list, bbox_text_dict):
  1841. table_standard_list = []
  1842. for i in range(len(table_location_list)):
  1843. row_list = area_row_list[i]
  1844. col_list = area_col_list[i]
  1845. location = table_location_list[i]
  1846. table_standard = True
  1847. # 1. 只有单行或单列
  1848. if len(row_list) <= 1 or len(col_list) <= 1:
  1849. table_standard = False
  1850. table_standard_list.append(table_standard)
  1851. continue
  1852. # 1. 单个单元格过多bbox
  1853. for row in row_list:
  1854. for col in col_list:
  1855. inter = [j for j in row if j in col]
  1856. inter = [eval(x) for x in list(set([str(x) for x in inter]))]
  1857. if len(inter) >= 8:
  1858. table_standard = False
  1859. break
  1860. # 1. 判断表格中,不在bbox中的黑色像素
  1861. table_black_cnt = count_black(img[location[0][1]:location[1][1], location[0][0]:location[1][0], :])
  1862. bbox_black_cnt = 0
  1863. for bbox in bbox_list:
  1864. if location[0][1] <= bbox[0][1] <= location[1][1]:
  1865. sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
  1866. if sub_img.shape[0] >= 3 and sub_img.shape[1] >= 3:
  1867. bbox_black_cnt += count_black(sub_img)
  1868. print('table_black_cnt, bbox_black_cnt', table_black_cnt, bbox_black_cnt, bbox_black_cnt / table_black_cnt)
  1869. if bbox_black_cnt / table_black_cnt < 0.5:
  1870. table_standard = False
  1871. table_standard_list.append(table_standard)
  1872. print('table_standard_list', table_standard_list)
  1873. return table_standard_list
  1874. def bbox_preprocess(bbox_list, text_list, row_list, bbox_text_dict):
  1875. # 合并同一行中多个单字bbox
  1876. for row in row_list:
  1877. single_bbox_list = []
  1878. row.sort(key=lambda x: x[0][0])
  1879. i = 0
  1880. for bbox in row:
  1881. if len(bbox_text_dict.get(str(bbox))) == 1 and i != len(row) - 1:
  1882. single_bbox_list.append(bbox)
  1883. else:
  1884. if len(single_bbox_list) >= 3:
  1885. if len(bbox_text_dict.get(str(bbox))) == 1:
  1886. single_bbox_list.append(bbox)
  1887. new_bbox = single_bbox_list[0]
  1888. new_text = ""
  1889. single_bbox_list.sort(key=lambda x: x[0][0])
  1890. for b in single_bbox_list:
  1891. new_bbox = [[new_bbox[0][0], new_bbox[0][1]],
  1892. [b[2][0], new_bbox[0][1]],
  1893. [b[2][0], b[2][1]],
  1894. [new_bbox[0][0], b[2][1]],
  1895. ]
  1896. bbox_list.remove(b)
  1897. new_text += bbox_text_dict.get(str(b))
  1898. text_list.remove(bbox_text_dict.get(str(b)))
  1899. # print('new_bbox, new_text', new_bbox, new_text)
  1900. bbox_list.append(new_bbox)
  1901. text_list.append(new_text)
  1902. bbox_text_dict[str(new_bbox)] = new_text
  1903. single_bbox_list = []
  1904. i += 1
  1905. return bbox_list, text_list, bbox_text_dict
  1906. def merge_table(area_row_list, area_col_list, table_location_list, bbox_list):
  1907. table_location_list.sort(key=lambda x: x[0][1])
  1908. merge_index_list = []
  1909. temp_merge_list = []
  1910. for i in range(1, len(table_location_list)):
  1911. last_col_list = area_col_list[i-1]
  1912. col_list = area_col_list[i]
  1913. last_location = table_location_list[i-1]
  1914. location = table_location_list[i]
  1915. merge_flag = 0
  1916. # 获取每个列的宽度
  1917. col_width_list = []
  1918. for col in col_list:
  1919. col.sort(key=lambda x: x[0][0])
  1920. min_w = col[0][0][0]
  1921. col.sort(key=lambda x: x[2][0])
  1922. max_w = col[-1][2][0]
  1923. col_width_list.append([min_w, max_w])
  1924. # 获取两个表格之间的bbox,判断bbox是否跨越多列
  1925. threshold = 5
  1926. merge_flag2 = 1
  1927. for bbox in bbox_list:
  1928. if last_location[1][1]-threshold <= bbox[0][1] <= bbox[2][1] <= location[0][1]+threshold:
  1929. if bbox[0][0] < col_width_list[0][0] or bbox[2][0] > col_width_list[-1][1]:
  1930. merge_flag2 = 0
  1931. break
  1932. for w in col_width_list:
  1933. if w[0] <= bbox[0][0] <= w[1] and bbox[2][0] - bbox[0][0] > w[1] - w[0]:
  1934. merge_flag2 = 0
  1935. break
  1936. # if location[0][1] - last_location[1][1] <= 20:
  1937. if merge_flag2:
  1938. if len(last_col_list) == len(col_list):
  1939. temp_merge_list += [i-1, i]
  1940. merge_flag = 1
  1941. if not merge_flag:
  1942. if temp_merge_list:
  1943. merge_index_list.append(temp_merge_list)
  1944. else:
  1945. merge_index_list.append([i-1])
  1946. temp_merge_list = []
  1947. if temp_merge_list:
  1948. merge_index_list.append(temp_merge_list)
  1949. else:
  1950. merge_index_list.append([len(table_location_list)-1])
  1951. # print('merge_index_list', merge_index_list)
  1952. if not merge_index_list:
  1953. return area_row_list, area_col_list, table_location_list
  1954. new_table_location_list = []
  1955. new_area_row_list = []
  1956. new_area_col_list = []
  1957. for index_list in merge_index_list:
  1958. if not table_location_list:
  1959. break
  1960. index_list = list(set(index_list))
  1961. temp_table = table_location_list[index_list[0]]
  1962. new_area_row_list.append(area_row_list[index_list[0]])
  1963. new_area_col_list.append(area_col_list[index_list[0]])
  1964. for index in index_list[1:]:
  1965. temp_table = [[min(temp_table[0][0], table_location_list[index][0][0]),
  1966. min(temp_table[0][1], table_location_list[index][0][1])],
  1967. [max(temp_table[1][0], table_location_list[index][1][0]),
  1968. max(temp_table[1][1], table_location_list[index][1][1])]
  1969. ]
  1970. new_area_row_list[-1] += area_row_list[index]
  1971. new_area_col_list[-1] += area_col_list[index]
  1972. new_table_location_list.append(temp_table)
  1973. return new_area_row_list, new_area_col_list, new_table_location_list
  1974. def add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict):
  1975. """
  1976. 对单个单元格内多列的,增加列线
  1977. :return:
  1978. """
  1979. add_area_col_lines = []
  1980. for i in range(len(table_location_list)):
  1981. row_list = area_row_list[i]
  1982. col_list = area_col_list[i]
  1983. location = table_location_list[i]
  1984. add_col_lines = []
  1985. new_col_list = []
  1986. for col in col_list:
  1987. row_cnt = 0
  1988. new_row_list = []
  1989. cell_col_lines = []
  1990. col.sort(key=lambda x: (x[0][1], x[0][0]))
  1991. # print('col')
  1992. for row in row_list:
  1993. row.sort(key=lambda x: (x[0][0], x[0][1]))
  1994. inter = [j for j in row if j in col]
  1995. inter = [eval(x) for x in list(set([str(x) for x in inter]))]
  1996. inter.sort(key=lambda x: (x[0][1], x[0][0]))
  1997. new_row = []
  1998. print('inter', [bbox_text_dict.get(str(x)) for x in inter])
  1999. # if inter:
  2000. # # 先将同个单元格内上下重叠的bbox合并
  2001. # temp_inter = []
  2002. #
  2003. # used_bbox_list = []
  2004. # for bbox1 in inter:
  2005. # if bbox1 in used_bbox_list:
  2006. # continue
  2007. # temp_merge_bbox = [bbox1]
  2008. # for bbox2 in inter:
  2009. # if bbox2 in used_bbox_list:
  2010. # continue
  2011. # if line_overlap(bbox1[0][0], bbox1[2][0], bbox2[0][0], bbox2[2][0]) >= 2/3 * min(bbox1[2][0]-bbox1[0][0], bbox2[2][0], bbox2[0][0]) \
  2012. # and line_overlap(bbox1[0][1], bbox1[2][1], bbox2[0][1], bbox2[2][1]) > 0:
  2013. # temp_merge_bbox += [bbox1, bbox2]
  2014. # used_bbox_list += [bbox1, bbox2]
  2015. # temp_merge_bbox = [eval(y) for y in list(set([str(x) for x in temp_merge_bbox]))]
  2016. # temp_inter.append(temp_merge_bbox)
  2017. #
  2018. # inter = []
  2019. # for m_bbox in temp_inter:
  2020. # min_w, min_h, max_w, max_h = 10000, 10000, 0, 0
  2021. # temp_text = ""
  2022. # for bbox in m_bbox:
  2023. # if bbox[0][0] < min_w:
  2024. # min_w = bbox[0][0]
  2025. # if bbox[0][1] < min_h:
  2026. # min_h = bbox[0][1]
  2027. # if bbox[2][0] > max_w:
  2028. # max_w = bbox[2][0]
  2029. # if bbox[2][1] > max_h:
  2030. # max_h = bbox[2][1]
  2031. # temp_text += bbox_text_dict.get(str(bbox)) + ' '
  2032. # inter.append([[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]])
  2033. # bbox_text_dict[str(inter[-1])] = temp_text
  2034. # print('merge inter', [bbox_text_dict.get(str(x)) for x in inter])
  2035. # 一个单元格内多个bbox
  2036. if len(inter) > 1:
  2037. # 单元格内分行
  2038. cell_row = []
  2039. temp_row = [inter[0]]
  2040. row_len = [inter[0][0][1], inter[0][2][1]]
  2041. for bbox in inter[1:]:
  2042. temp_bbox = temp_row[0]
  2043. bbox_h_len = bbox[2][1] - bbox[0][1]
  2044. temp_bbox_h_len = temp_bbox[2][1] - temp_bbox[0][1]
  2045. # if temp_bbox[0][1]-5 <= bbox[0][1] <= bbox[2][1] <= temp_bbox[2][1]+5 \
  2046. # or bbox[0][1]-5 <= temp_bbox[0][1] <= temp_bbox[2][1] <= bbox[2][1]+5 \
  2047. if line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]) >= 1/3 * min(bbox_h_len, temp_bbox_h_len):
  2048. temp_row.append(bbox)
  2049. row_len[0] = min(row_len[0], bbox[0][1])
  2050. row_len[1] = max(row_len[1], bbox[2][1])
  2051. # print('in row', bbox_text_dict.get(str(bbox)), bbox_text_dict.get(str(temp_bbox)),
  2052. # row_len[0], row_len[1], bbox[0][1], bbox[2][1],
  2053. # line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]),
  2054. # 1/3 * min(bbox_h_len, temp_bbox_h_len))
  2055. else:
  2056. # print('not in row', bbox_text_dict.get(str(bbox)), bbox_text_dict.get(str(temp_bbox)),
  2057. # line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]), 1/3 * min(bbox_h_len, temp_bbox_h_len))
  2058. # print(bbox_text_dict.get(str(bbox)), temp_bbox[2][1] - bbox[0][1], 1/2 * min(bbox_h_len, temp_bbox_h_len),
  2059. # bbox[2][1] - temp_bbox[0][1], 1/2 * min(bbox_h_len, temp_bbox_h_len),
  2060. # line_overlap(temp_bbox[0][1], temp_bbox[2][1], bbox[0][1], bbox[2][1]), 1/3 * min(bbox_h_len, temp_bbox_h_len),
  2061. # temp_bbox[0][1], temp_bbox[2][1], bbox[0][1], bbox[2][1], bbox_text_dict.get(str(temp_bbox)))
  2062. cell_row.append(temp_row)
  2063. temp_row = [bbox]
  2064. row_len = [bbox[0][1], bbox[2][1]]
  2065. if temp_row:
  2066. cell_row.append(temp_row)
  2067. print('row_cnt', row_cnt)
  2068. for c in cell_row:
  2069. c.sort(key=lambda x: x[0][0])
  2070. print('cell_row', [bbox_text_dict.get(str(x)) for x in c])
  2071. if row_cnt == 0:
  2072. # 获取最大列数的列
  2073. temp_cell_row = copy.deepcopy(cell_row)
  2074. temp_cell_row.sort(key=lambda x: len(x))
  2075. max_cell_row = temp_cell_row[-1]
  2076. # 对行内上下堆叠的进行合并
  2077. max_cell_row.sort(key=lambda x: (x[0][0], x[0][1]))
  2078. used_bbox_list = []
  2079. merge_bbox_list = []
  2080. for bbox1 in max_cell_row:
  2081. temp_merge_bbox = [bbox1]
  2082. if bbox1 in used_bbox_list:
  2083. continue
  2084. for bbox2 in max_cell_row:
  2085. if bbox2 in used_bbox_list:
  2086. continue
  2087. if line_overlap(bbox1[0][0], bbox1[2][0], bbox2[0][0], bbox2[2][0]) >= 2/3 * min(bbox1[2][0]-bbox1[0][0], bbox2[2][0], bbox2[0][0]):
  2088. temp_merge_bbox.append(bbox2)
  2089. used_bbox_list += [bbox1, bbox2]
  2090. # 选范围最大的bbox
  2091. temp_merge_bbox.sort(key=lambda x: (x[2][0], -x[0][0]))
  2092. merge_bbox_list.append(temp_merge_bbox[-1])
  2093. temp_cell_row[-1] = merge_bbox_list
  2094. print('temp_cell_row', [bbox_text_dict.get(str(x)) for x in temp_cell_row[-1]])
  2095. # print('temp_cell_row', temp_cell_row[-1])
  2096. for c in temp_cell_row[-1]:
  2097. cell_col_lines.append([c[0][0], c[2][0]])
  2098. cell_col_lines.sort(key=lambda x: x[0])
  2099. for c in cell_col_lines:
  2100. add_col_lines.append([[int(c[1]), location[0][1]], [int(c[1]), location[1][1]]])
  2101. add_area_col_lines.append(add_col_lines)
  2102. # # 循环所有行,若跨行
  2103. # cell_col_lines.sort(key=lambda x: x[0])
  2104. # cell_row.sort(key=lambda x: (x[0][0], x[0][1]))
  2105. # print('sorted cell_col_lines', cell_col_lines)
  2106. # for r in cell_row:
  2107. # right_bbox = []
  2108. # for bbox in r:
  2109. # for k in range(len(cell_col_lines)):
  2110. # if k == 0:
  2111. # min_w = -10000
  2112. # if len(cell_col_lines) <= 1:
  2113. # max_w = cell_col_lines[k][1]
  2114. # else:
  2115. # max_w = cell_col_lines[k+1][0]
  2116. # # + 1/4*(cell_col_lines[k+1][1]-cell_col_lines[k+1][0])
  2117. # elif k == len(cell_col_lines) - 1:
  2118. # max_w = 10000
  2119. # if len(cell_col_lines) <= 1:
  2120. # min_w = cell_col_lines[k-1][1]
  2121. # else:
  2122. # min_w = cell_col_lines[k][0]
  2123. # else:
  2124. # if len(cell_col_lines) <= 1:
  2125. # min_w = -10000
  2126. # max_w = 10000
  2127. # else:
  2128. # min_w = cell_col_lines[k-1][1]
  2129. # max_w = cell_col_lines[k+1][0]
  2130. # # + 1/4*(cell_col_lines[k+1][1]-cell_col_lines[k+1][0])
  2131. #
  2132. # # 判断跨行
  2133. # if min_w <= bbox[0][0] <= bbox[2][0] <= max_w:
  2134. # new_row.append(bbox)
  2135. # right_bbox.append(bbox)
  2136. # # else:
  2137. # # print(min_w, bbox[0][0], bbox[2][0], max_w,
  2138. # # bbox_text_dict.get(str(bbox)))
  2139. #
  2140. # # 有跨行,该行舍弃
  2141. # if len(right_bbox) != len(r):
  2142. # for r1 in r:
  2143. # if r1 in new_row:
  2144. # new_row.remove(r1)
  2145. #
  2146. # # 单元格只有一个bbox
  2147. # else:
  2148. # new_row = inter
  2149. # print('new_row', [bbox_text_dict.get(str(x)) for x in new_row])
  2150. # new_row.sort(key=lambda x: x[0][0])
  2151. # new_row_list.append(new_row)
  2152. # row_cnt += 1
  2153. # new_col_list.append(new_row_list)
  2154. #
  2155. # new_row_list = [x for x in new_col_list[0]]
  2156. # for col in new_col_list[1:]:
  2157. # for j in range(len(col)):
  2158. # new_row_list[j] += col[j]
  2159. #
  2160. # temp_new_row_list = []
  2161. # for r in new_row_list:
  2162. # if r:
  2163. # temp_new_row_list.append(r)
  2164. # print('new_row_list', [bbox_text_dict.get(str(x)) for x in r])
  2165. # new_row_list = temp_new_row_list
  2166. # area_row_list[i] = new_row_list
  2167. return add_area_col_lines
  2168. def judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict):
  2169. new_area_col_lines = []
  2170. for i in range(len(table_location_list)):
  2171. location = table_location_list[i]
  2172. col_lines = area_col_lines[i]
  2173. col_lines.sort(key=lambda x: x[0][0])
  2174. sub_bbox_list = []
  2175. for bbox in bbox_list:
  2176. if location[0][1] <= bbox[0][1] <= location[1][1]:
  2177. sub_bbox_list.append(bbox)
  2178. # 判断线穿过bbox,那一行的线去掉
  2179. # temp_col_lines = []
  2180. # for c in col_lines:
  2181. # for bbox in sub_bbox_list:
  2182. # 判断新增线有没有压在黑色像素上或有没有在bbox之间
  2183. # temp_col_lines = []
  2184. # for c in col_lines:
  2185. # if c[1][1] >= img.shape[1] or c[0][1] <= 0:
  2186. # continue
  2187. #
  2188. # black_cnt = count_black(img[c[0][1]:c[1][1], c[0][0]:c[1][0]+1, :])
  2189. # if black_cnt > 10:
  2190. # continue
  2191. # temp_col_lines.append(c)
  2192. # col_lines = temp_col_lines
  2193. # 判断两线之间有没有完整bbox
  2194. col_lines = [eval(y) for y in list(set([str(x) for x in col_lines]))]
  2195. col_lines.sort(key=lambda x: x[0][0])
  2196. threshold = 5
  2197. if not col_lines:
  2198. new_area_col_lines.append([])
  2199. continue
  2200. temp_col_lines = [col_lines[0]]
  2201. for j in range(1, len(col_lines)):
  2202. last_col_w = temp_col_lines[-1][0][0]
  2203. col_w = col_lines[j][0][0]
  2204. for bbox in sub_bbox_list:
  2205. if last_col_w-threshold <= bbox[0][0] <= bbox[2][0] <= col_w+threshold:
  2206. temp_col_lines.append(col_lines[j])
  2207. break
  2208. temp_col_lines.append(col_lines[-1])
  2209. col_lines = temp_col_lines
  2210. # 判断线穿过bbox,向右移动到空位
  2211. for col in col_lines:
  2212. cross_bbox_list = []
  2213. for bbox in sub_bbox_list:
  2214. if bbox[0][0] < col[0][0] < bbox[2][0]:
  2215. cross_bbox_list.append(bbox)
  2216. if cross_bbox_list:
  2217. # cross_bbox_list.sort(key=lambda x: x[2][0], reverse=True)
  2218. # for bbox in cross_bbox_list:
  2219. # line_now_w = col[0][0]
  2220. # line_move_w = bbox[2][0]
  2221. # find_flag = 0
  2222. # for bbox1 in sub_bbox_list:
  2223. # if bbox1 in cross_bbox_list:
  2224. # continue
  2225. # if line_now_w <= bbox1[0][0] <= line_move_w:
  2226. # find_flag = 1
  2227. # break
  2228. #
  2229. # if not find_flag:
  2230. # col[0][0] = int(line_move_w)
  2231. # col[1][0] = int(line_move_w)
  2232. # break
  2233. cross_bbox_list.sort(key=lambda x: x[2][0], reverse=True)
  2234. line_move_w = cross_bbox_list[0][2][0]
  2235. line_now_w = col[0][0]
  2236. for bbox1 in sub_bbox_list:
  2237. if bbox1 in cross_bbox_list:
  2238. continue
  2239. if line_now_w <= bbox1[0][0] <= line_move_w:
  2240. line_now_w = line_move_w
  2241. line_move_w = bbox1[2][0]
  2242. col[0][0] = int(line_move_w)
  2243. col[1][0] = int(line_move_w)
  2244. # 将边框线加上
  2245. left_col = [[location[0][0], location[0][1]], [location[0][0], location[1][1]]]
  2246. right_col = [[location[1][0], location[0][1]], [location[1][0], location[1][1]]]
  2247. if left_col not in col_lines:
  2248. col_lines.append(left_col)
  2249. if right_col not in col_lines:
  2250. col_lines.append(right_col)
  2251. new_area_col_lines.append(col_lines)
  2252. return new_area_col_lines
  2253. def add_row_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict, area_row_lines):
  2254. add_area_row_lines = []
  2255. for i in range(len(table_location_list)):
  2256. row_list = area_row_list[i]
  2257. col_list = area_col_list[i]
  2258. location = table_location_list[i]
  2259. row_lines = area_row_lines[i]
  2260. add_row_lines = []
  2261. for row in row_list:
  2262. col_cnt = 0
  2263. row.sort(key=lambda x: (x[0][0], x[0][1]))
  2264. # # 只以第一列为标准
  2265. # first_col = col_list[0]
  2266. # first_col.sort(key=lambda x: (x[0][1], x[0][0]))
  2267. # inter = [j for j in row if j in first_col]
  2268. # inter = [eval(x) for x in list(set([str(x) for x in inter]))]
  2269. # inter.sort(key=lambda x: (x[0][1], x[0][0]))
  2270. # 所有列都参与
  2271. for col in col_list:
  2272. col.sort(key=lambda x: (x[0][1], x[0][0]))
  2273. inter = [j for j in row if j in col]
  2274. print('col', col_cnt, [bbox_text_dict.get(str(x)) for x in col], [bbox_text_dict.get(str(x)) for x in row])
  2275. inter = [eval(x) for x in list(set([str(x) for x in inter]))]
  2276. inter.sort(key=lambda x: (x[0][1], x[0][0]))
  2277. print('add_row_lines inter', [bbox_text_dict.get(str(x)) for x in inter])
  2278. if len(inter) > 0:
  2279. # 单元格内分行
  2280. cell_row = []
  2281. temp_row = [inter[0]]
  2282. row_len = [inter[0][0][1], inter[0][2][1]]
  2283. for bbox in inter[1:]:
  2284. temp_bbox = temp_row[0]
  2285. bbox_h_len = bbox[2][1] - bbox[0][1]
  2286. temp_bbox_h_len = temp_bbox[2][1] - temp_bbox[0][1]
  2287. if line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]) >= 1/3 * min(bbox_h_len, temp_bbox_h_len):
  2288. temp_row.append(bbox)
  2289. row_len[0] = min(row_len[0], bbox[0][1])
  2290. row_len[1] = max(row_len[1], bbox[2][1])
  2291. else:
  2292. cell_row.append(temp_row)
  2293. temp_row = [bbox]
  2294. row_len = [bbox[0][1], bbox[2][1]]
  2295. if temp_row:
  2296. cell_row.append(temp_row)
  2297. print('col_cnt', col_cnt)
  2298. for c in cell_row:
  2299. c.sort(key=lambda x: x[0][0])
  2300. print('cell_row', [bbox_text_dict.get(str(x)) for x in c])
  2301. # 对有多行的
  2302. if len(cell_row) > 0:
  2303. if len(cell_row) == 1:
  2304. h = int(cell_row[0][0][2][1])
  2305. add_row_lines.append([[location[0][0], h], [location[1][0], h]])
  2306. for j in range(1, len(cell_row)):
  2307. last_row = cell_row[j-1]
  2308. row1 = cell_row[j]
  2309. last_row.sort(key=lambda x: x[2][1])
  2310. row1.sort(key=lambda x: x[0][1])
  2311. find_flag = 0
  2312. for l in row_lines:
  2313. if last_row[-1][2][1] <= l[0][1] <= row1[0][0][1]:
  2314. find_flag = 1
  2315. break
  2316. if not find_flag:
  2317. h = int(last_row[-1][2][1])
  2318. if j == 1:
  2319. last_row.sort(key=lambda x: x[0][1])
  2320. h += int(last_row[0][0][1] - location[0][1])
  2321. else:
  2322. last_two_row = cell_row[j-2]
  2323. last_two_row.sort(key=lambda x: x[2][1])
  2324. last_row.sort(key=lambda x: x[0][1])
  2325. h += int(last_row[0][0][1] - last_two_row[-1][2][1])
  2326. add_row_lines.append([[location[0][0], h], [location[1][0], h]])
  2327. col_cnt += 1
  2328. add_area_row_lines.append(add_row_lines)
  2329. return add_area_row_lines
  2330. def judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict):
  2331. new_area_row_lines = []
  2332. for i in range(len(table_location_list)):
  2333. location = table_location_list[i]
  2334. row_lines = area_row_lines[i]
  2335. sub_bbox_list = []
  2336. for bbox in bbox_list:
  2337. if location[0][1] <= bbox[0][1] <= location[1][1]:
  2338. sub_bbox_list.append(bbox)
  2339. # 判断新增线有没有压在黑色像素上或有没有在bbox之间
  2340. # temp_row_lines = []
  2341. # for c in row_lines:
  2342. # if c[1][1] >= img.shape[1] or c[0][1] <= 0:
  2343. # continue
  2344. # if c[1][1] > location[1][1] or c[1][1] < location[0][1]:
  2345. # continue
  2346. # black_cnt = count_black(img[c[0][1]:c[1][1]+1, c[0][0]:c[1][0], :])
  2347. # if black_cnt > 10:
  2348. # continue
  2349. # temp_row_lines.append(c)
  2350. # row_lines = temp_row_lines
  2351. # if not row_lines:
  2352. # new_area_row_lines.append([])
  2353. # continue
  2354. # 判断线穿过bbox,向下移动到空位
  2355. row_lines.sort(key=lambda x: x[0][1])
  2356. for row in row_lines:
  2357. # 循环找出空位
  2358. while True:
  2359. cross_bbox_list = []
  2360. for bbox in sub_bbox_list:
  2361. # if (bbox[0][1]+bbox[2][1])/2 < row[0][1] < bbox[2][1]:
  2362. if bbox[0][1] < row[0][1] < bbox[2][1]:
  2363. cross_bbox_list.append(bbox)
  2364. if cross_bbox_list:
  2365. # print('row cross_bbox_list', [bbox_text_dict.get(str(x)) for x in cross_bbox_list])
  2366. cross_bbox_list.sort(key=lambda x: x[2][1], reverse=True)
  2367. line_move_h = cross_bbox_list[0][2][1]
  2368. move_bbox = cross_bbox_list[0]
  2369. row[0][1] = int(line_move_h)
  2370. row[1][1] = int(line_move_h)
  2371. # print('move_bbox', bbox_text_dict.get(str(move_bbox)))
  2372. else:
  2373. break
  2374. # 判断两线之间有没有完整bbox
  2375. row_lines.sort(key=lambda x: x[0][1])
  2376. threshold = 5
  2377. temp_row_lines = [row_lines[0]]
  2378. for j in range(1, len(row_lines)):
  2379. last_row_w = temp_row_lines[-1][0][1]
  2380. row_w = row_lines[j][0][1]
  2381. for bbox in sub_bbox_list:
  2382. if last_row_w-threshold <= bbox[0][1] <= bbox[2][1] <= row_w+threshold:
  2383. temp_row_lines.append(row_lines[j])
  2384. break
  2385. temp_row_lines.append(row_lines[-1])
  2386. row_lines = temp_row_lines
  2387. # 将边框线加上
  2388. up_row = [[location[0][0], location[0][1]], [location[1][0], location[0][1]]]
  2389. bottom_row = [[location[0][0], location[1][1]], [location[1][0], location[1][1]]]
  2390. if up_row not in row_lines:
  2391. row_lines.append(up_row)
  2392. if bottom_row not in row_lines:
  2393. row_lines.append(bottom_row)
  2394. new_area_row_lines.append(row_lines)
  2395. return new_area_row_lines
  2396. def merge_lines(lines, axis=0, threshold=5):
  2397. lines.sort(key=lambda x: x[0][1-axis])
  2398. used_lines = []
  2399. new_lines = copy.deepcopy(lines)
  2400. for line1 in lines:
  2401. if line1 in used_lines:
  2402. continue
  2403. current_line = line1
  2404. temp_merge_lines = [line1]
  2405. for line2 in lines:
  2406. if line2 in used_lines:
  2407. continue
  2408. if abs(current_line[0][1-axis] - line2[0][1-axis]) <= threshold:
  2409. temp_merge_lines.append(line2)
  2410. used_lines.append(line2)
  2411. current_line = line2
  2412. # 取最右或最下线
  2413. temp_merge_lines.sort(key=lambda x: x[0][1-axis])
  2414. for l in temp_merge_lines:
  2415. if l in new_lines:
  2416. new_lines.remove(l)
  2417. new_lines.append(temp_merge_lines[-1])
  2418. new_lines.sort(key=lambda x: x[0][1-axis])
  2419. return new_lines
  2420. def merge_row_bbox_list(area_row_list):
  2421. new_area_row_list = []
  2422. for row_list in area_row_list:
  2423. new_row_list = copy.deepcopy(row_list)
  2424. # 针对表头有多行被分在不同行
  2425. for i in range(1, len(row_list)-1):
  2426. last_row = row_list[i-1]
  2427. row = row_list[i]
  2428. next_row = row_list[i+1]
  2429. merge_cnt = 0
  2430. for last_b in last_row:
  2431. find_flag1 = 0
  2432. for next_b in next_row:
  2433. if line_overlap(last_b[0][0], last_b[2][0], next_b[0][0], next_b[2][0]) \
  2434. >= 0.8*min(last_b[2][0] - last_b[0][0], next_b[2][0] - next_b[0][0]):
  2435. find_flag1 = 1
  2436. break
  2437. find_flag2 = 0
  2438. if find_flag1:
  2439. for b in row:
  2440. if line_overlap(last_b[0][0], last_b[2][0], b[0][0], b[2][0]) \
  2441. >= 0.8*min(last_b[2][0] - last_b[0][0], b[2][0] - b[0][0]):
  2442. find_flag2 = 1
  2443. break
  2444. if find_flag1 and not find_flag2:
  2445. merge_cnt += 1
  2446. if merge_cnt == len(last_row) and merge_cnt > 1:
  2447. new_row_list = new_row_list[:i-1] + [last_row+row+next_row] + new_row_list[i+2:]
  2448. new_area_row_list.append(new_row_list)
  2449. return new_area_row_list
  2450. def count_black(image_np, threshold=150):
  2451. lower = np.array([0, 0, 0])
  2452. upper = np.array([threshold, threshold, threshold])
  2453. mask = cv2.inRange(image_np, lower, upper)
  2454. cnt = np.sum(mask != 0)
  2455. # print("count color ", cnt)
  2456. return cnt
  2457. def get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0):
  2458. area_row_list = []
  2459. for i in range(len(table_location_list)):
  2460. row_lines = area_row_lines[i]
  2461. col_lines = area_col_lines[i]
  2462. # 求线交点
  2463. cross_points = get_points_by_line(img, row_lines, col_lines)
  2464. # 交点分行
  2465. cross_points.sort(key=lambda x: (x[1-axis], x[axis]))
  2466. row_point_list = []
  2467. current_row = [cross_points[0]]
  2468. for p in cross_points[1:]:
  2469. if current_row[0][1-axis] == p[1-axis]:
  2470. current_row.append(p)
  2471. else:
  2472. row_point_list.append(current_row)
  2473. current_row = [p]
  2474. if current_row:
  2475. row_point_list.append(current_row)
  2476. used_bbox_list = []
  2477. row_list = []
  2478. for j in range(1, len(row_point_list)):
  2479. last_row = row_point_list[j-1]
  2480. row = row_point_list[j]
  2481. sub_row_list = []
  2482. for k in range(1, len(row)):
  2483. last_p = last_row[k-1]
  2484. p = row[k]
  2485. for bbox in bbox_list:
  2486. if bbox in used_bbox_list:
  2487. continue
  2488. bbox_h_center = (bbox[0][1-axis]+bbox[2][1-axis]) / 2
  2489. bbox_w_center = (bbox[0][axis]+bbox[2][axis]) / 2
  2490. if last_p[1-axis] <= bbox_h_center <= p[1-axis] and last_p[axis] <= bbox_w_center <= p[axis]:
  2491. sub_row_list.append(bbox)
  2492. used_bbox_list.append(bbox)
  2493. row_list.append(sub_row_list)
  2494. area_row_list.append(row_list)
  2495. return area_row_list
  2496. def get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list):
  2497. area_table_bbox_list = []
  2498. area_table_cell_list = []
  2499. for i in range(len(table_location_list)):
  2500. row_lines = area_row_lines[i]
  2501. col_lines = area_col_lines[i]
  2502. # 求线交点
  2503. cross_points = get_points_by_line(img, row_lines, col_lines)
  2504. # 交点分行
  2505. cross_points.sort(key=lambda x: (x[1], x[0]))
  2506. row_point_list = []
  2507. if not cross_points:
  2508. area_table_bbox_list.append([])
  2509. area_table_cell_list.append([])
  2510. continue
  2511. current_row = [cross_points[0]]
  2512. for p in cross_points[1:]:
  2513. if current_row[0][1] == p[1]:
  2514. current_row.append(p)
  2515. else:
  2516. row_point_list.append(current_row)
  2517. current_row = [p]
  2518. if current_row:
  2519. row_point_list.append(current_row)
  2520. # bbox以表格格式排列
  2521. used_bbox_list = []
  2522. row_list = []
  2523. row_cell_list = []
  2524. for j in range(1, len(row_point_list)):
  2525. last_row = row_point_list[j-1]
  2526. row = row_point_list[j]
  2527. col_list = []
  2528. col_cell_list = []
  2529. for k in range(1, len(row)):
  2530. last_p = last_row[k-1]
  2531. p = row[k]
  2532. cell = []
  2533. for bbox in bbox_list:
  2534. if bbox in used_bbox_list:
  2535. continue
  2536. bbox_h_center = (bbox[0][1]+bbox[2][1]) / 2
  2537. bbox_w_center = (bbox[0][0]+bbox[2][0]) / 2
  2538. if last_p[1] <= bbox_h_center <= p[1] and last_p[0] <= bbox_w_center <= p[0]:
  2539. cell.append(bbox)
  2540. used_bbox_list.append(bbox)
  2541. col_list.append(cell)
  2542. col_cell_list.append([last_p, p])
  2543. row_list.append(col_list)
  2544. row_cell_list.append(col_cell_list)
  2545. area_table_bbox_list.append(row_list)
  2546. area_table_cell_list.append(row_cell_list)
  2547. return area_table_bbox_list, area_table_cell_list
  2548. def get_lines_from_img(img):
  2549. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  2550. # 开操作提取水平线
  2551. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 1))
  2552. binary1 = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel)
  2553. # cv2.imshow('7,1', binary1)
  2554. # 开操作提取垂直线
  2555. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 7))
  2556. binary2 = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel)
  2557. # cv2.imshow('1,7', binary2)
  2558. #
  2559. # cv2.imshow('table', binary1+binary2)
  2560. # cv2.waitKey(0)
  2561. return binary1, binary2
  2562. def get_bbox_by_img(row_img, col_img):
  2563. # cv2.imshow('table', row_img+col_img)
  2564. # 求线交点
  2565. point_img = np.bitwise_and(row_img, col_img)
  2566. # cv2.imshow('point_img', point_img)
  2567. # cv2.waitKey(0)
  2568. # 识别黑白图中的白色交叉点,将横纵坐标取出
  2569. ys, xs = np.where(point_img > 200)
  2570. cross_points = []
  2571. for i in range(len(xs)):
  2572. cross_points.append((xs[i], ys[i]))
  2573. cross_points.sort(key=lambda x: (x[0], x[1]))
  2574. temp_cross_points = []
  2575. for p1 in cross_points:
  2576. find_flag = 0
  2577. for p2 in temp_cross_points:
  2578. if abs(p1[1] - p2[1]) <= 5 and abs(p1[0] - p2[0]) <= 5:
  2579. find_flag = 1
  2580. break
  2581. if not find_flag:
  2582. temp_cross_points.append(p1)
  2583. cross_points = temp_cross_points
  2584. if not cross_points:
  2585. return [], False
  2586. print('cross_points', len(cross_points))
  2587. axis = 0
  2588. # 交点分行
  2589. cross_points.sort(key=lambda x: (x[1-axis], x[axis]))
  2590. row_point_list = []
  2591. current_row = [cross_points[0]]
  2592. for p in cross_points[1:]:
  2593. if abs(current_row[0][1-axis] - p[1-axis]) <= 5:
  2594. current_row.append(p)
  2595. else:
  2596. current_row.sort(key=lambda x: x[0])
  2597. row_point_list.append(current_row)
  2598. current_row = [p]
  2599. if current_row:
  2600. current_row.sort(key=lambda x: x[0])
  2601. row_point_list.append(current_row)
  2602. row_len = len(row_point_list[0])
  2603. for row in row_point_list:
  2604. # print('row_point_list row', len(row), row)
  2605. if row_len != len(row):
  2606. return [], False
  2607. row_list = []
  2608. standard_flag = True
  2609. for j in range(1, len(row_point_list)):
  2610. last_row = row_point_list[j-1]
  2611. row = row_point_list[j]
  2612. sub_row = []
  2613. for k in range(1, len(row)):
  2614. if k-1 >= len(last_row):
  2615. # print(len(last_row), len(row))
  2616. standard_flag = False
  2617. break
  2618. last_p = last_row[k-1]
  2619. p = row[k]
  2620. bbox = [last_p, p]
  2621. sub_row.append(bbox)
  2622. row_list.append(sub_row)
  2623. if not row_list:
  2624. return [], False
  2625. row_len = len(row_list[0])
  2626. for row in row_list:
  2627. if len(row) != row_len:
  2628. standard_flag = False
  2629. break
  2630. print('standard_flag', standard_flag)
  2631. if standard_flag:
  2632. new_img = np.zeros((row_img.shape[0], row_img.shape[1], 3), dtype=np.uint8)
  2633. # for row in row_list:
  2634. # for b in row:
  2635. # print('b', b)
  2636. # cv2.rectangle(new_img, [int(b[0][0]), int(b[0][1])], [int(b[1][0]), int(b[1][1])],
  2637. # (0, 0, 255), 1)
  2638. # cv2.imshow('new_img', new_img)
  2639. # cv2.waitKey(0)
  2640. return row_list, standard_flag
  2641. def get_points_by_line(img, row_lines, col_lines):
  2642. row_img = np.zeros_like(img[:, :, 0], dtype=np.uint8)
  2643. col_img = np.zeros_like(img[:, :, 0], dtype=np.uint8)
  2644. for r in row_lines:
  2645. cv2.line(row_img, r[0], r[1], (255, 255, 255), 1)
  2646. for c in col_lines:
  2647. cv2.line(col_img, c[0], c[1], (255, 255, 255), 1)
  2648. point_img = np.bitwise_and(row_img, col_img)
  2649. # 识别黑白图中的白色交叉点,将横纵坐标取出
  2650. ys, xs = np.where(point_img > 0)
  2651. points = []
  2652. for i in range(len(xs)):
  2653. points.append((xs[i], ys[i]))
  2654. points.sort(key=lambda x: (x[0], x[1]))
  2655. return points
  2656. def merge_text_and_table(text_bbox_list, table_row_list):
  2657. used_bbox_list = []
  2658. row_list = []
  2659. for row in table_row_list:
  2660. sub_row_list = []
  2661. row.sort(key=lambda x: x[0][0])
  2662. for bbox1 in row:
  2663. sub_bbox_list = []
  2664. for bbox2 in text_bbox_list:
  2665. if bbox2 in used_bbox_list:
  2666. continue
  2667. bbox_h_center = (bbox2[0][1]+bbox2[2][1]) / 2
  2668. bbox_w_center = (bbox2[0][0]+bbox2[2][0]) / 2
  2669. if bbox1[0][1] <= bbox_h_center <= bbox1[1][1] and bbox1[0][0] <= bbox_w_center <= bbox1[1][0]:
  2670. sub_bbox_list.append(bbox2)
  2671. used_bbox_list.append(bbox2)
  2672. sub_row_list.append(sub_bbox_list)
  2673. row_list.append(sub_row_list)
  2674. return row_list
  2675. def shrink_bbox(img, bbox_list):
  2676. def return_first_black_index(image_np):
  2677. lower = np.array([0, 0, 0])
  2678. upper = np.array([150, 150, 150])
  2679. mask = cv2.inRange(image_np, lower, upper)
  2680. black_index_list = np.where(mask != 0)
  2681. return black_index_list
  2682. new_bbox_list = []
  2683. for bbox in bbox_list:
  2684. img_bbox = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
  2685. if 0 in img_bbox.shape:
  2686. new_bbox_list.append(bbox)
  2687. continue
  2688. # 左右上下开始扫描,碰到黑像素即停
  2689. index_list = return_first_black_index(img_bbox[:, :, :])
  2690. if index_list[0].size == 0 or index_list[1].size == 0:
  2691. new_bbox_list.append(bbox)
  2692. continue
  2693. min_h = index_list[0][0]
  2694. max_h = index_list[0][-1]
  2695. img_bbox1 = np.swapaxes(img_bbox, 0, 1)
  2696. index_list = return_first_black_index(img_bbox1[:, :, :])
  2697. if index_list[0].size == 0 or index_list[1].size == 0:
  2698. new_bbox_list.append(bbox)
  2699. continue
  2700. min_w = index_list[0][0]
  2701. max_w = index_list[0][-1]
  2702. real_min_w = bbox[0][0] + min_w
  2703. real_max_w = bbox[0][0] + max_w
  2704. real_min_h = bbox[0][1] + min_h
  2705. real_max_h = bbox[0][1] + max_h
  2706. new_bbox = [[real_min_w, real_min_h], [real_min_w, real_max_h], [real_max_w, real_max_h], [real_max_w, real_min_h]]
  2707. new_bbox_list.append(new_bbox)
  2708. # cv2.imshow('img', img_bbox)
  2709. # cv2.imshow('shrink', img[int(new_bbox[0][1]):int(new_bbox[2][1]), int(new_bbox[0][0]):int(new_bbox[2][0]), :])
  2710. # cv2.waitKey(0)
  2711. return new_bbox_list
  2712. def affinity_propagation(data_list):
  2713. """
  2714. 聚类:近邻传播
  2715. :return:
  2716. """
  2717. data_np = np.array(data_list)
  2718. random_state = 170
  2719. model = AffinityPropagation(damping=0.5, convergence_iter=15, random_state=random_state).fit(data_np)
  2720. # cluster_centers_indices = model.cluster_centers_indices_
  2721. y_pred = model.labels_
  2722. if y_pred[0] == -1:
  2723. print('ap dp0.5 ci50')
  2724. model = AffinityPropagation(convergence_iter=50, random_state=random_state).fit(data_np)
  2725. y_pred = model.labels_
  2726. if y_pred[0] == -1:
  2727. print('ap dp0.7 ci15')
  2728. model = AffinityPropagation(damping=0.7, convergence_iter=15, random_state=random_state).fit(data_np)
  2729. y_pred = model.labels_
  2730. if y_pred[0] == -1:
  2731. print('ap dp0.7 ci50')
  2732. model = AffinityPropagation(damping=0.7, convergence_iter=50, random_state=random_state).fit(data_np)
  2733. y_pred = model.labels_
  2734. if y_pred[0] == -1:
  2735. print('all -1')
  2736. y_pred = np.zeros(y_pred.shape[0])
  2737. y_pred = y_pred.tolist()
  2738. return y_pred
  2739. def dbscan(data_list):
  2740. """
  2741. 聚类:dbscan
  2742. :return:
  2743. """
  2744. data_np = np.array(data_list)
  2745. model = DBSCAN(eps=3, min_samples=2).fit(data_np)
  2746. y_pred = model.labels_
  2747. y_pred = y_pred.tolist()
  2748. return y_pred
  2749. def test_ocr_model(img_path):
  2750. with open(img_path, "rb") as f:
  2751. file_bytes = f.read()
  2752. file_base64 = base64.b64encode(file_bytes)
  2753. file_json = {"data": file_base64, "md5": 0}
  2754. _url = "http://192.168.2.103:17000/ocr"
  2755. # _url = "http://127.0.0.1:17000/ocr"
  2756. result = json.loads(request_post(_url, file_json))
  2757. return result
  2758. def test_cho_model(text):
  2759. # text = "巧克力"
  2760. text = [x for x in text]
  2761. data_json = {"data": json.dumps(text)}
  2762. _url = "http://192.168.2.103:17058/cho"
  2763. result = json.loads(request_post(_url, data_json))
  2764. if result.get("success"):
  2765. decode_list = result.get("data")
  2766. print("char_list", text)
  2767. print("decode_list", decode_list)
  2768. return decode_list
  2769. else:
  2770. print("failed!")
  2771. if __name__ == '__main__':
  2772. get_table_new()
  2773. # _l = [[18, 0], [0, 0], [14, 0], [0, 0], [12, 0], [0, 0], [14, 0], [2, 0], [15, 0], [0, 0]]
  2774. # # _l = [[27, 0], [26, 0], [17, 0]]
  2775. # print(affinity_propagation(_l))
  2776. # print(dbscan(_l))
  2777. # _img = cv2.imread(r'C:\Users\Administrator\Desktop\111.jpg')
  2778. # shrink_bbox(_img, [[[0, 0], [0, 0], [_img.shape[1], _img.shape[0]], [_img.shape[1], _img.shape[0]]]])