get_table_by_rules.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. import copy
  2. import re
  3. import cv2
  4. import numpy as np
  5. from botr.rules.table_utils import shrink_bbox, split_bbox, get_table_bbox_list, count_black
  6. from botr.utils import line_iou
  7. # from format_convert.utils import log
  8. def get_table_by_rule(img, text_list, bbox_list, table_location, show=0):
  9. if show:
  10. print('get_table_by_rule bbox_list', bbox_list)
  11. if not bbox_list:
  12. return [], [], [], {}
  13. if show:
  14. img_show = copy.deepcopy(img)
  15. img_result = copy.deepcopy(img)
  16. # 处理bbox,缩小框
  17. bbox_list = shrink_bbox(img, bbox_list)
  18. # 创建对应dict
  19. bbox_text_dict = {}
  20. temp_list = []
  21. for i in range(len(text_list)):
  22. # 排除text为空的
  23. if not text_list[i]:
  24. continue
  25. if re.sub(' ', '', text_list[i]) == '':
  26. continue
  27. # text中间为空格,其实是两列的
  28. match = re.search('[ ]{3,}', text_list[i])
  29. if match:
  30. # print(text_list[i][match.span()[1]:], re.match('[((]', text_list[i][match.span()[1]:]))
  31. text = text_list[i]
  32. bbox = bbox_list[i]
  33. blank_index = (match.span()[0] + match.span()[1]) / 2
  34. chinese_cnt = len(re.findall('[\u4e00-\u9fff()?。,!【】¥《》]', text)) * 1.5
  35. char_cnt = len(re.findall('[ .?!,+*&^%$#@~=:;/<>()a-zA-Z0-9{}]', text))
  36. # print(text, match.span()[0], match.span()[1], blank_index, chinese_cnt, char_cnt)
  37. char_cnt += chinese_cnt
  38. char_pixel = abs(bbox[0][0] - bbox[2][0]) / char_cnt
  39. index_pixel = char_pixel * blank_index
  40. # print(abs(bbox[0][0] - bbox[2][0]), char_cnt, char_pixel, index_pixel)
  41. bbox1 = [bbox[0], bbox[1], [bbox[0][0] + index_pixel, bbox[2][1]], bbox[3]]
  42. bbox1 = shrink_bbox(img, [bbox1])[0]
  43. text1 = text[:match.span()[0]]
  44. bbox2 = [[bbox[0][0]+index_pixel, bbox[0][1]], bbox[1], bbox[2], bbox[3]]
  45. bbox2 = shrink_bbox(img, [bbox2])[0]
  46. text2 = text[match.span()[1]:]
  47. if re.sub(' ', '', text1) != '':
  48. bbox_text_dict[str(bbox1)] = text1
  49. temp_list.append(bbox1)
  50. if re.sub(' ', '', text2) != '':
  51. bbox_text_dict[str(bbox2)] = text2
  52. temp_list.append(bbox2)
  53. # 正常的bbox
  54. else:
  55. bbox_text_dict[str(bbox_list[i])] = text_list[i]
  56. temp_list.append(bbox_list[i])
  57. bbox_list = temp_list
  58. if show:
  59. print('bbox_text_dict', bbox_text_dict)
  60. for bbox in bbox_list:
  61. cv2.rectangle(img_show, (int(bbox[0][0]), int(bbox[0][1])),
  62. (int(bbox[2][0]), int(bbox[2][1])), (255, 0, 0), 2)
  63. cv2.imshow('bbox_list', img_show)
  64. cv2.waitKey(0)
  65. # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox
  66. table_left_up_point = [table_location[0], table_location[1]]
  67. min_distance = 100000000000
  68. first_bbox = bbox_list[0]
  69. for bbox in bbox_list:
  70. distance = abs(bbox[0][0] - table_left_up_point[0]) + abs(bbox[0][1] - table_left_up_point[1])
  71. if distance < min_distance:
  72. min_distance = distance
  73. first_bbox = bbox
  74. # print('first_bbox', first_bbox, bbox_text_dict.get(str(first_bbox)))
  75. # # 对first_bbox预处理
  76. # # 分割
  77. # new_bbox_list, bbox_text_dict = split_bbox(img, first_bbox, bbox_text_dict)
  78. # if new_bbox_list:
  79. # if first_bbox in bbox_list:
  80. # bbox_list.remove(first_bbox)
  81. # bbox_list += new_bbox_list
  82. # new_bbox_list.sort(key=lambda x: (x[0][0]))
  83. # first_bbox = new_bbox_list[0]
  84. # 根据第一个bbox,得到第一行
  85. first_row = []
  86. bbox_list.sort(key=lambda x: (x[0][1], x[0][0]))
  87. for bbox in bbox_list:
  88. # h有交集
  89. if first_bbox[0][1] <= bbox[0][1] <= first_bbox[2][1] \
  90. or first_bbox[0][1] <= bbox[2][1] <= first_bbox[2][1] \
  91. or bbox[0][1] <= first_bbox[0][1] <= bbox[2][1] \
  92. or bbox[0][1] <= first_bbox[2][1] <= bbox[2][1]:
  93. first_row.append(bbox)
  94. # h小于first_box
  95. elif bbox[2][1] <= first_bbox[0][1]:
  96. first_row.append(bbox)
  97. # 对第一行分列
  98. first_row.sort(key=lambda x: (x[0][0], x[0][1]))
  99. first_row_col = []
  100. used_bbox = []
  101. for bbox in first_row:
  102. if bbox in used_bbox:
  103. continue
  104. temp_col = []
  105. for bbox1 in first_row:
  106. if bbox1 in used_bbox:
  107. continue
  108. if bbox1[0][0] <= bbox[0][0] <= bbox1[2][0] \
  109. or bbox1[0][0] <= bbox[2][0] <= bbox1[2][0] \
  110. or bbox[0][0] <= bbox1[0][0] <= bbox[2][0] \
  111. or bbox[0][0] <= bbox1[2][0] <= bbox[2][0]:
  112. temp_col.append(bbox1)
  113. used_bbox.append(bbox1)
  114. first_row_col.append(temp_col)
  115. # 根据第一个bbox,得到第一列
  116. first_col = []
  117. bbox_list.sort(key=lambda x: (x[0][0], x[0][1]))
  118. for bbox in bbox_list:
  119. # w有交集
  120. if first_bbox[0][0] <= bbox[0][0] <= first_bbox[2][0] \
  121. or first_bbox[0][0] <= bbox[2][0] <= first_bbox[2][0] \
  122. or bbox[0][0] <= first_bbox[0][0] <= bbox[2][0] \
  123. or bbox[0][0] <= first_bbox[2][0] <= bbox[2][0]:
  124. first_col.append(bbox)
  125. # w小于first_box
  126. elif bbox[2][0] <= first_bbox[0][0]:
  127. first_col.append(bbox)
  128. # 对第一列分行
  129. first_col.sort(key=lambda x: (x[0][1], x[0][0]))
  130. first_col_row = []
  131. current_bbox = first_col[0]
  132. temp_row = []
  133. for bbox in first_col:
  134. if current_bbox[0][1] <= bbox[0][1] <= current_bbox[2][1] \
  135. or current_bbox[0][1] <= bbox[2][1] <= current_bbox[2][1] \
  136. or bbox[0][1] <= current_bbox[0][1] <= bbox[2][1] \
  137. or bbox[0][1] <= current_bbox[2][1] <= bbox[2][1]:
  138. temp_row.append(bbox)
  139. else:
  140. if temp_row:
  141. temp_row.sort(key=lambda x: x[0][1])
  142. first_col_row.append(temp_row)
  143. temp_row = [bbox]
  144. current_bbox = bbox
  145. if temp_row:
  146. temp_row.sort(key=lambda x: x[0][1])
  147. first_col_row.append(temp_row)
  148. if show:
  149. print('len(first_row)', len(first_row))
  150. print('first_row', [bbox_text_dict.get(str(x)) for x in first_row])
  151. print('first_col', [bbox_text_dict.get(str(x)) for x in first_col])
  152. print('len(first_col)', len(first_col))
  153. print('len(first_row_col)', len(first_row_col))
  154. print('len(first_col_row)', len(first_col_row))
  155. # 划线 列
  156. col_line_list = []
  157. for col in first_row_col:
  158. # 画2条线,根据左右bbox
  159. min_w, max_w = 1000000, 0
  160. # print('col', [bbox_text_dict.get(str(x)) for x in col])
  161. for bbox in col:
  162. if bbox[0][0] < min_w:
  163. min_w = bbox[0][0]
  164. if bbox[2][0] > max_w:
  165. max_w = bbox[2][0]
  166. col_line_list.append([min_w, table_location[1], min_w, table_location[3]])
  167. col_line_list.append([max_w, table_location[1], max_w, table_location[3]])
  168. # 划线 行
  169. row_line_list = []
  170. last_max_h = None
  171. for row in first_col_row:
  172. # 画3条线,根据上下bbox
  173. min_h, max_h = 1000000, 0
  174. for bbox in row:
  175. if bbox[0][1] < min_h:
  176. min_h = bbox[0][1]
  177. if bbox[2][1] > max_h:
  178. max_h = bbox[2][1]
  179. row_line_list.append([table_location[0], min_h, table_location[2], min_h])
  180. row_line_list.append([table_location[0], max_h, table_location[2], max_h])
  181. # if last_max_h:
  182. # row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)])
  183. last_max_h = max_h
  184. if show:
  185. print('len(col_line_list)', len(col_line_list))
  186. print('col_line_list', col_line_list)
  187. print('len(row_line_list)', len(row_line_list))
  188. # 判断列线有没有压在黑色像素上,若有则移动
  189. temp_list = []
  190. for i in range(1, len(col_line_list), 2):
  191. # 前一列右边线
  192. line1 = col_line_list[i]
  193. line1 = [int(x) for x in line1]
  194. # 后一列左边线
  195. if i+1 >= len(col_line_list):
  196. break
  197. line2 = col_line_list[i+1]
  198. line2 = [int(x) for x in line2]
  199. max_black_cnt = 10
  200. black_threshold = 150
  201. black_cnt1 = count_black(img[line1[1]:line1[3], line1[0]:line1[2]+1, :], threshold=black_threshold)
  202. black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold)
  203. # print('col black_cnt1', i, black_cnt1)
  204. # print('col black_cnt2', i, black_cnt2)
  205. # if black_cnt2 <= max_black_cnt and black_cnt1 <= max_black_cnt:
  206. # if black_cnt1 >= black_cnt2:
  207. # temp_list.append(line2)
  208. # else:
  209. # temp_list.append(line1)
  210. # elif black_cnt2 <= max_black_cnt:
  211. # temp_list.append(line2)
  212. # elif black_cnt1 <= max_black_cnt:
  213. # temp_list.append(line1)
  214. # 两条线都不符合
  215. # else:
  216. # 先找出最近的bbox,不能跨bbox
  217. min_distance = 100000
  218. min_dis_bbox = bbox_list[0]
  219. # for bbox in bbox_list:
  220. for bbox in first_col_row[0]:
  221. if bbox[2][0] < line2[0]:
  222. _dis = line2[0] - bbox[2][0]
  223. if _dis < min_distance:
  224. min_distance = _dis
  225. min_dis_bbox = bbox
  226. # 从右向左移寻找
  227. right_left_index_list = []
  228. right_left_cnt_list = []
  229. find_flag = False
  230. for j in range(line2[0], int(min_dis_bbox[2][0]), -1):
  231. # 需连续3个像素列满足要求
  232. if len(right_left_index_list) == 3:
  233. find_flag = True
  234. break
  235. black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold)
  236. # print('col black_cnt', black_cnt)
  237. right_left_cnt_list.append(black_cnt)
  238. # 直接找到无黑色像素的
  239. if black_cnt == 0:
  240. right_left_index_list.append(j)
  241. else:
  242. right_left_index_list = []
  243. if show:
  244. print('find_flag', find_flag)
  245. if find_flag:
  246. temp_list.append([right_left_index_list[1], line2[1], right_left_index_list[1], line2[3]])
  247. else:
  248. # 为0的找不到,就找最小的
  249. # 每个位置加上前后n位求平均
  250. n = 1
  251. min_cnt = 1000000.
  252. min_cnt_index = 0
  253. for j, cnt in enumerate(right_left_cnt_list):
  254. if show:
  255. print('min_cnt', min_cnt)
  256. if j < n or j > len(right_left_cnt_list) - 1 - n:
  257. continue
  258. # 小到一定程度提前结束
  259. if min_cnt <= 0.001:
  260. break
  261. last_cnt = right_left_cnt_list[j-1]
  262. next_cnt = right_left_cnt_list[j+1]
  263. avg_cnt = (last_cnt + cnt + next_cnt) / 3
  264. if avg_cnt < min_cnt:
  265. min_cnt = avg_cnt
  266. min_cnt_index = j
  267. min_cnt_index = line2[0] - min_cnt_index
  268. temp_list.append([min_cnt_index, line2[1], min_cnt_index, line2[3]])
  269. col_line_list = temp_list
  270. if show:
  271. print('len(col_line_list)', len(col_line_list))
  272. for col in col_line_list:
  273. col = [int(x) for x in col]
  274. cv2.line(img_show, col[:2], col[2:4], (0, 255, 0), 2)
  275. cv2.imshow('col_line_list', img_show)
  276. cv2.waitKey(0)
  277. # 根据列的划线对bbox分列
  278. last_line = [0, 0, 0, 0]
  279. col_bbox_list = []
  280. for line in col_line_list + [[img.shape[0], 0, img.shape[0], 0]]:
  281. col = []
  282. for bbox in bbox_list:
  283. iou = line_iou([[last_line[0], 0], [line[0], 0]], [[bbox[0][0], 0], [bbox[2][0], 0]], axis=0)
  284. if iou >= 0.6:
  285. col.append(bbox)
  286. col.sort(key=lambda x: x[0][1])
  287. col_bbox_list.append(col)
  288. last_line = line
  289. # 判断行线
  290. temp_list = []
  291. for i in range(1, len(row_line_list), 2):
  292. # 前一行下边线
  293. line1 = row_line_list[i]
  294. line1 = [int(x) for x in line1]
  295. # 后一行上边线
  296. if i+1 >= len(row_line_list):
  297. break
  298. line2 = row_line_list[i+1]
  299. line2 = [int(x) for x in line2]
  300. # 判断行线之间的bbox分别属于哪一行
  301. sub_bbox_list = []
  302. threshold = 5
  303. for bbox in bbox_list:
  304. if line1[1] - threshold <= bbox[0][1] <= bbox[2][1] <= line2[1]+threshold:
  305. sub_bbox_list.append(bbox)
  306. # 根据行的h和分列判断bbox属于上一行还是下一行
  307. line1_bbox_list = []
  308. line2_bbox_list = []
  309. if sub_bbox_list:
  310. sub_bbox_list.sort(key=lambda x: x[0][1])
  311. min_h = sub_bbox_list[0][0][1] - 1
  312. max_h = sub_bbox_list[-1][2][1] + 1
  313. for bbox in sub_bbox_list:
  314. # 找到属于哪一列
  315. current_col = None
  316. for col in col_bbox_list:
  317. if bbox in col:
  318. current_col = copy.deepcopy(col)
  319. break
  320. if current_col:
  321. # 行做成bbox加入列作为基准
  322. line1_bbox = [[0, min_h], [], [0, min_h], []]
  323. line2_bbox = [[0, max_h], [], [0, max_h], []]
  324. current_col += [line1_bbox, line2_bbox]
  325. current_col.sort(key=lambda x: x[0][1])
  326. bbox_index = current_col.index(bbox)
  327. line1_bbox_index = current_col.index(line1_bbox)
  328. line2_bbox_index = current_col.index(line2_bbox)
  329. # print('current_col', [bbox_text_dict.get(str(x)) for x in current_col])
  330. # print('line1_bbox_index, bbox_index, line2_bbox_index', line1_bbox_index, bbox_index, line2_bbox_index)
  331. # 计算距离
  332. distance1 = 10000
  333. for index in range(line1_bbox_index, bbox_index):
  334. h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
  335. h2 = (current_col[index+1][0][1] + current_col[index+1][2][1]) / 2
  336. # print(bbox_text_dict.get())
  337. distance1 = abs(h1 - h2)
  338. distance2 = 10000
  339. for index in range(line2_bbox_index, bbox_index, -1):
  340. h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
  341. h2 = (current_col[index-1][0][1] + current_col[index-1][2][1]) / 2
  342. distance2 = abs(h1 - h2)
  343. # print(bbox_text_dict.get(str(bbox)), distance1, distance2)
  344. ratio = 1.5
  345. # 属于下一行
  346. if distance1 >= distance2 * ratio or distance1 >= distance2 + 8:
  347. line2_bbox_list.append(bbox)
  348. # 属于上一行
  349. elif distance2 >= distance1 * ratio or distance2 >= distance1 + 8:
  350. line1_bbox_list.append(bbox)
  351. else:
  352. print('距离不明确,需要nsp模型介入判断')
  353. if line1_bbox_list:
  354. # print('line1_bbox_list', [bbox_text_dict.get(str(x)) for x in line1_bbox_list])
  355. line1_bbox_list.sort(key=lambda x: x[0][1])
  356. b = line1_bbox_list[-1]
  357. line1 = [line1[0], b[2][1], line1[2], b[2][1]]
  358. if line2_bbox_list:
  359. # print('line2_bbox_list', [bbox_text_dict.get(str(x)) for x in line2_bbox_list])
  360. line2_bbox_list.sort(key=lambda x: x[0][1])
  361. b = line2_bbox_list[0]
  362. line2 = [line2[0], b[0][1], line2[2], b[0][1]]
  363. _line = [line1[0], (line1[1]+line2[1])/2, line1[2], (line1[3]+line2[3])/2]
  364. _line = [int(x) for x in _line]
  365. temp_list.append(_line)
  366. row_line_list = temp_list
  367. if show:
  368. print('len(row_line_list)', len(row_line_list))
  369. print('len(col_line_list)', len(col_line_list))
  370. # 只有一行或一列的直接跳过
  371. if len(row_line_list) < 1 or len(col_line_list) < 1:
  372. return [], [], [], {}
  373. # 加上表格轮廓线
  374. threshold = 5
  375. min_w = max(table_location[0], 0+threshold)
  376. max_w = min(table_location[2], img.shape[1]-threshold)
  377. min_h = max(table_location[1], 0+threshold)
  378. max_h = min(table_location[3], img.shape[0]-threshold)
  379. row_line_list.append([min_w, min_h, max_w, min_h])
  380. row_line_list.append([min_w, max_h, max_w, max_h])
  381. col_line_list.append([min_w, min_h, min_w, max_h])
  382. col_line_list.append([max_w, min_h, max_w, max_h])
  383. # # 行线、列线两两之间没有bbox则合并
  384. # col_line_list.sort(key=lambda x: x[0])
  385. # temp_list = []
  386. # used_bbox_list = []
  387. # last_col = col_line_list[0]
  388. # for col in col_line_list[1:]:
  389. # find_flag = False
  390. # for bbox in bbox_list:
  391. # if bbox in used_bbox_list:
  392. # continue
  393. # if last_col[0] <= (bbox[0][0] + bbox[2][0]) / 2 <= col[0]:
  394. # print('bbox', bbox, bbox_text_dict.get(str(bbox)))
  395. # used_bbox_list.append(bbox)
  396. # find_flag = True
  397. # break
  398. # print('last_col, col, find_flag', last_col, col, find_flag)
  399. # if not find_flag:
  400. # new_w = int((last_col[0] + col[0])/2)
  401. # temp_list.append([new_w, col[1], new_w, col[3]])
  402. # else:
  403. # temp_list.append(last_col)
  404. # last_col = col
  405. # if find_flag:
  406. # temp_list.append(col_line_list[-1])
  407. # col_line_list = temp_list
  408. # 由线得到按行列排列的bbox
  409. row_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in row_line_list]
  410. col_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in col_line_list]
  411. table_bbox_list, table_cell_list = get_table_bbox_list(img, [row_line_list], [col_line_list], [table_location], bbox_list)
  412. # 线合并
  413. line_list = row_line_list + col_line_list
  414. # show
  415. if show:
  416. for r in table_cell_list:
  417. for c in r:
  418. cv2.rectangle(img_result, c[0], c[1], (0, 255, 0), 1)
  419. cv2.namedWindow('table_cell', cv2.WINDOW_NORMAL)
  420. cv2.imshow('table_cell', img_result)
  421. for line in col_line_list:
  422. cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2)
  423. for line in row_line_list:
  424. cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2)
  425. cv2.namedWindow('img', cv2.WINDOW_NORMAL)
  426. cv2.imshow('img', cv2.resize(img_result, (768, 1024)))
  427. cv2.waitKey(0)
  428. return line_list, table_cell_list, table_location, bbox_text_dict