import copy import cv2 import numpy as np from botr.rules.table_utils import shrink_bbox, split_bbox, get_table_bbox_list, count_black from botr.utils import line_iou # from format_convert.utils import log def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=0): if not bbox_list: return [], [], [] # 处理bbox,缩小框 bbox_list = shrink_bbox(img, bbox_list) # 创建对应dict bbox_text_dict = {} for i in range(len(text_list)): bbox_text_dict[str(bbox_list[i])] = text_list[i] # print('bbox_text_dict', bbox_text_dict) # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox table_left_up_point = [table_location[0], table_location[1]] min_distance = 100000000000 first_bbox = bbox_list[0] for bbox in bbox_list: distance = abs(bbox[0][0] - table_left_up_point[0]) + abs(bbox[0][1] - table_left_up_point[1]) if distance < min_distance: min_distance = distance first_bbox = bbox # print('first_bbox', first_bbox, bbox_text_dict.get(str(first_bbox))) # # 对first_bbox预处理 # # 分割 # new_bbox_list, bbox_text_dict = split_bbox(img, first_bbox, bbox_text_dict) # if new_bbox_list: # if first_bbox in bbox_list: # bbox_list.remove(first_bbox) # bbox_list += new_bbox_list # new_bbox_list.sort(key=lambda x: (x[0][0])) # first_bbox = new_bbox_list[0] # 根据第一个bbox,得到第一行 first_row = [] bbox_list.sort(key=lambda x: (x[0][1], x[0][0])) for bbox in bbox_list: # h有交集 if first_bbox[0][1] <= bbox[0][1] <= first_bbox[2][1] \ or first_bbox[0][1] <= bbox[2][1] <= first_bbox[2][1] \ or bbox[0][1] <= first_bbox[0][1] <= bbox[2][1] \ or bbox[0][1] <= first_bbox[2][1] <= bbox[2][1]: first_row.append(bbox) # h小于first_box elif bbox[2][1] <= first_bbox[0][1]: first_row.append(bbox) # 对第一行分列 first_row.sort(key=lambda x: (x[0][0], x[0][1])) first_row_col = [] used_bbox = [] for bbox in first_row: if bbox in used_bbox: continue temp_col = [] for bbox1 in first_row: if bbox1 in used_bbox: continue if bbox1[0][0] <= bbox[0][0] <= bbox1[2][0] \ or bbox1[0][0] <= bbox[2][0] <= bbox1[2][0] \ or bbox[0][0] <= bbox1[0][0] <= bbox[2][0] \ or bbox[0][0] <= bbox1[2][0] <= bbox[2][0]: temp_col.append(bbox1) used_bbox.append(bbox1) first_row_col.append(temp_col) # 根据第一个bbox,得到第一列 first_col = [] bbox_list.sort(key=lambda x: (x[0][0], x[0][1])) for bbox in bbox_list: # w有交集 if first_bbox[0][0] <= bbox[0][0] <= first_bbox[2][0] \ or first_bbox[0][0] <= bbox[2][0] <= first_bbox[2][0] \ or bbox[0][0] <= first_bbox[0][0] <= bbox[2][0] \ or bbox[0][0] <= first_bbox[2][0] <= bbox[2][0]: first_col.append(bbox) # w小于first_box elif bbox[2][0] <= first_bbox[0][0]: first_col.append(bbox) # 对第一列分行 first_col.sort(key=lambda x: (x[0][1], x[0][0])) first_col_row = [] current_bbox = first_col[0] temp_row = [] for bbox in first_col: if current_bbox[0][1] <= bbox[0][1] <= current_bbox[2][1] \ or current_bbox[0][1] <= bbox[2][1] <= current_bbox[2][1] \ or bbox[0][1] <= current_bbox[0][1] <= bbox[2][1] \ or bbox[0][1] <= current_bbox[2][1] <= bbox[2][1]: temp_row.append(bbox) else: if temp_row: temp_row.sort(key=lambda x: x[0][1]) first_col_row.append(temp_row) temp_row = [bbox] current_bbox = bbox if temp_row: temp_row.sort(key=lambda x: x[0][1]) first_col_row.append(temp_row) # print('len(first_row)', len(first_row)) # print('first_row', [bbox_text_dict.get(str(x)) for x in first_row]) # print('first_col', [bbox_text_dict.get(str(x)) for x in first_col]) # print('len(first_col)', len(first_col)) # print('len(first_row_col)', len(first_row_col)) # print('len(first_col_row)', len(first_col_row)) # 划线 列 col_line_list = [] for col in first_row_col: # 画2条线,根据左右bbox min_w, max_w = 1000000, 0 # print('col', [bbox_text_dict.get(str(x)) for x in col]) for bbox in col: if bbox[0][0] < min_w: min_w = bbox[0][0] if bbox[2][0] > max_w: max_w = bbox[2][0] col_line_list.append([min_w, table_location[1], min_w, table_location[3]]) col_line_list.append([max_w, table_location[1], max_w, table_location[3]]) # 划线 行 row_line_list = [] last_max_h = None for row in first_col_row: # 画3条线,根据上下bbox min_h, max_h = 1000000, 0 for bbox in row: if bbox[0][1] < min_h: min_h = bbox[0][1] if bbox[2][1] > max_h: max_h = bbox[2][1] row_line_list.append([table_location[0], min_h, table_location[2], min_h]) row_line_list.append([table_location[0], max_h, table_location[2], max_h]) # if last_max_h: # row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)]) last_max_h = max_h # print('len(col_line_list)', len(col_line_list)) # print('col_line_list', col_line_list) # print('len(row_line_list)', len(row_line_list)) # 判断列线有没有压在黑色像素上,若有则移动 temp_list = [] for i in range(1, len(col_line_list), 2): # 前一列右边线 line1 = col_line_list[i] line1 = [int(x) for x in line1] # 后一列左边线 if i+1 >= len(col_line_list): break line2 = col_line_list[i+1] line2 = [int(x) for x in line2] max_black_cnt = 10 black_threshold = 150 black_cnt1 = count_black(img[line1[1]:line1[3], line1[0]:line1[2]+1, :], threshold=black_threshold) black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold) # print('col black_cnt1', i, black_cnt1) # print('col black_cnt2', i, black_cnt2) if False and black_cnt2 <= max_black_cnt and black_cnt1 <= max_black_cnt: if black_cnt1 >= black_cnt2: temp_list.append(line2) else: temp_list.append(line1) elif False and black_cnt2 <= max_black_cnt: temp_list.append(line2) elif False and black_cnt1 <= max_black_cnt: temp_list.append(line1) # 两条线都不符合 else: # 先找出最近的bbox,不能跨bbox min_distance = 100000 min_dis_bbox = bbox_list[0] for bbox in bbox_list: if bbox[2][0] < line2[0]: _dis = line2[0] - bbox[2][0] if _dis < min_distance: min_distance = _dis min_dis_bbox = bbox # 从右向左移寻找 right_left_index_list = [] right_left_cnt_list = [] find_flag = False for j in range(line2[0], int(min_dis_bbox[2][0]), -1): # 需连续3个像素列满足要求 if len(right_left_index_list) == 3: find_flag = True break black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold) # print('col black_cnt', black_cnt) right_left_cnt_list.append(black_cnt) # 直接找到无黑色像素的 if black_cnt == 0: right_left_index_list.append(j) else: right_left_index_list = [] if find_flag: temp_list.append([right_left_index_list[1], line2[1], right_left_index_list[1], line2[3]]) else: # 为0的找不到,就找最小的 # 每个位置加上前后n位求平均 n = 1 min_cnt = 1000000. min_cnt_index = 0 for j, cnt in enumerate(right_left_cnt_list): if j < n or j > len(right_left_cnt_list) - 1 - n: continue # 小到一定程度提前结束 if min_cnt <= 0.001: break last_cnt = right_left_cnt_list[j-1] next_cnt = right_left_cnt_list[j+1] avg_cnt = (last_cnt + cnt + next_cnt) / 3 if avg_cnt < min_cnt: min_cnt = avg_cnt min_cnt_index = j min_cnt_index = line2[0] - min_cnt_index temp_list.append([min_cnt_index, line2[1], min_cnt_index, line2[3]]) col_line_list = temp_list # 根据列的划线对bbox分列 last_line = [0, 0, 0, 0] col_bbox_list = [] for line in col_line_list + [[img.shape[0], 0, img.shape[0], 0]]: col = [] for bbox in bbox_list: iou = line_iou([[last_line[0], 0], [line[0], 0]], [[bbox[0][0], 0], [bbox[2][0], 0]], axis=0) if iou >= 0.6: col.append(bbox) col.sort(key=lambda x: x[0][1]) col_bbox_list.append(col) last_line = line # 判断行线 temp_list = [] for i in range(1, len(row_line_list), 2): # 前一行下边线 line1 = row_line_list[i] line1 = [int(x) for x in line1] # 后一行上边线 if i+1 >= len(row_line_list): break line2 = row_line_list[i+1] line2 = [int(x) for x in line2] # 判断行线之间的bbox分别属于哪一行 sub_bbox_list = [] threshold = 5 for bbox in bbox_list: if line1[1] - threshold <= bbox[0][1] <= bbox[2][1] <= line2[1]+threshold: sub_bbox_list.append(bbox) # 根据行的h和分列判断bbox属于上一行还是下一行 line1_bbox_list = [] line2_bbox_list = [] if sub_bbox_list: sub_bbox_list.sort(key=lambda x: x[0][1]) min_h = sub_bbox_list[0][0][1] - 1 max_h = sub_bbox_list[-1][2][1] + 1 for bbox in sub_bbox_list: # 找到属于哪一列 current_col = None for col in col_bbox_list: if bbox in col: current_col = copy.deepcopy(col) break if current_col: # 行做成bbox加入列作为基准 line1_bbox = [[0, min_h], [], [0, min_h], []] line2_bbox = [[0, max_h], [], [0, max_h], []] current_col += [line1_bbox, line2_bbox] current_col.sort(key=lambda x: x[0][1]) bbox_index = current_col.index(bbox) line1_bbox_index = current_col.index(line1_bbox) line2_bbox_index = current_col.index(line2_bbox) # print('current_col', [bbox_text_dict.get(str(x)) for x in current_col]) # print('line1_bbox_index, bbox_index, line2_bbox_index', line1_bbox_index, bbox_index, line2_bbox_index) # 计算距离 distance1 = 10000 for index in range(line1_bbox_index, bbox_index): h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2 h2 = (current_col[index+1][0][1] + current_col[index+1][2][1]) / 2 # print(bbox_text_dict.get()) distance1 = abs(h1 - h2) distance2 = 10000 for index in range(line2_bbox_index, bbox_index, -1): h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2 h2 = (current_col[index-1][0][1] + current_col[index-1][2][1]) / 2 distance2 = abs(h1 - h2) # print(bbox_text_dict.get(str(bbox)), distance1, distance2) ratio = 1.5 # 属于下一行 if distance1 >= distance2 * ratio or distance1 >= distance2 + 8: line2_bbox_list.append(bbox) # 属于上一行 elif distance2 >= distance1 * ratio or distance2 >= distance1 + 8: line1_bbox_list.append(bbox) else: print('距离不明确,需要nsp模型介入判断') if line1_bbox_list: # print('line1_bbox_list', [bbox_text_dict.get(str(x)) for x in line1_bbox_list]) line1_bbox_list.sort(key=lambda x: x[0][1]) b = line1_bbox_list[-1] line1 = [line1[0], b[2][1], line1[2], b[2][1]] if line2_bbox_list: # print('line2_bbox_list', [bbox_text_dict.get(str(x)) for x in line2_bbox_list]) line2_bbox_list.sort(key=lambda x: x[0][1]) b = line2_bbox_list[0] line2 = [line2[0], b[0][1], line2[2], b[0][1]] _line = [line1[0], (line1[1]+line2[1])/2, line1[2], (line1[3]+line2[3])/2] _line = [int(x) for x in _line] temp_list.append(_line) row_line_list = temp_list # 加上表格轮廓线 threshold = 5 min_w = max(table_location[0], 0+threshold) max_w = min(table_location[2], img.shape[1]-threshold) min_h = max(table_location[1], 0+threshold) max_h = min(table_location[3], img.shape[0]-threshold) row_line_list.append([min_w, min_h, max_w, min_h]) row_line_list.append([min_w, max_h, max_w, max_h]) col_line_list.append([min_w, min_h, min_w, max_h]) col_line_list.append([max_w, min_h, max_w, max_h]) # 由线得到按行列排列的bbox row_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in row_line_list] col_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in col_line_list] table_bbox_list, table_cell_list = get_table_bbox_list(img, [row_line_list], [col_line_list], [table_location], bbox_list) # 线合并 line_list = row_line_list + col_line_list # show if is_test: for r in table_cell_list: for c in r: cv2.rectangle(img, c[0], c[1], (0, 255, 0), 1) cv2.imshow('table_cell', img) for line in col_line_list: cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2) for line in row_line_list: cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2) cv2.namedWindow('img', cv2.WINDOW_NORMAL) cv2.imshow('img', cv2.resize(img, (768, 1024))) cv2.waitKey(0) return line_list, table_cell_list, table_location