123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- import copy
- import re
- import cv2
- import numpy as np
- from botr.rules.table_utils import shrink_bbox, split_bbox, get_table_bbox_list, count_black
- from botr.utils import line_iou
- # from format_convert.utils import log
- def get_table_by_rule(img, text_list, bbox_list, table_location, show=0):
- if show:
- print('get_table_by_rule bbox_list', bbox_list)
- if not bbox_list:
- return [], [], [], {}
- if show:
- img_show = copy.deepcopy(img)
- img_result = copy.deepcopy(img)
- # 处理bbox,缩小框
- bbox_list = shrink_bbox(img, bbox_list)
- # 创建对应dict
- bbox_text_dict = {}
- temp_list = []
- for i in range(len(text_list)):
- # 排除text为空的
- if not text_list[i]:
- continue
- if re.sub(' ', '', text_list[i]) == '':
- continue
- # text中间为空格,其实是两列的
- match = re.search('[ ]{3,}', text_list[i])
- if match:
- # print(text_list[i][match.span()[1]:], re.match('[((]', text_list[i][match.span()[1]:]))
- text = text_list[i]
- bbox = bbox_list[i]
- blank_index = (match.span()[0] + match.span()[1]) / 2
- chinese_cnt = len(re.findall('[\u4e00-\u9fff()?。,!【】¥《》]', text)) * 1.5
- char_cnt = len(re.findall('[ .?!,+*&^%$#@~=:;/<>()a-zA-Z0-9{}]', text))
- # print(text, match.span()[0], match.span()[1], blank_index, chinese_cnt, char_cnt)
- char_cnt += chinese_cnt
- char_pixel = abs(bbox[0][0] - bbox[2][0]) / char_cnt
- index_pixel = char_pixel * blank_index
- # print(abs(bbox[0][0] - bbox[2][0]), char_cnt, char_pixel, index_pixel)
- bbox1 = [bbox[0], bbox[1], [bbox[0][0] + index_pixel, bbox[2][1]], bbox[3]]
- bbox1 = shrink_bbox(img, [bbox1])[0]
- text1 = text[:match.span()[0]]
- bbox2 = [[bbox[0][0]+index_pixel, bbox[0][1]], bbox[1], bbox[2], bbox[3]]
- bbox2 = shrink_bbox(img, [bbox2])[0]
- text2 = text[match.span()[1]:]
- if re.sub(' ', '', text1) != '':
- bbox_text_dict[str(bbox1)] = text1
- temp_list.append(bbox1)
- if re.sub(' ', '', text2) != '':
- bbox_text_dict[str(bbox2)] = text2
- temp_list.append(bbox2)
- # 正常的bbox
- else:
- bbox_text_dict[str(bbox_list[i])] = text_list[i]
- temp_list.append(bbox_list[i])
- bbox_list = temp_list
- if show:
- print('bbox_text_dict', bbox_text_dict)
- for bbox in bbox_list:
- cv2.rectangle(img_show, (int(bbox[0][0]), int(bbox[0][1])),
- (int(bbox[2][0]), int(bbox[2][1])), (255, 0, 0), 2)
- cv2.imshow('bbox_list', img_show)
- cv2.waitKey(0)
- # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox
- table_left_up_point = [table_location[0], table_location[1]]
- min_distance = 100000000000
- if not bbox_list:
- return [], [], [], {}
- first_bbox = bbox_list[0]
- for bbox in bbox_list:
- distance = abs(bbox[0][0] - table_left_up_point[0]) + abs(bbox[0][1] - table_left_up_point[1])
- if distance < min_distance:
- min_distance = distance
- first_bbox = bbox
- # print('first_bbox', first_bbox, bbox_text_dict.get(str(first_bbox)))
- # # 对first_bbox预处理
- # # 分割
- # new_bbox_list, bbox_text_dict = split_bbox(img, first_bbox, bbox_text_dict)
- # if new_bbox_list:
- # if first_bbox in bbox_list:
- # bbox_list.remove(first_bbox)
- # bbox_list += new_bbox_list
- # new_bbox_list.sort(key=lambda x: (x[0][0]))
- # first_bbox = new_bbox_list[0]
- # 根据第一个bbox,得到第一行
- first_row = []
- bbox_list.sort(key=lambda x: (x[0][1], x[0][0]))
- for bbox in bbox_list:
- # h有交集
- if first_bbox[0][1] <= bbox[0][1] <= first_bbox[2][1] \
- or first_bbox[0][1] <= bbox[2][1] <= first_bbox[2][1] \
- or bbox[0][1] <= first_bbox[0][1] <= bbox[2][1] \
- or bbox[0][1] <= first_bbox[2][1] <= bbox[2][1]:
- first_row.append(bbox)
- # h小于first_box
- elif bbox[2][1] <= first_bbox[0][1]:
- first_row.append(bbox)
- # 对第一行分列
- first_row.sort(key=lambda x: (x[0][0], x[0][1]))
- first_row_col = []
- used_bbox = []
- for bbox in first_row:
- if bbox in used_bbox:
- continue
- temp_col = []
- for bbox1 in first_row:
- if bbox1 in used_bbox:
- continue
- if bbox1[0][0] <= bbox[0][0] <= bbox1[2][0] \
- or bbox1[0][0] <= bbox[2][0] <= bbox1[2][0] \
- or bbox[0][0] <= bbox1[0][0] <= bbox[2][0] \
- or bbox[0][0] <= bbox1[2][0] <= bbox[2][0]:
- temp_col.append(bbox1)
- used_bbox.append(bbox1)
- first_row_col.append(temp_col)
- # 根据第一个bbox,得到第一列
- first_col = []
- bbox_list.sort(key=lambda x: (x[0][0], x[0][1]))
- for bbox in bbox_list:
- # w有交集
- if first_bbox[0][0] <= bbox[0][0] <= first_bbox[2][0] \
- or first_bbox[0][0] <= bbox[2][0] <= first_bbox[2][0] \
- or bbox[0][0] <= first_bbox[0][0] <= bbox[2][0] \
- or bbox[0][0] <= first_bbox[2][0] <= bbox[2][0]:
- first_col.append(bbox)
- # w小于first_box
- elif bbox[2][0] <= first_bbox[0][0]:
- first_col.append(bbox)
- # 对第一列分行
- first_col.sort(key=lambda x: (x[0][1], x[0][0]))
- first_col_row = []
- current_bbox = first_col[0]
- temp_row = []
- for bbox in first_col:
- if current_bbox[0][1] <= bbox[0][1] <= current_bbox[2][1] \
- or current_bbox[0][1] <= bbox[2][1] <= current_bbox[2][1] \
- or bbox[0][1] <= current_bbox[0][1] <= bbox[2][1] \
- or bbox[0][1] <= current_bbox[2][1] <= bbox[2][1]:
- temp_row.append(bbox)
- else:
- if temp_row:
- temp_row.sort(key=lambda x: x[0][1])
- first_col_row.append(temp_row)
- temp_row = [bbox]
- current_bbox = bbox
- if temp_row:
- temp_row.sort(key=lambda x: x[0][1])
- first_col_row.append(temp_row)
- if show:
- print('len(first_row)', len(first_row))
- print('first_row', [bbox_text_dict.get(str(x)) for x in first_row])
- print('first_col', [bbox_text_dict.get(str(x)) for x in first_col])
- print('len(first_col)', len(first_col))
- print('len(first_row_col)', len(first_row_col))
- print('len(first_col_row)', len(first_col_row))
- # 划线 列
- col_line_list = []
- for col in first_row_col:
- # 画2条线,根据左右bbox
- min_w, max_w = 1000000, 0
- # print('col', [bbox_text_dict.get(str(x)) for x in col])
- for bbox in col:
- if bbox[0][0] < min_w:
- min_w = bbox[0][0]
- if bbox[2][0] > max_w:
- max_w = bbox[2][0]
- col_line_list.append([min_w, table_location[1], min_w, table_location[3]])
- col_line_list.append([max_w, table_location[1], max_w, table_location[3]])
- # 划线 行
- row_line_list = []
- last_max_h = None
- for row in first_col_row:
- # 画3条线,根据上下bbox
- min_h, max_h = 1000000, 0
- for bbox in row:
- if bbox[0][1] < min_h:
- min_h = bbox[0][1]
- if bbox[2][1] > max_h:
- max_h = bbox[2][1]
- row_line_list.append([table_location[0], min_h, table_location[2], min_h])
- row_line_list.append([table_location[0], max_h, table_location[2], max_h])
- # if last_max_h:
- # row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)])
- last_max_h = max_h
- if show:
- print('len(col_line_list)', len(col_line_list))
- print('col_line_list', col_line_list)
- print('len(row_line_list)', len(row_line_list))
- # 判断列线有没有压在黑色像素上,若有则移动
- temp_list = []
- for i in range(1, len(col_line_list), 2):
- # 前一列右边线
- line1 = col_line_list[i]
- line1 = [int(x) for x in line1]
- # 后一列左边线
- if i+1 >= len(col_line_list):
- break
- line2 = col_line_list[i+1]
- line2 = [int(x) for x in line2]
- max_black_cnt = 10
- black_threshold = 150
- black_cnt1 = count_black(img[line1[1]:line1[3], line1[0]:line1[2]+1, :], threshold=black_threshold)
- black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold)
- # print('col black_cnt1', i, black_cnt1)
- # print('col black_cnt2', i, black_cnt2)
- # if black_cnt2 <= max_black_cnt and black_cnt1 <= max_black_cnt:
- # if black_cnt1 >= black_cnt2:
- # temp_list.append(line2)
- # else:
- # temp_list.append(line1)
- # elif black_cnt2 <= max_black_cnt:
- # temp_list.append(line2)
- # elif black_cnt1 <= max_black_cnt:
- # temp_list.append(line1)
- # 两条线都不符合
- # else:
- # 先找出最近的bbox,不能跨bbox
- min_distance = 100000
- min_dis_bbox = bbox_list[0]
- # for bbox in bbox_list:
- for bbox in first_col_row[0]:
- if bbox[2][0] < line2[0]:
- _dis = line2[0] - bbox[2][0]
- if _dis < min_distance:
- min_distance = _dis
- min_dis_bbox = bbox
- # 从右向左移寻找
- right_left_index_list = []
- right_left_cnt_list = []
- find_flag = False
- for j in range(line2[0], int(min_dis_bbox[2][0]), -1):
- # 需连续3个像素列满足要求
- if len(right_left_index_list) == 3:
- find_flag = True
- break
- black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold)
- # print('col black_cnt', black_cnt)
- right_left_cnt_list.append(black_cnt)
- # 直接找到无黑色像素的
- if black_cnt == 0:
- right_left_index_list.append(j)
- else:
- right_left_index_list = []
- if show:
- print('find_flag', find_flag)
- if find_flag:
- temp_list.append([right_left_index_list[1], line2[1], right_left_index_list[1], line2[3]])
- else:
- # 为0的找不到,就找最小的
- # 每个位置加上前后n位求平均
- n = 1
- min_cnt = 1000000.
- min_cnt_index = 0
- for j, cnt in enumerate(right_left_cnt_list):
- if show:
- print('min_cnt', min_cnt)
- if j < n or j > len(right_left_cnt_list) - 1 - n:
- continue
- # 小到一定程度提前结束
- if min_cnt <= 0.001:
- break
- last_cnt = right_left_cnt_list[j-1]
- next_cnt = right_left_cnt_list[j+1]
- avg_cnt = (last_cnt + cnt + next_cnt) / 3
- if avg_cnt < min_cnt:
- min_cnt = avg_cnt
- min_cnt_index = j
- min_cnt_index = line2[0] - min_cnt_index
- temp_list.append([min_cnt_index, line2[1], min_cnt_index, line2[3]])
- col_line_list = temp_list
- if show:
- print('len(col_line_list)', len(col_line_list))
- for col in col_line_list:
- col = [int(x) for x in col]
- cv2.line(img_show, col[:2], col[2:4], (0, 255, 0), 2)
- cv2.imshow('col_line_list', img_show)
- cv2.waitKey(0)
- # 根据列的划线对bbox分列
- last_line = [0, 0, 0, 0]
- col_bbox_list = []
- for line in col_line_list + [[img.shape[0], 0, img.shape[0], 0]]:
- col = []
- for bbox in bbox_list:
- iou = line_iou([[last_line[0], 0], [line[0], 0]], [[bbox[0][0], 0], [bbox[2][0], 0]], axis=0)
- if iou >= 0.6:
- col.append(bbox)
- col.sort(key=lambda x: x[0][1])
- col_bbox_list.append(col)
- last_line = line
- # 判断行线
- temp_list = []
- for i in range(1, len(row_line_list), 2):
- # 前一行下边线
- line1 = row_line_list[i]
- line1 = [int(x) for x in line1]
- # 后一行上边线
- if i+1 >= len(row_line_list):
- break
- line2 = row_line_list[i+1]
- line2 = [int(x) for x in line2]
- # 判断行线之间的bbox分别属于哪一行
- sub_bbox_list = []
- threshold = 5
- for bbox in bbox_list:
- if line1[1] - threshold <= bbox[0][1] <= bbox[2][1] <= line2[1]+threshold:
- sub_bbox_list.append(bbox)
- # 根据行的h和分列判断bbox属于上一行还是下一行
- line1_bbox_list = []
- line2_bbox_list = []
- if sub_bbox_list:
- sub_bbox_list.sort(key=lambda x: x[0][1])
- min_h = sub_bbox_list[0][0][1] - 1
- max_h = sub_bbox_list[-1][2][1] + 1
- for bbox in sub_bbox_list:
- # 找到属于哪一列
- current_col = None
- for col in col_bbox_list:
- if bbox in col:
- current_col = copy.deepcopy(col)
- break
- if current_col:
- # 行做成bbox加入列作为基准
- line1_bbox = [[0, min_h], [], [0, min_h], []]
- line2_bbox = [[0, max_h], [], [0, max_h], []]
- current_col += [line1_bbox, line2_bbox]
- current_col.sort(key=lambda x: x[0][1])
- bbox_index = current_col.index(bbox)
- line1_bbox_index = current_col.index(line1_bbox)
- line2_bbox_index = current_col.index(line2_bbox)
- # print('current_col', [bbox_text_dict.get(str(x)) for x in current_col])
- # print('line1_bbox_index, bbox_index, line2_bbox_index', line1_bbox_index, bbox_index, line2_bbox_index)
- # 计算距离
- distance1 = 10000
- for index in range(line1_bbox_index, bbox_index):
- h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
- h2 = (current_col[index+1][0][1] + current_col[index+1][2][1]) / 2
- # print(bbox_text_dict.get())
- distance1 = abs(h1 - h2)
- distance2 = 10000
- for index in range(line2_bbox_index, bbox_index, -1):
- h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
- h2 = (current_col[index-1][0][1] + current_col[index-1][2][1]) / 2
- distance2 = abs(h1 - h2)
- # print(bbox_text_dict.get(str(bbox)), distance1, distance2)
- ratio = 1.5
- # 属于下一行
- if distance1 >= distance2 * ratio or distance1 >= distance2 + 8:
- line2_bbox_list.append(bbox)
- # 属于上一行
- elif distance2 >= distance1 * ratio or distance2 >= distance1 + 8:
- line1_bbox_list.append(bbox)
- else:
- print('距离不明确,需要nsp模型介入判断')
- if line1_bbox_list:
- # print('line1_bbox_list', [bbox_text_dict.get(str(x)) for x in line1_bbox_list])
- line1_bbox_list.sort(key=lambda x: x[0][1])
- b = line1_bbox_list[-1]
- line1 = [line1[0], b[2][1], line1[2], b[2][1]]
- if line2_bbox_list:
- # print('line2_bbox_list', [bbox_text_dict.get(str(x)) for x in line2_bbox_list])
- line2_bbox_list.sort(key=lambda x: x[0][1])
- b = line2_bbox_list[0]
- line2 = [line2[0], b[0][1], line2[2], b[0][1]]
- _line = [line1[0], (line1[1]+line2[1])/2, line1[2], (line1[3]+line2[3])/2]
- _line = [int(x) for x in _line]
- temp_list.append(_line)
- row_line_list = temp_list
- if show:
- print('len(row_line_list)', len(row_line_list))
- print('len(col_line_list)', len(col_line_list))
- # 只有一行或一列的直接跳过
- if len(row_line_list) < 1 or len(col_line_list) < 1:
- return [], [], [], {}
- # 加上表格轮廓线
- threshold = 5
- min_w = max(table_location[0], 0+threshold)
- max_w = min(table_location[2], img.shape[1]-threshold)
- min_h = max(table_location[1], 0+threshold)
- max_h = min(table_location[3], img.shape[0]-threshold)
- row_line_list.append([min_w, min_h, max_w, min_h])
- row_line_list.append([min_w, max_h, max_w, max_h])
- col_line_list.append([min_w, min_h, min_w, max_h])
- col_line_list.append([max_w, min_h, max_w, max_h])
- # # 行线、列线两两之间没有bbox则合并
- # col_line_list.sort(key=lambda x: x[0])
- # temp_list = []
- # used_bbox_list = []
- # last_col = col_line_list[0]
- # for col in col_line_list[1:]:
- # find_flag = False
- # for bbox in bbox_list:
- # if bbox in used_bbox_list:
- # continue
- # if last_col[0] <= (bbox[0][0] + bbox[2][0]) / 2 <= col[0]:
- # print('bbox', bbox, bbox_text_dict.get(str(bbox)))
- # used_bbox_list.append(bbox)
- # find_flag = True
- # break
- # print('last_col, col, find_flag', last_col, col, find_flag)
- # if not find_flag:
- # new_w = int((last_col[0] + col[0])/2)
- # temp_list.append([new_w, col[1], new_w, col[3]])
- # else:
- # temp_list.append(last_col)
- # last_col = col
- # if find_flag:
- # temp_list.append(col_line_list[-1])
- # col_line_list = temp_list
- # 由线得到按行列排列的bbox
- row_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in row_line_list]
- col_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in col_line_list]
- table_bbox_list, table_cell_list = get_table_bbox_list(img, [row_line_list], [col_line_list], [table_location], bbox_list)
- # 线合并
- line_list = row_line_list + col_line_list
- # show
- if show:
- for r in table_cell_list:
- for c in r:
- cv2.rectangle(img_result, c[0], c[1], (0, 255, 0), 1)
- cv2.namedWindow('table_cell', cv2.WINDOW_NORMAL)
- cv2.imshow('table_cell', img_result)
- for line in col_line_list:
- cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2)
- for line in row_line_list:
- cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2)
- cv2.namedWindow('img', cv2.WINDOW_NORMAL)
- cv2.imshow('img', cv2.resize(img_result, (768, 1024)))
- cv2.waitKey(0)
- return line_list, table_cell_list, table_location, bbox_text_dict
|