get_table_by_rules.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import copy
  2. import re
  3. import cv2
  4. import numpy as np
  5. from botr.rules.table_utils import shrink_bbox, split_bbox, get_table_bbox_list, count_black
  6. from botr.utils import line_iou
  7. # from format_convert.utils import log
  8. def get_table_by_rule(img, text_list, bbox_list, table_location, show=0):
  9. if show:
  10. print('get_table_by_rule bbox_list', bbox_list)
  11. if not bbox_list:
  12. return [], [], [], {}
  13. if show:
  14. img_show = copy.deepcopy(img)
  15. img_result = copy.deepcopy(img)
  16. # 处理bbox,缩小框
  17. bbox_list = shrink_bbox(img, bbox_list)
  18. # 创建对应dict
  19. bbox_text_dict = {}
  20. temp_list = []
  21. for i in range(len(text_list)):
  22. # 排除text为空的
  23. if not text_list[i]:
  24. continue
  25. if re.sub(' ', '', text_list[i]) == '':
  26. continue
  27. # text中间为空格,其实是两列的
  28. match = re.search('[ ]{3,}', text_list[i])
  29. if match:
  30. # print(text_list[i][match.span()[1]:], re.match('[((]', text_list[i][match.span()[1]:]))
  31. text = text_list[i]
  32. bbox = bbox_list[i]
  33. blank_index = (match.span()[0] + match.span()[1]) / 2
  34. chinese_cnt = len(re.findall('[\u4e00-\u9fff()?。,!【】¥《》]', text)) * 1.5
  35. char_cnt = len(re.findall('[ .?!,+*&^%$#@~=:;/<>()a-zA-Z0-9{}]', text))
  36. # print(text, match.span()[0], match.span()[1], blank_index, chinese_cnt, char_cnt)
  37. char_cnt += chinese_cnt
  38. char_pixel = abs(bbox[0][0] - bbox[2][0]) / char_cnt
  39. index_pixel = char_pixel * blank_index
  40. # print(abs(bbox[0][0] - bbox[2][0]), char_cnt, char_pixel, index_pixel)
  41. bbox1 = [bbox[0], bbox[1], [bbox[0][0] + index_pixel, bbox[2][1]], bbox[3]]
  42. bbox1 = shrink_bbox(img, [bbox1])[0]
  43. text1 = text[:match.span()[0]]
  44. bbox2 = [[bbox[0][0]+index_pixel, bbox[0][1]], bbox[1], bbox[2], bbox[3]]
  45. bbox2 = shrink_bbox(img, [bbox2])[0]
  46. text2 = text[match.span()[1]:]
  47. if re.sub(' ', '', text1) != '':
  48. bbox_text_dict[str(bbox1)] = text1
  49. temp_list.append(bbox1)
  50. if re.sub(' ', '', text2) != '':
  51. bbox_text_dict[str(bbox2)] = text2
  52. temp_list.append(bbox2)
  53. # 正常的bbox
  54. else:
  55. bbox_text_dict[str(bbox_list[i])] = text_list[i]
  56. temp_list.append(bbox_list[i])
  57. bbox_list = temp_list
  58. if show:
  59. print('bbox_text_dict', bbox_text_dict)
  60. for bbox in bbox_list:
  61. cv2.rectangle(img_show, (int(bbox[0][0]), int(bbox[0][1])),
  62. (int(bbox[2][0]), int(bbox[2][1])), (255, 0, 0), 2)
  63. cv2.namedWindow('bbox_list', cv2.WINDOW_NORMAL)
  64. cv2.imshow('bbox_list', img_show)
  65. cv2.waitKey(0)
  66. # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox
  67. table_left_up_point = [table_location[0], table_location[1]]
  68. min_distance = 100000000000
  69. if not bbox_list:
  70. return [], [], [], {}
  71. first_bbox = bbox_list[0]
  72. for bbox in bbox_list:
  73. distance = abs(bbox[0][0] - table_left_up_point[0]) + abs(bbox[0][1] - table_left_up_point[1])
  74. if distance < min_distance:
  75. min_distance = distance
  76. first_bbox = bbox
  77. # print('first_bbox', first_bbox, bbox_text_dict.get(str(first_bbox)))
  78. # # 对first_bbox预处理
  79. # # 分割
  80. # new_bbox_list, bbox_text_dict = split_bbox(img, first_bbox, bbox_text_dict)
  81. # if new_bbox_list:
  82. # if first_bbox in bbox_list:
  83. # bbox_list.remove(first_bbox)
  84. # bbox_list += new_bbox_list
  85. # new_bbox_list.sort(key=lambda x: (x[0][0]))
  86. # first_bbox = new_bbox_list[0]
  87. # 根据第一个bbox,得到第一行
  88. first_row = []
  89. bbox_list.sort(key=lambda x: (x[0][1], x[0][0]))
  90. for bbox in bbox_list:
  91. # h有交集
  92. if first_bbox[0][1] <= bbox[0][1] <= first_bbox[2][1] \
  93. or first_bbox[0][1] <= bbox[2][1] <= first_bbox[2][1] \
  94. or bbox[0][1] <= first_bbox[0][1] <= bbox[2][1] \
  95. or bbox[0][1] <= first_bbox[2][1] <= bbox[2][1]:
  96. first_row.append(bbox)
  97. # h小于first_box
  98. elif bbox[2][1] <= first_bbox[0][1]:
  99. first_row.append(bbox)
  100. # 对第一行分列
  101. first_row.sort(key=lambda x: (x[0][0], x[0][1]))
  102. first_row_col = []
  103. used_bbox = []
  104. for bbox in first_row:
  105. if bbox in used_bbox:
  106. continue
  107. temp_col = []
  108. for bbox1 in first_row:
  109. if bbox1 in used_bbox:
  110. continue
  111. if bbox1[0][0] <= bbox[0][0] <= bbox1[2][0] \
  112. or bbox1[0][0] <= bbox[2][0] <= bbox1[2][0] \
  113. or bbox[0][0] <= bbox1[0][0] <= bbox[2][0] \
  114. or bbox[0][0] <= bbox1[2][0] <= bbox[2][0]:
  115. temp_col.append(bbox1)
  116. used_bbox.append(bbox1)
  117. first_row_col.append(temp_col)
  118. # 根据第一个bbox,得到第一列
  119. first_col = []
  120. bbox_list.sort(key=lambda x: (x[0][0], x[0][1]))
  121. for bbox in bbox_list:
  122. # w有交集
  123. if first_bbox[0][0] <= bbox[0][0] <= first_bbox[2][0] \
  124. or first_bbox[0][0] <= bbox[2][0] <= first_bbox[2][0] \
  125. or bbox[0][0] <= first_bbox[0][0] <= bbox[2][0] \
  126. or bbox[0][0] <= first_bbox[2][0] <= bbox[2][0]:
  127. first_col.append(bbox)
  128. # w小于first_box
  129. elif bbox[2][0] <= first_bbox[0][0]:
  130. first_col.append(bbox)
  131. # 对第一列分行
  132. first_col.sort(key=lambda x: (x[0][1], x[0][0]))
  133. first_col_row = []
  134. current_bbox = first_col[0]
  135. temp_row = []
  136. for bbox in first_col:
  137. if current_bbox[0][1] <= bbox[0][1] <= current_bbox[2][1] \
  138. or current_bbox[0][1] <= bbox[2][1] <= current_bbox[2][1] \
  139. or bbox[0][1] <= current_bbox[0][1] <= bbox[2][1] \
  140. or bbox[0][1] <= current_bbox[2][1] <= bbox[2][1]:
  141. temp_row.append(bbox)
  142. else:
  143. if temp_row:
  144. temp_row.sort(key=lambda x: x[0][1])
  145. first_col_row.append(temp_row)
  146. temp_row = [bbox]
  147. current_bbox = bbox
  148. if temp_row:
  149. temp_row.sort(key=lambda x: x[0][1])
  150. first_col_row.append(temp_row)
  151. if show:
  152. print('len(first_row)', len(first_row))
  153. print('first_row', [bbox_text_dict.get(str(x)) for x in first_row])
  154. print('first_col', [bbox_text_dict.get(str(x)) for x in first_col])
  155. print('len(first_col)', len(first_col))
  156. print('len(first_row_col)', len(first_row_col))
  157. print('len(first_col_row)', len(first_col_row))
  158. # 划线 列
  159. col_line_list = []
  160. for col in first_row_col:
  161. # 画2条线,根据左右bbox
  162. min_w, max_w = 1000000, 0
  163. # print('col', [bbox_text_dict.get(str(x)) for x in col])
  164. for bbox in col:
  165. if bbox[0][0] < min_w:
  166. min_w = bbox[0][0]
  167. if bbox[2][0] > max_w:
  168. max_w = bbox[2][0]
  169. col_line_list.append([min_w, table_location[1], min_w, table_location[3]])
  170. col_line_list.append([max_w, table_location[1], max_w, table_location[3]])
  171. # 划线 行
  172. row_line_list = []
  173. last_max_h = None
  174. for row in first_col_row:
  175. # 画3条线,根据上下bbox
  176. min_h, max_h = 1000000, 0
  177. for bbox in row:
  178. if bbox[0][1] < min_h:
  179. min_h = bbox[0][1]
  180. if bbox[2][1] > max_h:
  181. max_h = bbox[2][1]
  182. row_line_list.append([table_location[0], min_h, table_location[2], min_h])
  183. row_line_list.append([table_location[0], max_h, table_location[2], max_h])
  184. # if last_max_h:
  185. # row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)])
  186. last_max_h = max_h
  187. if show:
  188. print('len(col_line_list)', len(col_line_list))
  189. print('col_line_list', col_line_list)
  190. print('len(row_line_list)', len(row_line_list))
  191. # 判断列线有没有压在黑色像素上,若有则移动
  192. temp_list = []
  193. for i in range(1, len(col_line_list), 2):
  194. # 前一列右边线
  195. line1 = col_line_list[i]
  196. line1 = [int(x) for x in line1]
  197. # 后一列左边线
  198. if i+1 >= len(col_line_list):
  199. break
  200. line2 = col_line_list[i+1]
  201. line2 = [int(x) for x in line2]
  202. max_black_cnt = 10
  203. black_threshold = 150
  204. black_cnt1 = count_black(img[line1[1]:line1[3], line1[0]:line1[2]+1, :], threshold=black_threshold)
  205. black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold)
  206. # print('col black_cnt1', i, black_cnt1)
  207. # print('col black_cnt2', i, black_cnt2)
  208. # if black_cnt2 <= max_black_cnt and black_cnt1 <= max_black_cnt:
  209. # if black_cnt1 >= black_cnt2:
  210. # temp_list.append(line2)
  211. # else:
  212. # temp_list.append(line1)
  213. # elif black_cnt2 <= max_black_cnt:
  214. # temp_list.append(line2)
  215. # elif black_cnt1 <= max_black_cnt:
  216. # temp_list.append(line1)
  217. # 两条线都不符合
  218. # else:
  219. # 先找出最近的bbox,不能跨bbox
  220. min_distance = 100000
  221. min_dis_bbox = bbox_list[0]
  222. # for bbox in bbox_list:
  223. for bbox in first_col_row[0]:
  224. if bbox[2][0] < line2[0]:
  225. _dis = line2[0] - bbox[2][0]
  226. if _dis < min_distance:
  227. min_distance = _dis
  228. min_dis_bbox = bbox
  229. # 从右向左移寻找
  230. right_left_index_list = []
  231. right_left_cnt_list = []
  232. find_flag = False
  233. for j in range(line2[0], int(min_dis_bbox[2][0]), -1):
  234. # 需连续3个像素列满足要求
  235. if len(right_left_index_list) == 3:
  236. find_flag = True
  237. break
  238. black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold)
  239. # print('col black_cnt', black_cnt)
  240. right_left_cnt_list.append(black_cnt)
  241. # 直接找到无黑色像素的
  242. if black_cnt == 0:
  243. right_left_index_list.append(j)
  244. else:
  245. right_left_index_list = []
  246. if show:
  247. print('find_flag', find_flag)
  248. if find_flag:
  249. temp_list.append([right_left_index_list[1], line2[1], right_left_index_list[1], line2[3]])
  250. else:
  251. # 为0的找不到,就找最小的
  252. # 每个位置加上前后n位求平均
  253. n = 1
  254. min_cnt = 1000000.
  255. min_cnt_index = 0
  256. for j, cnt in enumerate(right_left_cnt_list):
  257. if show:
  258. print('min_cnt', min_cnt)
  259. if j < n or j > len(right_left_cnt_list) - 1 - n:
  260. continue
  261. # 小到一定程度提前结束
  262. if min_cnt <= 0.001:
  263. break
  264. last_cnt = right_left_cnt_list[j-1]
  265. next_cnt = right_left_cnt_list[j+1]
  266. avg_cnt = (last_cnt + cnt + next_cnt) / 3
  267. if avg_cnt < min_cnt:
  268. min_cnt = avg_cnt
  269. min_cnt_index = j
  270. min_cnt_index = line2[0] - min_cnt_index
  271. temp_list.append([min_cnt_index, line2[1], min_cnt_index, line2[3]])
  272. col_line_list = temp_list
  273. if show:
  274. print('len(col_line_list)', len(col_line_list))
  275. for col in col_line_list:
  276. col = [int(x) for x in col]
  277. cv2.line(img_show, col[:2], col[2:4], (0, 255, 0), 2)
  278. cv2.imshow('col_line_list', img_show)
  279. cv2.waitKey(0)
  280. # 根据列的划线对bbox分列
  281. last_line = [0, 0, 0, 0]
  282. col_bbox_list = []
  283. for line in col_line_list + [[img.shape[0], 0, img.shape[0], 0]]:
  284. col = []
  285. for bbox in bbox_list:
  286. iou = line_iou([[last_line[0], 0], [line[0], 0]], [[bbox[0][0], 0], [bbox[2][0], 0]], axis=0)
  287. if iou >= 0.6:
  288. col.append(bbox)
  289. col.sort(key=lambda x: x[0][1])
  290. col_bbox_list.append(col)
  291. last_line = line
  292. # 判断行线
  293. temp_list = []
  294. for i in range(1, len(row_line_list), 2):
  295. # 前一行下边线
  296. line1 = row_line_list[i]
  297. line1 = [int(x) for x in line1]
  298. # 后一行上边线
  299. if i+1 >= len(row_line_list):
  300. break
  301. line2 = row_line_list[i+1]
  302. line2 = [int(x) for x in line2]
  303. # 判断行线之间的bbox分别属于哪一行
  304. sub_bbox_list = []
  305. threshold = 5
  306. for bbox in bbox_list:
  307. if line1[1] - threshold <= bbox[0][1] <= bbox[2][1] <= line2[1]+threshold:
  308. sub_bbox_list.append(bbox)
  309. # 根据行的h和分列判断bbox属于上一行还是下一行
  310. line1_bbox_list = []
  311. line2_bbox_list = []
  312. if sub_bbox_list:
  313. sub_bbox_list.sort(key=lambda x: x[0][1])
  314. min_h = sub_bbox_list[0][0][1] - 1
  315. max_h = sub_bbox_list[-1][2][1] + 1
  316. for bbox in sub_bbox_list:
  317. # 找到属于哪一列
  318. current_col = None
  319. for col in col_bbox_list:
  320. if bbox in col:
  321. current_col = copy.deepcopy(col)
  322. break
  323. if current_col:
  324. # 行做成bbox加入列作为基准
  325. line1_bbox = [[0, min_h], [], [0, min_h], []]
  326. line2_bbox = [[0, max_h], [], [0, max_h], []]
  327. current_col += [line1_bbox, line2_bbox]
  328. current_col.sort(key=lambda x: x[0][1])
  329. bbox_index = current_col.index(bbox)
  330. line1_bbox_index = current_col.index(line1_bbox)
  331. line2_bbox_index = current_col.index(line2_bbox)
  332. # print('current_col', [bbox_text_dict.get(str(x)) for x in current_col])
  333. # print('line1_bbox_index, bbox_index, line2_bbox_index', line1_bbox_index, bbox_index, line2_bbox_index)
  334. # 计算距离
  335. distance1 = 10000
  336. for index in range(line1_bbox_index, bbox_index):
  337. h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
  338. h2 = (current_col[index+1][0][1] + current_col[index+1][2][1]) / 2
  339. # print(bbox_text_dict.get())
  340. distance1 = abs(h1 - h2)
  341. distance2 = 10000
  342. for index in range(line2_bbox_index, bbox_index, -1):
  343. h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
  344. h2 = (current_col[index-1][0][1] + current_col[index-1][2][1]) / 2
  345. distance2 = abs(h1 - h2)
  346. # print(bbox_text_dict.get(str(bbox)), distance1, distance2)
  347. ratio = 1.5
  348. # 属于下一行
  349. if distance1 >= distance2 * ratio or distance1 >= distance2 + 8:
  350. line2_bbox_list.append(bbox)
  351. # 属于上一行
  352. elif distance2 >= distance1 * ratio or distance2 >= distance1 + 8:
  353. line1_bbox_list.append(bbox)
  354. else:
  355. print('距离不明确,需要nsp模型介入判断')
  356. if line1_bbox_list:
  357. # print('line1_bbox_list', [bbox_text_dict.get(str(x)) for x in line1_bbox_list])
  358. line1_bbox_list.sort(key=lambda x: x[0][1])
  359. b = line1_bbox_list[-1]
  360. line1 = [line1[0], b[2][1], line1[2], b[2][1]]
  361. if line2_bbox_list:
  362. # print('line2_bbox_list', [bbox_text_dict.get(str(x)) for x in line2_bbox_list])
  363. line2_bbox_list.sort(key=lambda x: x[0][1])
  364. b = line2_bbox_list[0]
  365. line2 = [line2[0], b[0][1], line2[2], b[0][1]]
  366. _line = [line1[0], (line1[1]+line2[1])/2, line1[2], (line1[3]+line2[3])/2]
  367. _line = [int(x) for x in _line]
  368. temp_list.append(_line)
  369. row_line_list = temp_list
  370. if show:
  371. print('len(row_line_list)', len(row_line_list))
  372. print('len(col_line_list)', len(col_line_list))
  373. # 只有一行或一列的直接跳过
  374. if len(row_line_list) < 1 or len(col_line_list) < 1:
  375. return [], [], [], {}
  376. # 加上表格轮廓线
  377. threshold = 5
  378. min_w = max(table_location[0], 0+threshold)
  379. max_w = min(table_location[2], img.shape[1]-threshold)
  380. min_h = max(table_location[1], 0+threshold)
  381. max_h = min(table_location[3], img.shape[0]-threshold)
  382. row_line_list.append([min_w, min_h, max_w, min_h])
  383. row_line_list.append([min_w, max_h, max_w, max_h])
  384. col_line_list.append([min_w, min_h, min_w, max_h])
  385. col_line_list.append([max_w, min_h, max_w, max_h])
  386. # # 行线、列线两两之间没有bbox则合并
  387. # col_line_list.sort(key=lambda x: x[0])
  388. # temp_list = []
  389. # used_bbox_list = []
  390. # last_col = col_line_list[0]
  391. # for col in col_line_list[1:]:
  392. # find_flag = False
  393. # for bbox in bbox_list:
  394. # if bbox in used_bbox_list:
  395. # continue
  396. # if last_col[0] <= (bbox[0][0] + bbox[2][0]) / 2 <= col[0]:
  397. # print('bbox', bbox, bbox_text_dict.get(str(bbox)))
  398. # used_bbox_list.append(bbox)
  399. # find_flag = True
  400. # break
  401. # print('last_col, col, find_flag', last_col, col, find_flag)
  402. # if not find_flag:
  403. # new_w = int((last_col[0] + col[0])/2)
  404. # temp_list.append([new_w, col[1], new_w, col[3]])
  405. # else:
  406. # temp_list.append(last_col)
  407. # last_col = col
  408. # if find_flag:
  409. # temp_list.append(col_line_list[-1])
  410. # col_line_list = temp_list
  411. # 由线得到按行列排列的bbox
  412. row_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in row_line_list]
  413. col_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in col_line_list]
  414. table_bbox_list, table_cell_list = get_table_bbox_list(img, [row_line_list], [col_line_list], [table_location], bbox_list)
  415. # 线合并
  416. line_list = row_line_list + col_line_list
  417. # show
  418. if show:
  419. for r in table_cell_list:
  420. for c in r:
  421. cv2.rectangle(img_result, c[0], c[1], (0, 255, 0), 1)
  422. cv2.namedWindow('table_cell', cv2.WINDOW_NORMAL)
  423. cv2.imshow('table_cell', img_result)
  424. for line in col_line_list:
  425. cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2)
  426. for line in row_line_list:
  427. cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2)
  428. cv2.namedWindow('img', cv2.WINDOW_NORMAL)
  429. cv2.imshow('img', cv2.resize(img_result, (768, 1024)))
  430. cv2.waitKey(0)
  431. return line_list, table_cell_list, table_location, bbox_text_dict