import copy import re import cv2 import numpy as np from botr.rules.table_utils import shrink_bbox, split_bbox, get_table_bbox_list, count_black from botr.utils import line_iou # from format_convert.utils import log def get_table_by_rule(img, text_list, bbox_list, table_location, show=0): if show: print('get_table_by_rule bbox_list', bbox_list) if not bbox_list: return [], [], [], {} if show: img_show = copy.deepcopy(img) img_result = copy.deepcopy(img) # 处理bbox,缩小框 bbox_list = shrink_bbox(img, bbox_list) # 创建对应dict bbox_text_dict = {} temp_list = [] for i in range(len(text_list)): # 排除text为空的 if not text_list[i]: continue if re.sub(' ', '', text_list[i]) == '': continue # text中间为空格,其实是两列的 match = re.search('[ ]{3,}', text_list[i]) if match: # print(text_list[i][match.span()[1]:], re.match('[((]', text_list[i][match.span()[1]:])) text = text_list[i] bbox = bbox_list[i] blank_index = (match.span()[0] + match.span()[1]) / 2 chinese_cnt = len(re.findall('[\u4e00-\u9fff()?。,!【】¥《》]', text)) * 1.5 char_cnt = len(re.findall('[ .?!,+*&^%$#@~=:;/<>()a-zA-Z0-9{}]', text)) # print(text, match.span()[0], match.span()[1], blank_index, chinese_cnt, char_cnt) char_cnt += chinese_cnt char_pixel = abs(bbox[0][0] - bbox[2][0]) / char_cnt index_pixel = char_pixel * blank_index # print(abs(bbox[0][0] - bbox[2][0]), char_cnt, char_pixel, index_pixel) bbox1 = [bbox[0], bbox[1], [bbox[0][0] + index_pixel, bbox[2][1]], bbox[3]] bbox1 = shrink_bbox(img, [bbox1])[0] text1 = text[:match.span()[0]] bbox2 = [[bbox[0][0]+index_pixel, bbox[0][1]], bbox[1], bbox[2], bbox[3]] bbox2 = shrink_bbox(img, [bbox2])[0] text2 = text[match.span()[1]:] if re.sub(' ', '', text1) != '': bbox_text_dict[str(bbox1)] = text1 temp_list.append(bbox1) if re.sub(' ', '', text2) != '': bbox_text_dict[str(bbox2)] = text2 temp_list.append(bbox2) # 正常的bbox else: bbox_text_dict[str(bbox_list[i])] = text_list[i] temp_list.append(bbox_list[i]) bbox_list = temp_list if show: print('bbox_text_dict', bbox_text_dict) for bbox in bbox_list: cv2.rectangle(img_show, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])), (255, 0, 0), 2) cv2.namedWindow('bbox_list', cv2.WINDOW_NORMAL) cv2.imshow('bbox_list', img_show) cv2.waitKey(0) # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox table_left_up_point = [table_location[0], table_location[1]] min_distance = 100000000000 if not bbox_list: return [], [], [], {} first_bbox = bbox_list[0] for bbox in bbox_list: distance = abs(bbox[0][0] - table_left_up_point[0]) + abs(bbox[0][1] - table_left_up_point[1]) if distance < min_distance: min_distance = distance first_bbox = bbox # print('first_bbox', first_bbox, bbox_text_dict.get(str(first_bbox))) # # 对first_bbox预处理 # # 分割 # new_bbox_list, bbox_text_dict = split_bbox(img, first_bbox, bbox_text_dict) # if new_bbox_list: # if first_bbox in bbox_list: # bbox_list.remove(first_bbox) # bbox_list += new_bbox_list # new_bbox_list.sort(key=lambda x: (x[0][0])) # first_bbox = new_bbox_list[0] # 根据第一个bbox,得到第一行 first_row = [] bbox_list.sort(key=lambda x: (x[0][1], x[0][0])) for bbox in bbox_list: # h有交集 if first_bbox[0][1] <= bbox[0][1] <= first_bbox[2][1] \ or first_bbox[0][1] <= bbox[2][1] <= first_bbox[2][1] \ or bbox[0][1] <= first_bbox[0][1] <= bbox[2][1] \ or bbox[0][1] <= first_bbox[2][1] <= bbox[2][1]: first_row.append(bbox) # h小于first_box elif bbox[2][1] <= first_bbox[0][1]: first_row.append(bbox) # 对第一行分列 first_row.sort(key=lambda x: (x[0][0], x[0][1])) first_row_col = [] used_bbox = [] for bbox in first_row: if bbox in used_bbox: continue temp_col = [] for bbox1 in first_row: if bbox1 in used_bbox: continue if bbox1[0][0] <= bbox[0][0] <= bbox1[2][0] \ or bbox1[0][0] <= bbox[2][0] <= bbox1[2][0] \ or bbox[0][0] <= bbox1[0][0] <= bbox[2][0] \ or bbox[0][0] <= bbox1[2][0] <= bbox[2][0]: temp_col.append(bbox1) used_bbox.append(bbox1) first_row_col.append(temp_col) # 根据第一个bbox,得到第一列 first_col = [] bbox_list.sort(key=lambda x: (x[0][0], x[0][1])) for bbox in bbox_list: # w有交集 if first_bbox[0][0] <= bbox[0][0] <= first_bbox[2][0] \ or first_bbox[0][0] <= bbox[2][0] <= first_bbox[2][0] \ or bbox[0][0] <= first_bbox[0][0] <= bbox[2][0] \ or bbox[0][0] <= first_bbox[2][0] <= bbox[2][0]: first_col.append(bbox) # w小于first_box elif bbox[2][0] <= first_bbox[0][0]: first_col.append(bbox) # 对第一列分行 first_col.sort(key=lambda x: (x[0][1], x[0][0])) first_col_row = [] current_bbox = first_col[0] temp_row = [] for bbox in first_col: if current_bbox[0][1] <= bbox[0][1] <= current_bbox[2][1] \ or current_bbox[0][1] <= bbox[2][1] <= current_bbox[2][1] \ or bbox[0][1] <= current_bbox[0][1] <= bbox[2][1] \ or bbox[0][1] <= current_bbox[2][1] <= bbox[2][1]: temp_row.append(bbox) else: if temp_row: temp_row.sort(key=lambda x: x[0][1]) first_col_row.append(temp_row) temp_row = [bbox] current_bbox = bbox if temp_row: temp_row.sort(key=lambda x: x[0][1]) first_col_row.append(temp_row) if show: print('len(first_row)', len(first_row)) print('first_row', [bbox_text_dict.get(str(x)) for x in first_row]) print('first_col', [bbox_text_dict.get(str(x)) for x in first_col]) print('len(first_col)', len(first_col)) print('len(first_row_col)', len(first_row_col)) print('len(first_col_row)', len(first_col_row)) # 划线 列 col_line_list = [] for col in first_row_col: # 画2条线,根据左右bbox min_w, max_w = 1000000, 0 # print('col', [bbox_text_dict.get(str(x)) for x in col]) for bbox in col: if bbox[0][0] < min_w: min_w = bbox[0][0] if bbox[2][0] > max_w: max_w = bbox[2][0] col_line_list.append([min_w, table_location[1], min_w, table_location[3]]) col_line_list.append([max_w, table_location[1], max_w, table_location[3]]) # 划线 行 row_line_list = [] last_max_h = None for row in first_col_row: # 画3条线,根据上下bbox min_h, max_h = 1000000, 0 for bbox in row: if bbox[0][1] < min_h: min_h = bbox[0][1] if bbox[2][1] > max_h: max_h = bbox[2][1] row_line_list.append([table_location[0], min_h, table_location[2], min_h]) row_line_list.append([table_location[0], max_h, table_location[2], max_h]) # if last_max_h: # row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)]) last_max_h = max_h if show: print('len(col_line_list)', len(col_line_list)) print('col_line_list', col_line_list) print('len(row_line_list)', len(row_line_list)) # 判断列线有没有压在黑色像素上,若有则移动 temp_list = [] for i in range(1, len(col_line_list), 2): # 前一列右边线 line1 = col_line_list[i] line1 = [int(x) for x in line1] # 后一列左边线 if i+1 >= len(col_line_list): break line2 = col_line_list[i+1] line2 = [int(x) for x in line2] max_black_cnt = 10 black_threshold = 150 black_cnt1 = count_black(img[line1[1]:line1[3], line1[0]:line1[2]+1, :], threshold=black_threshold) black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold) # print('col black_cnt1', i, black_cnt1) # print('col black_cnt2', i, black_cnt2) # if black_cnt2 <= max_black_cnt and black_cnt1 <= max_black_cnt: # if black_cnt1 >= black_cnt2: # temp_list.append(line2) # else: # temp_list.append(line1) # elif black_cnt2 <= max_black_cnt: # temp_list.append(line2) # elif black_cnt1 <= max_black_cnt: # temp_list.append(line1) # 两条线都不符合 # else: # 先找出最近的bbox,不能跨bbox min_distance = 100000 min_dis_bbox = bbox_list[0] # for bbox in bbox_list: for bbox in first_col_row[0]: if bbox[2][0] < line2[0]: _dis = line2[0] - bbox[2][0] if _dis < min_distance: min_distance = _dis min_dis_bbox = bbox # 从右向左移寻找 right_left_index_list = [] right_left_cnt_list = [] find_flag = False for j in range(line2[0], int(min_dis_bbox[2][0]), -1): # 需连续3个像素列满足要求 if len(right_left_index_list) == 3: find_flag = True break black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold) # print('col black_cnt', black_cnt) right_left_cnt_list.append(black_cnt) # 直接找到无黑色像素的 if black_cnt == 0: right_left_index_list.append(j) else: right_left_index_list = [] if show: print('find_flag', find_flag) if find_flag: temp_list.append([right_left_index_list[1], line2[1], right_left_index_list[1], line2[3]]) else: # 为0的找不到,就找最小的 # 每个位置加上前后n位求平均 n = 1 min_cnt = 1000000. min_cnt_index = 0 for j, cnt in enumerate(right_left_cnt_list): if show: print('min_cnt', min_cnt) if j < n or j > len(right_left_cnt_list) - 1 - n: continue # 小到一定程度提前结束 if min_cnt <= 0.001: break last_cnt = right_left_cnt_list[j-1] next_cnt = right_left_cnt_list[j+1] avg_cnt = (last_cnt + cnt + next_cnt) / 3 if avg_cnt < min_cnt: min_cnt = avg_cnt min_cnt_index = j min_cnt_index = line2[0] - min_cnt_index temp_list.append([min_cnt_index, line2[1], min_cnt_index, line2[3]]) col_line_list = temp_list if show: print('len(col_line_list)', len(col_line_list)) for col in col_line_list: col = [int(x) for x in col] cv2.line(img_show, col[:2], col[2:4], (0, 255, 0), 2) cv2.imshow('col_line_list', img_show) cv2.waitKey(0) # 根据列的划线对bbox分列 last_line = [0, 0, 0, 0] col_bbox_list = [] for line in col_line_list + [[img.shape[0], 0, img.shape[0], 0]]: col = [] for bbox in bbox_list: iou = line_iou([[last_line[0], 0], [line[0], 0]], [[bbox[0][0], 0], [bbox[2][0], 0]], axis=0) if iou >= 0.6: col.append(bbox) col.sort(key=lambda x: x[0][1]) col_bbox_list.append(col) last_line = line # 判断行线 temp_list = [] for i in range(1, len(row_line_list), 2): # 前一行下边线 line1 = row_line_list[i] line1 = [int(x) for x in line1] # 后一行上边线 if i+1 >= len(row_line_list): break line2 = row_line_list[i+1] line2 = [int(x) for x in line2] # 判断行线之间的bbox分别属于哪一行 sub_bbox_list = [] threshold = 5 for bbox in bbox_list: if line1[1] - threshold <= bbox[0][1] <= bbox[2][1] <= line2[1]+threshold: sub_bbox_list.append(bbox) # 根据行的h和分列判断bbox属于上一行还是下一行 line1_bbox_list = [] line2_bbox_list = [] if sub_bbox_list: sub_bbox_list.sort(key=lambda x: x[0][1]) min_h = sub_bbox_list[0][0][1] - 1 max_h = sub_bbox_list[-1][2][1] + 1 for bbox in sub_bbox_list: # 找到属于哪一列 current_col = None for col in col_bbox_list: if bbox in col: current_col = copy.deepcopy(col) break if current_col: # 行做成bbox加入列作为基准 line1_bbox = [[0, min_h], [], [0, min_h], []] line2_bbox = [[0, max_h], [], [0, max_h], []] current_col += [line1_bbox, line2_bbox] current_col.sort(key=lambda x: x[0][1]) bbox_index = current_col.index(bbox) line1_bbox_index = current_col.index(line1_bbox) line2_bbox_index = current_col.index(line2_bbox) # print('current_col', [bbox_text_dict.get(str(x)) for x in current_col]) # print('line1_bbox_index, bbox_index, line2_bbox_index', line1_bbox_index, bbox_index, line2_bbox_index) # 计算距离 distance1 = 10000 for index in range(line1_bbox_index, bbox_index): h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2 h2 = (current_col[index+1][0][1] + current_col[index+1][2][1]) / 2 # print(bbox_text_dict.get()) distance1 = abs(h1 - h2) distance2 = 10000 for index in range(line2_bbox_index, bbox_index, -1): h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2 h2 = (current_col[index-1][0][1] + current_col[index-1][2][1]) / 2 distance2 = abs(h1 - h2) # print(bbox_text_dict.get(str(bbox)), distance1, distance2) ratio = 1.5 # 属于下一行 if distance1 >= distance2 * ratio or distance1 >= distance2 + 8: line2_bbox_list.append(bbox) # 属于上一行 elif distance2 >= distance1 * ratio or distance2 >= distance1 + 8: line1_bbox_list.append(bbox) else: print('距离不明确,需要nsp模型介入判断') if line1_bbox_list: # print('line1_bbox_list', [bbox_text_dict.get(str(x)) for x in line1_bbox_list]) line1_bbox_list.sort(key=lambda x: x[0][1]) b = line1_bbox_list[-1] line1 = [line1[0], b[2][1], line1[2], b[2][1]] if line2_bbox_list: # print('line2_bbox_list', [bbox_text_dict.get(str(x)) for x in line2_bbox_list]) line2_bbox_list.sort(key=lambda x: x[0][1]) b = line2_bbox_list[0] line2 = [line2[0], b[0][1], line2[2], b[0][1]] _line = [line1[0], (line1[1]+line2[1])/2, line1[2], (line1[3]+line2[3])/2] _line = [int(x) for x in _line] temp_list.append(_line) row_line_list = temp_list if show: print('len(row_line_list)', len(row_line_list)) print('len(col_line_list)', len(col_line_list)) # 只有一行或一列的直接跳过 if len(row_line_list) < 1 or len(col_line_list) < 1: return [], [], [], {} # 加上表格轮廓线 threshold = 5 min_w = max(table_location[0], 0+threshold) max_w = min(table_location[2], img.shape[1]-threshold) min_h = max(table_location[1], 0+threshold) max_h = min(table_location[3], img.shape[0]-threshold) row_line_list.append([min_w, min_h, max_w, min_h]) row_line_list.append([min_w, max_h, max_w, max_h]) col_line_list.append([min_w, min_h, min_w, max_h]) col_line_list.append([max_w, min_h, max_w, max_h]) # # 行线、列线两两之间没有bbox则合并 # col_line_list.sort(key=lambda x: x[0]) # temp_list = [] # used_bbox_list = [] # last_col = col_line_list[0] # for col in col_line_list[1:]: # find_flag = False # for bbox in bbox_list: # if bbox in used_bbox_list: # continue # if last_col[0] <= (bbox[0][0] + bbox[2][0]) / 2 <= col[0]: # print('bbox', bbox, bbox_text_dict.get(str(bbox))) # used_bbox_list.append(bbox) # find_flag = True # break # print('last_col, col, find_flag', last_col, col, find_flag) # if not find_flag: # new_w = int((last_col[0] + col[0])/2) # temp_list.append([new_w, col[1], new_w, col[3]]) # else: # temp_list.append(last_col) # last_col = col # if find_flag: # temp_list.append(col_line_list[-1]) # col_line_list = temp_list # 由线得到按行列排列的bbox row_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in row_line_list] col_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in col_line_list] table_bbox_list, table_cell_list = get_table_bbox_list(img, [row_line_list], [col_line_list], [table_location], bbox_list) # 线合并 line_list = row_line_list + col_line_list # show if show: for r in table_cell_list: for c in r: cv2.rectangle(img_result, c[0], c[1], (0, 255, 0), 1) cv2.namedWindow('table_cell', cv2.WINDOW_NORMAL) cv2.imshow('table_cell', img_result) for line in col_line_list: cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2) for line in row_line_list: cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2) cv2.namedWindow('img', cv2.WINDOW_NORMAL) cv2.imshow('img', cv2.resize(img_result, (768, 1024))) cv2.waitKey(0) return line_list, table_cell_list, table_location, bbox_text_dict