import base64 import copy import json import logging import math import random import re import traceback from glob import glob import cv2 from sklearn.cluster import AffinityPropagation, DBSCAN # from tensorflow_version.table_head_predict import predict from botr.utils import request_post, line_iou, pil_resize, get_best_predict_size2, line_overlap import jieba import numpy as np from matplotlib import pyplot as plt def _plot(_line_list, mode=1): for _line in _line_list: if mode == 1: x0, y0, x1, y1 = _line.__dict__.get("bbox") elif mode == 2: x0, y0, x1, y1 = _line elif mode == 3: x0, y0 = _line[0] x1, y1 = _line[1] plt.plot([x0, x1], [y0, y1]) plt.show() return def get_table_by_rule2(img, text_list, bbox_list, table_location, is_test=0): # 处理bbox,缩小框 bbox_list = shrink_bbox(img, bbox_list) # 创建对应dict bbox_text_dict = {} for i in range(len(text_list)): bbox_text_dict[str(bbox_list[i])] = text_list[i] # 获取全局的按行排列bbox row_list = get_table_rows(bbox_list, bbox_text_dict) if len(row_list) == 0: return [], [], [] # 删除只有一个bbox的第一行和最后一行 if len(row_list[0]) == 1: table_location = [table_location[0], row_list[0][0][2][1], table_location[2], table_location[3]] row_list = row_list[1:] if len(row_list[-1]) == 1: table_location = [table_location[0], table_location[1], table_location[2], row_list[-1][0][0][1]] row_list = row_list[:-1] # 获取表格区域,以及区域里的按行排列bbox table_location_list = [[[int(table_location[0]), int(table_location[1])], [int(table_location[2]), int(table_location[3])]]] area_row_list = [row_list] area_row_list = merge_row_bbox_list(area_row_list) # 获取全局的按列排列bbox area_col_list = get_table_cols(bbox_list, table_location_list) # 获取行线、列线 area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list) if is_test: _plot(area_row_lines[0] + area_col_lines[0], mode=3) # 判断列线合法 area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict) # 判断行线合法 area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict) if is_test: _plot(area_row_lines[0] + area_col_lines[0], mode=3) # 由线得到按行列排列的bbox area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list) if is_test: for a in area_table_cell_list: for r in a: for c in r: cv2.rectangle(img, c[0], c[1], (255, 0, 0), 1) # for b in c: # cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1) cv2.imshow('table_cell', img) cv2.waitKey(0) # 展示 if is_test: show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list) if not area_row_lines or not area_col_lines: return [], [], [] line_list = [[x[0][0], x[0][1], x[1][0], x[1][1]] for x in area_row_lines[0] + area_col_lines[0]] cell_list = area_table_cell_list[0] return line_list, cell_list, table_location def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=1): # 处理bbox,缩小框 bbox_list = shrink_bbox(img, bbox_list) # 创建对应dict bbox_text_dict = {} for i in range(len(text_list)): bbox_text_dict[str(bbox_list[i])] = text_list[i] # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox table_left_up_point = [table_location[0], table_location[1]] min_distance = 100000000000 first_bbox = bbox_list[0] for bbox in bbox_list: distance = abs(bbox[0][0] - table_left_up_point[0]) + abs(bbox[0][1] - table_left_up_point[1]) if distance < min_distance: min_distance = distance first_bbox = bbox # 对first_bbox预处理 # 分割 new_bbox_list, bbox_text_dict = split_bbox(img, first_bbox, bbox_text_dict) if new_bbox_list: if first_bbox in bbox_list: bbox_list.remove(first_bbox) bbox_list += new_bbox_list new_bbox_list.sort(key=lambda x: (x[0][0])) first_bbox = new_bbox_list[0] # 根据第一个bbox,得到第一行 first_row = [] bbox_list.sort(key=lambda x: (x[0][1], x[0][0])) for bbox in bbox_list: # h有交集 if first_bbox[0][1] <= bbox[0][1] <= first_bbox[2][1] \ or first_bbox[0][1] <= bbox[2][1] <= first_bbox[2][1] \ or bbox[0][1] <= first_bbox[0][1] <= bbox[2][1] \ or bbox[0][1] <= first_bbox[2][1] <= bbox[2][1]: first_row.append(bbox) # h小于first_box elif bbox[2][1] <= first_bbox[0][1]: first_row.append(bbox) # 对第一行分列 first_row.sort(key=lambda x: (x[0][0], x[0][1])) first_row_col = [] used_bbox = [] for bbox in first_row: if bbox in used_bbox: continue temp_col = [] for bbox1 in first_row: if bbox1 in used_bbox: continue if bbox1[0][0] <= bbox[0][0] <= bbox1[2][0] \ or bbox1[0][0] <= bbox[2][0] <= bbox1[2][0] \ or bbox[0][0] <= bbox1[0][0] <= bbox[2][0] \ or bbox[0][0] <= bbox1[2][0] <= bbox[2][0]: temp_col.append(bbox1) used_bbox.append(bbox1) first_row_col.append(temp_col) # 根据第一个bbox,得到第一列 first_col = [] bbox_list.sort(key=lambda x: (x[0][0], x[0][1])) for bbox in bbox_list: # w有交集 if first_bbox[0][0] <= bbox[0][0] <= first_bbox[2][0] \ or first_bbox[0][0] <= bbox[2][0] <= first_bbox[2][0] \ or bbox[0][0] <= first_bbox[0][0] <= bbox[2][0] \ or bbox[0][0] <= first_bbox[2][0] <= bbox[2][0]: first_col.append(bbox) # w小于first_box elif bbox[2][0] <= first_bbox[0][0]: first_col.append(bbox) # 对第一列分行 first_col.sort(key=lambda x: (x[0][1], x[0][0])) first_col_row = [] current_bbox = first_col[0] temp_row = [] for bbox in first_col: if current_bbox[0][1] <= bbox[0][1] <= current_bbox[2][1] \ or current_bbox[0][1] <= bbox[2][1] <= current_bbox[2][1] \ or bbox[0][1] <= current_bbox[0][1] <= bbox[2][1] \ or bbox[0][1] <= current_bbox[2][1] <= bbox[2][1]: temp_row.append(bbox) else: if temp_row: temp_row.sort(key=lambda x: x[0][1]) first_col_row.append(temp_row) temp_row = [bbox] current_bbox = bbox if temp_row: temp_row.sort(key=lambda x: x[0][1]) first_col_row.append(temp_row) print('len(first_row)', len(first_row)) print('first_row', [bbox_text_dict.get(str(x)) for x in first_row]) print('first_col', [bbox_text_dict.get(str(x)) for x in first_col]) print('len(first_col)', len(first_col)) print('len(first_row_col)', len(first_row_col)) print('len(first_col_row)', len(first_col_row)) # 划线 列 col_line_list = [] for col in first_row_col: # 画2条线,根据左右bbox min_w, max_w = 1000000, 0 print('col', [bbox_text_dict.get(str(x)) for x in col]) for bbox in col: if bbox[0][0] < min_w: min_w = bbox[0][0] if bbox[2][0] > max_w: max_w = bbox[2][0] col_line_list.append([min_w, table_location[1], min_w, table_location[3]]) col_line_list.append([max_w, table_location[1], max_w, table_location[3]]) # 划线 行 row_line_list = [] last_max_h = None for row in first_col_row: # 画3条线,根据上下bbox min_h, max_h = 1000000, 0 for bbox in row: if bbox[0][1] < min_h: min_h = bbox[0][1] if bbox[2][1] > max_h: max_h = bbox[2][1] row_line_list.append([table_location[0], min_h, table_location[2], min_h]) row_line_list.append([table_location[0], max_h, table_location[2], max_h]) # if last_max_h: # row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)]) last_max_h = max_h print('len(col_line_list)', len(col_line_list)) print('col_line_list', col_line_list) print('len(row_line_list)', len(row_line_list)) # 判断列线有没有压在黑色像素上,若有则移动 temp_list = [] for i in range(1, len(col_line_list), 2): # 前一列右边线 line1 = col_line_list[i] line1 = [int(x) for x in line1] # 后一列左边线 if i+1 >= len(col_line_list): break line2 = col_line_list[i+1] line2 = [int(x) for x in line2] max_black_cnt = 10 black_threshold = 150 black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold) print('col black_cnt2', black_cnt2) if black_cnt2 <= max_black_cnt: temp_list.append(line2) else: black_cnt1 = count_black(img[line1[1]:line1[3], line1[0]:line1[2]+1, :], threshold=black_threshold) print('col black_cnt1', black_cnt1) if black_cnt1 <= max_black_cnt: temp_list.append(line1) else: # 两条线都不符合,从右向左移寻找 for j in range(line2[0], line1[0], -1): black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold) print('col black_cnt', black_cnt) if black_cnt <= max_black_cnt: temp_list.append([j, line2[1], j, line2[3]]) break col_line_list = temp_list # 根据列的划线对bbox分列 last_line = [0, 0, 0, 0] col_bbox_list = [] # used_bbox_list = [] for line in col_line_list + [[img.shape[0], 0, img.shape[0], 0]]: col = [] for bbox in bbox_list: # if bbox in used_bbox_list: # continue # print('last_line, line, bbox', last_line, line, bbox) iou = line_iou([[last_line[0], 0], [line[0], 0]], [[bbox[0][0], 0], [bbox[2][0], 0]], axis=0) if iou >= 0.6: col.append(bbox) # used_bbox_list.append(bbox) col.sort(key=lambda x: x[0][1]) col_bbox_list.append(col) last_line = line # 判断行线 temp_list = [] for i in range(1, len(row_line_list), 2): # 前一行下边线 line1 = row_line_list[i] line1 = [int(x) for x in line1] # 后一行上边线 if i+1 >= len(row_line_list): break line2 = row_line_list[i+1] line2 = [int(x) for x in line2] # 判断行线之间的bbox分别属于哪一行 sub_bbox_list = [] threshold = 5 for bbox in bbox_list: if line1[1] - threshold <= bbox[0][1] <= bbox[2][1] <= line2[1]+threshold: sub_bbox_list.append(bbox) # 根据行的h和分列判断bbox属于上一行还是下一行 line1_bbox_list = [] line2_bbox_list = [] if sub_bbox_list: sub_bbox_list.sort(key=lambda x: x[0][1]) min_h = sub_bbox_list[0][0][1] - 1 max_h = sub_bbox_list[-1][2][1] + 1 for bbox in sub_bbox_list: # 找到属于哪一列 current_col = None for col in col_bbox_list: if bbox in col: current_col = copy.deepcopy(col) break if current_col: # 行做成bbox加入列作为基准 line1_bbox = [[0, min_h], [], [0, min_h], []] line2_bbox = [[0, max_h], [], [0, max_h], []] current_col += [line1_bbox, line2_bbox] current_col.sort(key=lambda x: x[0][1]) bbox_index = current_col.index(bbox) line1_bbox_index = current_col.index(line1_bbox) line2_bbox_index = current_col.index(line2_bbox) print('current_col', [bbox_text_dict.get(str(x)) for x in current_col]) print('line1_bbox_index, bbox_index, line2_bbox_index', line1_bbox_index, bbox_index, line2_bbox_index) # 计算距离 distance1 = 10000 for index in range(line1_bbox_index, bbox_index): h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2 h2 = (current_col[index+1][0][1] + current_col[index+1][2][1]) / 2 # print(bbox_text_dict.get()) distance1 = abs(h1 - h2) distance2 = 10000 for index in range(line2_bbox_index, bbox_index, -1): h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2 h2 = (current_col[index-1][0][1] + current_col[index-1][2][1]) / 2 distance2 = abs(h1 - h2) print(bbox_text_dict.get(str(bbox)), distance1, distance2) ratio = 1.5 # 属于下一行 if distance1 >= distance2 * ratio or distance1 >= distance2 + 8: line2_bbox_list.append(bbox) # 属于上一行 elif distance2 >= distance1 * ratio or distance2 >= distance1 + 8: line1_bbox_list.append(bbox) else: print('距离不明确,需要nsp模型介入判断') if line1_bbox_list: print('line1_bbox_list', [bbox_text_dict.get(str(x)) for x in line1_bbox_list]) line1_bbox_list.sort(key=lambda x: x[0][1]) b = line1_bbox_list[-1] line1 = [line1[0], b[2][1], line1[2], b[2][1]] if line2_bbox_list: print('line2_bbox_list', [bbox_text_dict.get(str(x)) for x in line2_bbox_list]) line2_bbox_list.sort(key=lambda x: x[0][1]) b = line2_bbox_list[0] line2 = [line2[0], b[0][1], line2[2], b[0][1]] _line = [line1[0], (line1[1]+line2[1])/2, line1[2], (line1[3]+line2[3])/2] _line = [int(x) for x in _line] temp_list.append(_line) row_line_list = temp_list # 加上表格轮廓线 row_line_list.append([table_location[0], table_location[1], table_location[2], table_location[1]]) row_line_list.append([table_location[0], table_location[3], table_location[2], table_location[3]]) col_line_list.append([table_location[0], table_location[1], table_location[0], table_location[3]]) col_line_list.append([table_location[2], table_location[1], table_location[2], table_location[3]]) # 由线得到按行列排列的bbox area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, [row_line_list], [col_line_list], [table_location], bbox_list) # show if is_test: for line in col_line_list: cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2) for line in row_line_list: cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2) cv2.namedWindow('img', cv2.WINDOW_NORMAL) cv2.imshow('img', cv2.resize(img, (768, 1024))) cv2.waitKey(0) return [], [], [] def split_bbox_by_kmeans(img, bbox, bbox_text_dict): sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :] # 从左至右扫描 def get_table(): # 1. 一个单元格多行合并需解决 √ # 2. 一行多个单字合并 1007.jpg √ # 3. ocr识别错误bbox剔除 # 4. 上下表格合并 距离近,列数一样,或只少了第一列 1005.jpg 1014.jpg 1033.jpg √ # 5. 相近行列线合并 1020.jpg 1025.jpg 1054.jpg 1068.jpg # 6. 行线在合并bbox中间,需向上或向下移动 105.jpg 1054.jpg 1020.jpg # 7. 贴着左边框的长bbox也当做标题分开表格 1047.jpg 1059.jpg √ # 8. 判断非规整表格,单个单元格多个bbox,排除上下连接的bbox 105.jpg # 9. 判断非规整表格,ocr识别漏,黑色像素多 1050.jpg √ # 10. 第一列序号ocr识别漏 1051.jpg # 11. 用其他列作为分行标准,作为辅助,挑平均间隔最大的,行数也够的列 1085.jpg # 12. 判断表格 两个bbox靠的太近的不能作为开始行 1106.jpg √ # 13. 列中所有行间隔都很小,聚类距离统一值 1098.jpg √ # 14. 漏列(需剔除表格中非表格部分) 1059.jpg # 15. 漏行 1064.jpg 1065.jpg 1067.jpg 1085.jpg 1097.jpg 1101.jpg √ # 16. 表格分割错误 1045.jpg 1051.jpg 1078.jpg 1079.jpg √ # 17. 分列时,第一行的表头选定 1051.jpg 1106.jpg 1129.jpg # 18. 分割同一行中多个列 1093.jpg 1095.jpg 110.jpg # 19. 表格漏了 1119.jpg 1141.jpg # 20. 非规整表格判断错误,黑色像素 1122.jpg 1121.jpg √ # 21. 分列错误 1125.jpg 1158.jpg 1020.jpg √ # 22. 分行分列错误(需在第一列排除过长bbox) 1131.jpg 1132.jpg √ # 1135.jpg 1136.jpg 1147.jpg # 23. 表格范围外,与单元格内的文字上下相连 1134.jpg 1142.jpg # 24. 第一列空单元格太多可列为非规整 # 25. 竖线跨越多个bbox的较中心位置,考虑剔除 # 26. 竖线跨越bbox,考虑竖线缩短,将跨越的那一截去掉 1020.jpg # 27. 竖线插在一列中间,需调整其向右找到空白位置 1023.jpg # label_path = glob('../data/borderless_tables/*_label.jpg') # temp_label_path = [] # label_row_dict = {} # for p in label_path: # img = cv2.imread(p) # row_img, col_img = get_lines_from_img(img) # label_row_list, is_standard = get_bbox_by_img(row_img, col_img) # label_row_dict[p] = label_row_list # if is_standard: # temp_label_path.append(p) # label_path = temp_label_path # print('len(label_path)', len(label_path)) # for p in label_path: # print(p) with open('standard_table.txt', 'r') as f: label_path_list = f.readlines() # paths = glob('../data/borderless_tables/1.jpg') # merge_row # paths = glob('../data/borderless_tables/5.jpg') # title # paths = glob('../data/borderless_tables/26.jpg') # merge_col paths = glob('../data/borderless_tables/59.jpg') # split bbox paths = glob('../../hrnet-pytorch-main/my_dataset/borderless_tables/62.jpg') # paths = glob('../data/borderless_tables/57.jpg') paths = glob('../../hrnet-pytorch-main/my_dataset/borderless_tables/3.jpg') # not standard table # paths = glob(r'C:\Users\Administrator\Desktop\test_pdf_table\1.png') # label_path_list.append(r'C:\Users\Administrator\Desktop\test_pdf_table\1_label.jpg\n') paths = glob('../data/borderless_tables/*.jpg') # paths = glob('../data/standard_tables/*.jpg') path_cnt = 0 all_teds = 0 all_standard_cnt = 0 for p in paths: if 'label' in p: continue label_p = p[:-4] + '_label.jpg\n' if label_p not in label_path_list: continue # if path_cnt <= 10: # path_cnt += 1 # continue path_cnt += 1 img = cv2.imread(p) result = test_ocr_model(p) print(p) # print(result) bbox_list = eval(result.get('bbox')) text_list = eval(result.get('text')) bbox_text_dict = {} for i in range(len(text_list)): bbox_text_dict[str(bbox_list[i])] = text_list[i] # split_bbox(img, text_list, bbox_list) # 获取全局的按行排列bbox row_list = get_table_rows(bbox_list) # bbox预处理 bbox_list, text_list, bbox_text_dict = bbox_preprocess(bbox_list, text_list, row_list, bbox_text_dict) # bbox处理后再按行排列bbox row_list = get_table_rows(bbox_list) # 获取表格区域,以及区域里的按行排列bbox table_location_list, area_row_list = get_table_location(row_list) # 表格分割 table_location_list, area_row_list = split_table(table_location_list, area_row_list, bbox_text_dict) table_location_list, area_row_list = split_table(table_location_list, area_row_list, bbox_text_dict) print('fix_table_location_list', table_location_list) # print('fix_area_row_list', area_row_list) # 获取表格区域里,按列排序bbox area_col_list = get_table_cols(bbox_list, table_location_list) # 合并一列中多行bbox area_row_list = merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list) # 排除非规整表格 table_standard_list = delete_not_standard_table(img, area_row_list, area_col_list, table_location_list, bbox_list, bbox_text_dict) # 上下表格合并 area_row_list, area_col_list, table_location_list = merge_table(area_row_list, area_col_list, table_location_list, bbox_list) # 获取行线、列线 area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list) # 根据行列线生成对应bbox行列 area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list) # 添加列线 add_area_col_lines = add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict) for j in range(len(area_col_lines)): area_col_lines[j] += add_area_col_lines[j] # 判断列线合法 area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict) area_col_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=1) area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0) # # for a in area_col_list: # for c in a: # print('area_col_list', [bbox_text_dict.get(str(x)) for x in c]) # # # 合并一列中多行bbox # area_row_list = merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list) # # # 获取行线、列线 # area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list) # # add_area_col_lines = add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict) # # for j in range(len(area_col_lines)): # area_col_lines[j] += add_area_col_lines[j] # # area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list) # # area_col_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=1) # area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0) # # add_area_row_lines = add_row_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict, area_row_lines) for j in range(len(area_row_lines)): area_row_lines[j] += add_area_row_lines[j] # area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict) # 合并相近线 for j in range(len(area_col_lines)): area_col_lines[j] = merge_lines(area_col_lines[j], axis=1) area_row_lines[j] = merge_lines(area_row_lines[j], axis=0) # area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list) # area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list) # # 重新生成 # table_location_list = [] # temp_area_row_list = [] # for temp_row_list in area_row_list: # location_list, temp_row_list = get_table_location(temp_row_list) # table_location_list += location_list # temp_area_row_list += temp_row_list # area_col_list = get_table_cols(bbox_list, table_location_list) # area_row_list = temp_area_row_list # # # 获取行线、列线 # area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list) # # print('len(table_location_list)', len(table_location_list)) # for bbox in bbox_list: # cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])), # (0, 0, 255), 1) # # for i in range(len(table_location_list)): # # location = table_location_list[i] # # cv2.rectangle(img, location[0], location[1], (0, 255, 0), 1) # # row_lines = area_row_lines[i] # col_lines = area_col_lines[i] # for r in row_lines: # cv2.line(img, r[0], r[1], (0, 255, 0), 1) # for c in col_lines: # cv2.line(img, c[0], c[1], (0, 255, 0), 1) # # cv2.imshow('img', img) # cv2.waitKey(0) # 计算标注表格和生成表格的相似度 if len(table_location_list) == 1: # if not table_standard_list[0]: # continue row_lines = area_row_lines[0] col_lines = area_col_lines[0] row_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) col_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) for r in row_lines: cv2.line(row_img, r[0], r[1], (255, 255, 255), 1) for c in col_lines: cv2.line(col_img, c[0], c[1], (255, 255, 255), 1) row_list, is_standard = get_bbox_by_img(row_img, col_img) if not is_standard: continue row_list = merge_text_and_table(bbox_list, row_list) continue_flag = 0 for row in row_list: for b in row: if len(b) > 1: continue_flag = 1 break if continue_flag: continue max_len = 1 continue_flag = 0 for row in row_list: if abs(max_len - len(row)) > 2: continue_flag = 1 break if len(row) > max_len: max_len = len(row) if continue_flag: continue img_label = cv2.imread(label_p[:-1]) row_img1, col_img1 = get_lines_from_img(img_label) label_row_list, label_is_standard = get_bbox_by_img(row_img1, col_img1) if not label_is_standard: continue label_row_list = merge_text_and_table(bbox_list, label_row_list) add_flag = 0 modify_flag = 0 for i in range(len(row_list)): if i >= len(label_row_list): continue row = row_list[i] label_row = label_row_list[i] for r in label_row: if r not in row: add_flag += 1 else: if label_row.index(r) != row.index(r): modify_flag += 1 bbox_cnt = 0 for row in row_list: for b in row: bbox_cnt += 1 label_bbox_cnt = 0 for row in label_row_list: for b in row: label_bbox_cnt += 1 teds = 1 - (add_flag + modify_flag) / max(bbox_cnt, label_bbox_cnt) print('add_flag', add_flag, 'modify_flag', modify_flag, 'bbox_cnt', bbox_cnt, 'label_bbox_cnt', label_bbox_cnt) print('TEDS:', teds, p) all_teds += teds all_standard_cnt += 1 # if teds <= 0.8: # print('row_list', [y for y in [x for x in row_list]]) # print('label_row_list', [y for y in [x for x in label_row_list]]) # cv2.imshow('model_table', row_img+col_img) # cv2.imshow('label_table', row_img1+col_img1) # cv2.waitKey(0) # for i in range(len(row_list)): try: avg_teds = all_teds / all_standard_cnt except: avg_teds = 0 print('standard table cnt', all_standard_cnt) print('Avg TEDS', avg_teds) return def get_table_new(): with open('standard_table.txt', 'r') as f: label_path_list = f.readlines() # 表格分割问题:1019.jpg, 1020.jpg, 1023.jpg, 1027.jpg, 1029.jpg, 1030.jpg, 1031.jpg, 1035.jpg, 1040.jpg, 1042.jpg, 1046.jpg, 1047.jpg, 1061.jpg, 1064.jpg, 1067.jpg, 1072.jpg # 分列问题:1059.jpg, paths = glob('../data/borderless_tables/*.jpg') # paths = glob(r'C:\Users\Administrator\Desktop\test_pdf_table\1.png') paths = ['1019.jpg', '1020.jpg', '1023.jpg', '1027.jpg', '1029.jpg', '1030.jpg', '1031.jpg', '1035.jpg', '1040.jpg', '1042.jpg', '1046.jpg', '1047.jpg', '1061.jpg', '1064.jpg', '1067.jpg', '1072.jpg'] paths = ['../data/borderless_tables/' + x for x in paths] path_cnt = 0 for p in paths: if 'label' in p: continue # label_p = p[:-4] + '_label.jpg\n' # if label_p not in label_path_list: # continue # if path_cnt <= 22: # path_cnt += 1 # continue path_cnt += 1 img = cv2.imread(p) result = test_ocr_model(p) print(p) bbox_list = eval(result.get('bbox')) text_list = eval(result.get('text')) # 处理bbox,缩小框 bbox_list = shrink_bbox(img, bbox_list) # 创建对应dict bbox_text_dict = {} for i in range(len(text_list)): bbox_text_dict[str(bbox_list[i])] = text_list[i] # 获取全局的按行排列bbox row_list = get_table_rows(bbox_list, bbox_text_dict) # 获取表格区域,以及区域里的按行排列bbox table_location_list, area_row_list = get_table_location(row_list, bbox_text_dict) area_row_list = merge_row_bbox_list(area_row_list) # for a in area_row_list: # i = 0 # for r in a: # print('row', i) # i += 1 # for b in r: # print(bbox_text_dict.get(str(b))) # 获取全局的按列排列bbox area_col_list = get_table_cols(bbox_list, table_location_list) # 获取行线、列线 area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list) # 判断列线合法 area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict) # # 判断行线合法 area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict) # 由线得到按行列排列的bbox area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list) for a in area_table_bbox_list: for r in a: for c in r: # cv2.rectangle(img, c[0], c[1], (255, 0, 0), 1) for b in c: cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1) cv2.imshow('table_cell', img) # 分割表格 # table_location_list, _ = split_table_new2(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict) # table_location_list, _ = split_table(table_location_list, area_row_list, bbox_text_dict) # table_location_list = split_table_by_col(table_location_list, area_table_bbox_list, bbox_text_dict) # table_location_list = split_table_by_table_head(table_location_list, area_table_bbox_list, bbox_text_dict) # 重新生成按行排列bbox area_row_list = get_table_rows2(area_row_list, table_location_list) # for a in area_row_list: # for r in a: # for b in r: # cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1) # cv2.imshow('area_row_list', img) # 获取全局的按列排列bbox area_col_list = get_table_cols(bbox_list, table_location_list) # 获取行线、列线 area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list) # 判断列线合法 area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict) # 判断行线合法 area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict) # 展示 show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list) return def show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list): for bbox in bbox_list: cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])), (0, 0, 255), 1) for i in range(len(table_location_list)): # location = table_location_list[i] # cv2.rectangle(img, location[0], location[1], (0, 255, 0), 1) row_lines = area_row_lines[i] col_lines = area_col_lines[i] for r in row_lines: cv2.line(img, r[0], r[1], (0, 255, 0), 1) for c in col_lines: cv2.line(img, c[0], c[1], (0, 255, 0), 1) cv2.namedWindow('img', cv2.WINDOW_NORMAL) cv2.imshow('img', img) cv2.waitKey(0) return def get_table_borders(area_row_list, area_col_list, table_location_list): area_row_lines = [] area_col_lines = [] # 循环每个表格 for i in range(len(area_row_list)): row_list = area_row_list[i] col_list = area_col_list[i] location = table_location_list[i] # 获取行线 row_lines = [] row_lines.append([[location[0][0], location[0][1]], [location[1][0], location[0][1]]]) for row in row_list: max_h = 0 for bbox in row: if bbox[2][1] > max_h: max_h = int(bbox[2][1]) row_lines.append([[location[0][0], max_h], [location[1][0], max_h]]) row_lines[-1][0][1] = max(location[1][1], row_lines[-1][0][1]) row_lines[-1][1][1] = max(location[1][1], row_lines[-1][1][1]) # 补充表格行范围 table_location_list[i][1][1] = max(location[1][1], row_lines[-1][1][1]) location = table_location_list[i] # 获取列线 col_lines = [] col_lines.append([[location[0][0], location[0][1]], [location[0][0], location[1][1]]]) for col in col_list: max_w = 0 for bbox in col: if bbox[2][0] > max_w: max_w = int(bbox[2][0]) col_lines.append([[max_w, location[0][1]], [max_w, location[1][1]]]) # 补充表格列范围 table_location_list[i][1][0] = max(location[1][0], col_lines[-1][1][0]) location = table_location_list[i] for row in row_lines: row[0][0] = location[0][0] row[1][0] = location[1][0] area_row_lines.append(row_lines) area_col_lines.append(col_lines) return area_row_lines, area_col_lines def get_table_location(row_list, bbox_text_dict): # for r in row_list: # print('row', r) up_h = 10000 bottom_h = 0 left_w = 10000 right_w = 0 table_rows = 0 tolerance_list = [] area_row_list = [] temp_row_list = [] table_location_list = [] catalog_text_cnt = 0 for row in row_list: if len(row) >= 2: if not temp_row_list: # 第一行bbox之间需大于一定值 max_distance = 0 row.sort(key=lambda x: x[0][0]) row_text_list = [] catalog_text_cnt = 0 bbox_height_list = [abs(row[-1][0][1] - row[-1][2][1])] for i in range(1, len(row)): dis = row[i][0][0] - row[i-1][2][0] if dis >= max_distance: max_distance = dis text = bbox_text_dict.get(str(row[i-1])) row_text_list.append(bbox_text_dict.get(str(row[i-1]))) match = re.findall('\\.+\d+', text) if match and len(match[0]) == len(text): catalog_text_cnt += 1 bbox_height_list.append(abs(row[i][0][1] - row[i][2][1])) # 排除 # if len(row) == 2: # if max_distance <= abs(row[0][2][0] - row[0][0][0]): # continue # else: if max_distance <= 5: continue # 排除 '地 址', '名 称' # if len(row) == 2 and len(bbox_text_dict.get(str(row[0]))) == 1: # continue row_text_list = [] bbox_height_list = [] for i in range(len(row)): text = bbox_text_dict.get(str(row[i-1])) row_text_list.append(bbox_text_dict.get(str(row[i-1]))) match = re.findall('\\.+\d+', text) if match and len(match[0]) == len(text): catalog_text_cnt += 1 bbox_height_list.append(abs(row[i][0][1] - row[i][2][1])) # 排除height差别过大的 bbox_height_list.sort(key=lambda x: x) if bbox_height_list[-1] - bbox_height_list[0] > bbox_height_list[0]: continue # 排除目录 if catalog_text_cnt >= 3: continue # 排除水印图 if len(list(set(row_text_list))) < 2/3 * len(row): continue # 排除有下划线的 table_rows += 1 temp_row_list.append(row) for bbox in row: if up_h > bbox[0][1]: up_h = bbox[0][1] if bottom_h < bbox[2][1]: bottom_h = bbox[2][1] if left_w > bbox[0][0]: left_w = bbox[0][0] if right_w < bbox[2][0]: right_w = bbox[2][0] else: if len(tolerance_list) < 3 and table_rows > 0: tolerance_list.append(row) temp_row_list.append(row) continue if table_rows > 2 and up_h < bottom_h: table_location_list.append([[int(left_w), int(up_h)], [int(right_w), int(bottom_h)]]) if tolerance_list[-1] == temp_row_list[-1]: area_row_list.append(temp_row_list[:-1]) else: area_row_list.append(temp_row_list) up_h = 10000 bottom_h = 0 left_w = 10000 right_w = 0 table_rows = 0 tolerance_list = [] temp_row_list = [] if temp_row_list: if table_rows > 2 and up_h < bottom_h: table_location_list.append([[int(left_w), int(up_h)], [int(right_w), int(bottom_h)]]) area_row_list.append(temp_row_list) return table_location_list, area_row_list def get_table_rows(bbox_list, bbox_text_dict): bbox_list.sort(key=lambda x: (x[0][1], x[2][1], x[0][0], x[2][0])) row_list = [] used_bbox_list = [] for b1 in bbox_list: if b1 in used_bbox_list: continue temp_bbox_list = [b1] used_bbox_list.append(b1) for b2 in bbox_list: if b2 in used_bbox_list: continue if abs((b1[0][1] + b1[2][1]) / 2 - (b2[0][1] + b2[2][1]) / 2) <= 10 \ and line_overlap(b1[0][1], b1[2][1], b2[0][1], b2[2][1]) >= 1/2*min(b1[2][1]-b1[0][1], b2[2][1]-b2[0][1]): temp_bbox_list.append(b2) used_bbox_list.append(b2) row_list.append(temp_bbox_list) return row_list def get_table_rows2(area_row_list, table_location_list): temp_area_row_list = [] for area in area_row_list: temp_area_row_list += area area_row_list = [] for location in table_location_list: row_list = [] for row in temp_area_row_list: if location[0][1] <= row[0][0][1] <= row[0][2][1] <= location[1][1]: row_list.append(row) area_row_list.append(row_list) return area_row_list def get_table_bbox_row_or_col(bbox_list, axis=0): bbox_list.sort(key=lambda x: (x[0][1-axis], x[2][1-axis], x[0][axis], x[2][axis])) row_list = [] used_bbox_list = [] for b1 in bbox_list: if b1 in used_bbox_list: continue temp_bbox_list = [b1] used_bbox_list.append(b1) for b2 in bbox_list: if b2 in used_bbox_list: continue if abs((b1[0][1-axis] + b1[2][1-axis]) / 2 - (b2[0][1-axis] + b2[2][1-axis]) / 2) <= 10: temp_bbox_list.append(b2) used_bbox_list.append(b2) row_list.append(temp_bbox_list) return row_list def get_table_cols(bbox_list, table_location_list): bbox_list.sort(key=lambda x: (x[0][0], x[2][0], x[0][1], x[2][1])) all_col_list = [] used_bbox_list = [] for location in table_location_list: sub_bbox_list = [] for b in bbox_list: if location[0][1] <= (b[0][1] + b[2][1])/2 <= location[1][1]: sub_bbox_list.append(b) col_list = [] for b1 in sub_bbox_list: if b1 in used_bbox_list: continue col_width = [b1[0][0], b1[2][0]] temp_bbox_list = [b1] used_bbox_list.append(b1) for b2 in sub_bbox_list: if b2 in used_bbox_list: continue # 判断同一列 # 1. 中心点相差一定范围内 # 2. 左边点相差一定范围内 # 3. 行范围包含 # 4. iou大于一定值 if abs((b1[0][0] + b1[2][0]) / 2 - (b2[0][0] + b2[2][0]) / 2) <= 10 \ or abs(b1[0][0] - b2[0][0]) <= 10 \ or col_width[0] <= b2[0][0] <= b2[2][0] <= col_width[1] \ or b2[0][0] <= col_width[0] <= col_width[1] <= b2[2][0] \ or line_iou([[col_width[0], 0], [col_width[1], 0]], [[b2[0][0], 0], [b2[1][0], 0]], axis=0) >= 0.6: temp_bbox_list.append(b2) used_bbox_list.append(b2) if b2[0][0] < col_width[0]: col_width[0] = b2[0][0] if b2[2][0] > col_width[1]: col_width[1] = b2[2][0] col_list.append(temp_bbox_list) all_col_list.append(col_list) return all_col_list def merge_col_bbox_by_cluster(img, area_row_list, area_col_list, bbox_text_dict, all_bbox_list, table_location_list): temp_img = copy.deepcopy(img) # 循环每个表格 for i in range(len(area_row_list)): row = area_row_list[i] col = area_col_list[i] # 循环每一列,计算列中行之间的间隔距离 new_col = [] col_cnt = 0 for bbox_list in col: # 获取间隔距离 distance_list = [] bbox_list.sort(key=lambda x: (x[0][1], x[1][1])) text_list = [bbox_text_dict.get(str(x)) for x in bbox_list] for j in range(1, len(bbox_list)): dis = bbox_list[j][0][1] - bbox_list[j-1][2][1] if dis < 0: dis = 0. distance_list.append(dis) print("\n") print("distance_list", distance_list) # 聚类获取类别组 data_list = [[0, x] for x in distance_list] # 排除距离大于一定值的 data_mask_list = [] temp_data_list = [] for j in range(len(data_list)): if data_list[j][1] < 5.: data_mask_list.append(True) temp_data_list.append(data_list[j]) else: data_mask_list.append(False) data_list = temp_data_list print("data_list", data_list) cluster_list = [] if len(data_list) > 2: # 聚类 pred_list = dbscan(data_list) print('pred_list', pred_list) temp_pred_list = [] for j in data_mask_list: if j: temp_pred_list.append(pred_list.pop(0)) else: temp_pred_list.append(-1) pred_list = temp_pred_list print('pred_list', pred_list) cluster_num = len(list(set(pred_list))) for k in range(cluster_num): temp_list = [] for j in range(len(pred_list)): if pred_list[j] == k: if temp_list: if j - temp_list[-1] == 1: temp_list.append(j) else: temp_list.append(j) else: if temp_list: cluster_list.append(temp_list) temp_list = [] if temp_list: cluster_list.append(temp_list) elif len(data_list) > 0: temp_list = [] for j in range(len(distance_list)): if distance_list[j] < 5.0: temp_list.append(j) else: if temp_list: cluster_list.append(temp_list) temp_list = [] if temp_list: cluster_list.append(temp_list) # cluster_list.append([x for x in range(len(distance_list))]) print('text_list', text_list) print('cluster_list', cluster_list) # 合并bbox new_bbox_list = copy.deepcopy(bbox_list) for cluster in cluster_list: merge_flag = 1 for dis in [distance_list[x] for x in cluster]: if dis >= 5.0: merge_flag = 0 break if merge_flag: b_list = bbox_list[cluster[0]:cluster[-1]+2] t_list = text_list[cluster[0]:cluster[-1]+2] min_w = 10000 max_w = 0 min_h = 10000 max_h = 0 b_list = [eval(x) for x in list(set([str(x) for x in b_list]))] for bbox in b_list: if bbox in new_bbox_list: new_bbox_list.remove(bbox) if bbox in all_bbox_list: all_bbox_list.remove(bbox) if bbox[0][0] < min_w: min_w = bbox[0][0] if bbox[0][1] < min_h: min_h = bbox[0][1] if bbox[2][0] > max_w: max_w = bbox[2][0] if bbox[2][1] > max_h: max_h = bbox[2][1] new_bbox = [[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]] new_bbox_list.append(new_bbox) all_bbox_list.append(new_bbox) # 根据第一列的合并结果,指导其他列合并 if col_cnt == 0: first_col_rows = get_first_col_rows(new_bbox_list, table_location_list[i]) for r in first_col_rows: cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (0, 0, 255), 1) cv2.imshow('temp_img', temp_img) # cv2.waitKey(0) col_cnt += 1 # new_col.append(new_bbox_list) return all_bbox_list def merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list): temp_img = copy.deepcopy(img) # 循环每个表格 for i in range(len(area_row_list)): row_list = area_row_list[i] col_list = area_col_list[i] table_location = table_location_list[i] sub_bbox_list = [] for bbox in bbox_list: if table_location[0][1] <= bbox[0][1] <= table_location[1][1] \ or table_location[0][1] <= bbox[1][1] <= table_location[1][1]: sub_bbox_list.append(bbox) # 对第一列聚类,合并,再根据空白分行 first_col = col_list[0] cluster_list, distance_list = distance_cluster(first_col, axis=1) merge_first_col = merge_cluster(first_col, cluster_list, distance_list) merge_first_col.sort(key=lambda x: (x[0][1], x[0][0])) row_lines = get_first_col_rows(merge_first_col, table_location) # 对其他列聚类,合并 # merge_bbox_list = [] + first_col # for col in col_list[1:]: # cluster_list = distance_cluster(col, axis=1) # merge_col = merge_cluster(col, cluster_list) # merge_bbox_list += merge_col # 循环每一列,根据分行合并 new_row_list = [] row_lines.sort(key=lambda x: x) row_cnt = 0 need_add_bbox = [] # for c in first_col: # print('first col ', bbox_text_dict.get(str(c))) for j in range(1, len(row_lines)): print('\n') top_line = row_lines[j-1] bottom_line = row_lines[j] new_row = [] if need_add_bbox: # print('add') new_row += need_add_bbox print('add', bbox_text_dict.get(str(new_row[0]))) need_add_bbox = [] # 合并条件: # 1. 完全包含 # 2. 处在两行之间,判断bbox与第一列的这两行的bbox高度距离 for bbox in sub_bbox_list: if top_line <= bbox[0][1] <= bbox[2][1] <= bottom_line: new_row.append(bbox) # print('bbox, line', bbox_text_dict.get(str(bbox)), top_line, bottom_line) else: if bbox in first_col: continue # 如果第一列只有一行,交界处的bbox不算 if len(first_col) == 1: need_add_bbox.append(bbox) continue # 计算离该bbox最近的上下两个第一列的bbox first_col_center_h1 = 0 first_col_center_h2 = 10000 first_col_bbox1 = None first_col_bbox2 = None bbox_center_h = (bbox[0][1] + bbox[2][1]) / 2 for b in first_col: b_center_h = (b[0][1] + b[2][1]) / 2 # if bbox[0][1] <= b_center_h <= bbox[2][1]: # first_col_center_h2 = b_center_h # break if bbox_center_h >= b_center_h and bbox_center_h - b_center_h <= bbox_center_h - first_col_center_h1: first_col_center_h1 = b_center_h first_col_bbox1 = b if b_center_h >= bbox_center_h and b_center_h - bbox_center_h <= first_col_center_h2 - bbox_center_h: first_col_center_h2 = b_center_h first_col_bbox2 = b # 如果离该bbox最近的第一列的bbox,不是这一行的 if new_row and first_col_bbox1 != new_row[0] and top_line < bbox[0][1] < bottom_line: need_add_bbox.append(bbox) continue # if top_line <= bbox[2][1] <= bottom_line \ # and abs(first_col_center_h1 - bbox_center_h) >= abs(first_col_center_h2 - bbox_center_h): # new_row.append(bbox) # if first_col_bbox1 and first_col_bbox2: # print('bbox1, bbox2', bbox_text_dict[str(first_col_bbox1)], bbox_text_dict[str(first_col_bbox2)], # bbox_text_dict[str(bbox)]) if top_line < bbox[0][1] < bottom_line \ and abs(first_col_center_h1 - bbox_center_h) <= abs(first_col_center_h2 - bbox_center_h): new_row.append(bbox) elif top_line < bbox[0][1] < bottom_line: need_add_bbox.append(bbox) for r in need_add_bbox: print("next_row bbox", bbox_text_dict.get(str(r))) print('row', row_cnt, len(new_row)) for b in new_row: print(bbox_text_dict.get(str(b))) row_cnt += 1 new_row_list.append(new_row) area_row_list[i] = new_row_list # show r_cnt = 0 # for r in row_lines: # if r_cnt == 0 or r_cnt == len(row_lines) - 1: # cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (255, 0, 0), 1) # else: # cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (0, 255, 0), 1) # r_cnt += 1 # for b in merge_bbox_list: # cv2.rectangle(temp_img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (0, 0, 255), 1) # cv2.imshow('temp_img', temp_img) return area_row_list def distance_cluster(bbox_list, max_distance=5., axis=1): # 获取间隔距离 distance_list = [] bbox_list.sort(key=lambda x: (x[0][1], x[1][1])) for j in range(1, len(bbox_list)): dis = bbox_list[j][0][axis] - bbox_list[j-1][2][axis] if dis < 0: dis = 0. distance_list.append(dis) print("\n") print("distance_list", distance_list) # 聚类获取类别组 data_list = [[0, x] for x in distance_list] # 排除距离大于一定值的 data_mask_list = [] temp_data_list = [] for j in range(len(data_list)): if data_list[j][1] < max_distance: data_mask_list.append(True) temp_data_list.append(data_list[j]) else: data_mask_list.append(False) data_list = temp_data_list print("data_list", data_list) cluster_list = [] if len(data_list) > 2: # 聚类 pred_list = dbscan(data_list) print('pred_list', pred_list) temp_pred_list = [] for j in data_mask_list: if j: temp_pred_list.append(pred_list.pop(0)) else: temp_pred_list.append(-1) pred_list = temp_pred_list print('pred_list', pred_list) cluster_num = len(list(set(pred_list))) for k in range(cluster_num): temp_list = [] for j in range(len(pred_list)): if pred_list[j] == k: if temp_list: if j - temp_list[-1] == 1: temp_list.append(j) else: temp_list.append(j) else: if temp_list: cluster_list.append(temp_list) temp_list = [] if temp_list: cluster_list.append(temp_list) elif len(data_list) > 0: temp_list = [] for j in range(len(distance_list)): if distance_list[j] < max_distance: temp_list.append(j) else: if temp_list: cluster_list.append(temp_list) temp_list = [] if temp_list: cluster_list.append(temp_list) print('cluster_list', cluster_list) return cluster_list, distance_list def merge_cluster(bbox_list, cluster_list, distance_list): new_bbox_list = copy.deepcopy(bbox_list) # 特殊情况:每行之间空隙小,且规律,会全被分到一个类 if len(cluster_list) == 1 and len(cluster_list[0]) >= 4: cluster_list = [[x] for x in cluster_list[0]] # 每行的空隙小且均匀 if distance_list: if max(distance_list) - min(distance_list) <= 5.5: cluster_list = [[i] for i in range(len(distance_list))] # 去掉一个最大值,其他的空隙小且均匀 if distance_list and max(distance_list) - min(distance_list) >= 10: index = distance_list.index(max(distance_list)) if index <= 2 and len(distance_list[index+1:]) >= 3 and max(distance_list[index+1:]) - min(distance_list[index+1:]) <= 5.5: if index == 0: cluster_list = [[i] for i in range(len(distance_list[index+1:]))] else: if max(distance_list[:index]) - min(distance_list[:index]) <= 5.5: cluster_list = [[i] for i in range(len(distance_list[:index]))] cluster_list += [[i] for i in range(len(distance_list[index+1:]))] for cluster in cluster_list: b_list = bbox_list[cluster[0]:cluster[-1]+2] min_w = 10000 max_w = 0 min_h = 10000 max_h = 0 b_list = [eval(x) for x in list(set([str(x) for x in b_list]))] for bbox in b_list: if bbox in new_bbox_list: new_bbox_list.remove(bbox) if bbox[0][0] < min_w: min_w = bbox[0][0] if bbox[0][1] < min_h: min_h = bbox[0][1] if bbox[2][0] > max_w: max_w = bbox[2][0] if bbox[2][1] > max_h: max_h = bbox[2][1] new_bbox = [[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]] new_bbox_list.append(new_bbox) return new_bbox_list def get_first_col_rows(first_col, table_location): """ 根据第一列的bbox,分行 :return: """ location_top = table_location[0][1] location_bottom = table_location[1][1] row_block_list = [table_location[0][1]] for i in range(len(first_col)): bbox = first_col[i] if i + 1 < len(first_col): next_bbox = first_col[i+1] bbox_distance = abs(bbox[2][1] - next_bbox[0][1]) else: bbox_distance = 10000 if i == 0: top_block = abs(bbox[0][1] - location_top) bottom_block = min(top_block, bbox_distance) sub_row = bbox[2][1] + bottom_block else: top_block = abs(bbox[0][1] - row_block_list[-1]) bottom_block = min(top_block, bbox_distance) sub_row = bbox[2][1] + bottom_block row_block_list.append(sub_row) if len(row_block_list) == 2: row_block_list.append(location_bottom) else: row_block_list[-1] = max(row_block_list[-1], location_bottom) return row_block_list def judge_standard_table(row_list): up_h = 10000 bottom_h = 0 left_w = 10000 right_w = 0 table_rows = 0 now_row_len = 0 init_flag = 0 tolerance_list = [] area_row_list = [] temp_row_list = [] table_location_list = [] for row in row_list: if init_flag: up_h = 10000 bottom_h = 0 left_w = 10000 right_w = 0 table_rows = 0 tolerance_list = [] temp_row_list = [] init_flag = 0 if len(row) >= 2: if now_row_len == 0: now_row_len = len(row) else: if len(row) != now_row_len: init_flag = 1 continue table_rows += 1 temp_row_list.append(row) for bbox in row: if up_h > bbox[0][1]: up_h = bbox[0][1] if bottom_h < bbox[2][1]: bottom_h = bbox[2][1] if left_w > bbox[0][0]: left_w = bbox[0][0] if right_w < bbox[2][0]: right_w = bbox[2][0] else: if len(tolerance_list) < 1 and table_rows > 0: tolerance_list.append(row) temp_row_list.append(row) continue if table_rows > 1 and up_h < bottom_h: table_location_list.append([[int(left_w), int(up_h)], [int(right_w), int(bottom_h)]]) if tolerance_list[-1] == temp_row_list[-1]: area_row_list.append(temp_row_list[:-1]) else: area_row_list.append(temp_row_list) init_flag = 1 return table_location_list, area_row_list def split_bbox(img, bbox, bbox_text_dict): text = bbox_text_dict.get(str(bbox)) sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :] split_line_list = [] last_i_status = 1 # 从左到右遍历img for i in range(1, sub_img.shape[1]): # 若这一列黑色像素超过一定值 if np.where(sub_img[:, i, :] < 200)[0].size > sub_img.shape[0]/5: i_status = 0 else: i_status = 1 # 异或,上个像素列为黑且这个像素列为白,或上个像素列为白且这个像素列为黑 if last_i_status ^ i_status: split_line_list.append(int(i)) last_i_status = i_status # 两条分割线太近的去重 min_len = 5 last_l = split_line_list[0] temp_list = [split_line_list[0]] for l in split_line_list[1:]: if l - last_l > min_len: temp_list.append(l) last_l = l split_line_list = temp_list # 若两个分割线间无黑像素,则是应该分割的 split_pair_list = [] last_line = split_line_list[0] for line in split_line_list[1:]: print('last_line, line', last_line, line, np.where(sub_img[:, last_line:line, :] < 100)[0].size) if line - last_line >= 10 and np.where(sub_img[:, last_line:line, :] < 100)[0].size < 10: split_pair_list.append([last_line, line]) last_line = line print('split_pair_list', split_pair_list) for l in split_line_list: l = int(l + bbox[0][0]) cv2.line(img, (l, int(bbox[0][1])), (l, int(bbox[2][1])), (0, 255, 0), 2) cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])), (0, 0, 255), 1) cv2.imshow('img', img) cv2.waitKey(0) # 分割得到新bbox split_bbox_list = [] if split_pair_list: start_line = 0 for line1, line2 in split_pair_list: w1 = start_line + bbox[0][0] w2 = line1 + bbox[0][0] start_line = line2 split_bbox_list.append([[w1, bbox[0][1]], [], [w2, bbox[2][1]], []]) w1 = start_line + bbox[0][0] w2 = bbox[2][0] split_bbox_list.append([[w1, bbox[0][1]], [], [w2, bbox[2][1]], []]) print('split_bbox_list', split_bbox_list) # 计算每个字长度 all_len = 0 bbox_len_list = [] for bbox in split_bbox_list: _len = abs(bbox[2][0] - bbox[0][0]) all_len += _len bbox_len_list.append(_len) single_char_len = all_len / len(text) # 根据bbox长度和单字长度比例计算得到截取后的text split_text_list = [] text_start = 0 for _len in bbox_len_list: text_num = int(_len / single_char_len + 0.5) text_end = text_start+text_num if text_end >= len(text): text_end = len(text) split_text_list.append(text[text_start:text_end]) text_start = text_end print('split_text_list', split_text_list) # 更新bbox_text_dict for i, bbox in enumerate(split_bbox_list): bbox_text_dict[str(bbox)] = split_text_list[i] return split_bbox_list, bbox_text_dict def split_table(table_location_list, area_row_list, bbox_text_dict): temp_location_list = [] temp_area_row_list = [] for i in range(len(table_location_list)): location = table_location_list[i] sub_row_list = area_row_list[i] # 截断标题,对只有行中间或行开头一个bbox的行进行排除 need_split_index = [] for j in range(len(sub_row_list)): row = sub_row_list[j] if len(row) == 1: threshold = (row[0][2][0]-row[0][0][0])*1/3 if row[0][0][0] + threshold <= (location[0][0]+location[1][0])/2 <= row[0][2][0] - threshold: need_split_index.append(j) elif abs(location[0][0] - row[0][0][0]) <= 5 \ and row[0][2][0] - row[0][0][0] >= 1/5 * (location[1][0]-location[0][0]): need_split_index.append(j) print('need_split_index', need_split_index) if not need_split_index: temp_location_list.append(location) temp_area_row_list.append(sub_row_list) else: last_index = 0 need_split_index.append(len(sub_row_list)) for index in need_split_index: if index == last_index: last_index = index + 1 continue if len(sub_row_list[last_index:index]) < 2: last_index = index + 1 continue temp_area_row_list.append(sub_row_list[last_index:index]) min_w = 10000 max_w = 0 min_h = 10000 max_h = 0 for row in sub_row_list[last_index:index]: for bbox in row: if bbox[0][0] < min_w: min_w = bbox[0][0] if bbox[0][1] < min_h: min_h = bbox[0][1] if bbox[2][0] > max_w: max_w = bbox[2][0] if bbox[2][1] > max_h: max_h = bbox[2][1] temp_location_list.append([[int(min_w), int(min_h)], [int(max_w), int(max_h)]]) last_index = index+1 return temp_location_list, temp_area_row_list def split_table_by_col(table_location_list, area_table_bbox_list, bbox_text_dict): for i in range(len(table_location_list)): location = table_location_list[i] table_bbox_list = area_table_bbox_list[i] # 循环每一行 split_index_list = [] for j in range(1, len(table_bbox_list)): row = table_bbox_list[j] last_row = table_bbox_list[j-1] row_bbox_cnt_list = [len(x) for x in row] last_row_bbox_cnt_list = [len(x) for x in last_row] diff_num = 0 diff_flag = 0 for k in range(len(row_bbox_cnt_list)): if row_bbox_cnt_list[k] > last_row_bbox_cnt_list[k]: if last_row_bbox_cnt_list[k] != 0: diff_flag = 1 diff_num += 1 if diff_num > 0 and diff_flag: split_index_list.append(j) continue print('split_index_list', split_index_list) return table_location_list # def split_table_by_table_head(table_location_list, area_table_bbox_list, bbox_text_dict): # new_table_location_list = [] # for i in range(len(table_location_list)): # location = table_location_list[i] # table_bbox_list = area_table_bbox_list[i] # # # 每行单独进行表头预测 # table_head_row_list = [] # for j in range(len(table_bbox_list)): # row = table_bbox_list[j] # print('row', row) # # if row.count([]) == len(row): # table_head_row_list.append([['', 0]]) # continue # # row_bbox_list = [] # for col in row: # for b in col: # new_b = bbox_text_dict.get(str(b)) # new_b = re.sub("^[^\u4e00-\u9fa5a-zA-Z0-9]+", "", new_b) # new_b = re.sub("[^\u4e00-\u9fa5a-zA-Z0-9]+$", "", new_b) # row_bbox_list.append(new_b) # result_list = predict([row_bbox_list]) # # 组合结果 # for m in range(len(result_list)): # for n in range(len(result_list[m])): # result_list[m][n] = [row_bbox_list[n], int(result_list[m][n])] # result_list = result_list[0] # print('table_head', result_list) # table_head_row_list.append(result_list) # # # 根据表头分割 # split_index_list = [] # for j in range(1, len(table_head_row_list)): # row_head = [x[1] for x in table_head_row_list[j]] # last_row_head = [x[1] for x in table_head_row_list[j-1]] # # # [['6', 0], ['税费', 0], ['依法缴纳', 0], ['1', 0], ['次', 0], ['25000', 0], ['25000', 0]] # # [['大写', 1], ['肆抢柒万元整', 0]] # if 1 in row_head and 1 not in last_row_head: # split_index_list.append(j) # # # [['供应商', 1], ['广东一线达通网络科技有限公司', 0]] # # [['货物明细', 1], ['单价金额(元', 1], ['数量', 1], ['总计金额(元', 1]] # if 1 in row_head and 1 in last_row_head and 0 not in row_head and row_head.count(1) != last_row_head.count(1): # split_index_list.append(j) # print('split_index_list', split_index_list) # # new_location_list = table_split_by_index(location, split_index_list, table_bbox_list) # print('new_location_list, location', new_location_list, location) # new_table_location_list += new_location_list # print('new_table_location_list', new_table_location_list) # return new_table_location_list def table_split_by_index(table_location, split_index_list, table_bbox_list): if split_index_list: # 分割表格 split_index_list = [0] + split_index_list + [len(table_bbox_list)] split_index_list = list(set(split_index_list)) split_index_list.sort(key=lambda x: x) print('split_index_list', split_index_list) new_location_list = [] for l in range(1, len(split_index_list)): index = split_index_list[l] last_index = split_index_list[l-1] # if index - last_index <= 2: # continue # 获取范围 rows = table_bbox_list[last_index:index] min_h, min_w = 10000, 10000 max_h, max_w = 0, 0 for r in rows: for c in r: for b in c: if b: if b[0][0] < min_w: min_w = int(b[0][0]) if b[0][1] < min_h: min_h = int(b[0][1]) if b[2][0] > max_w: max_w = int(b[2][0]) if b[2][1] > max_h: max_h = int(b[2][1]) new_location = [[min_w, min_h], [max_w, max_h]] new_location_list.append(new_location) print('new_location', new_location) if new_location_list: return new_location_list else: return [table_location] else: return [table_location] def split_table_new(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict): temp_location_list = [] temp_area_row_list = [] for k in range(len(table_location_list)): table = area_table_bbox_list[k] location = table_location_list[k] row_list = area_row_list[k] table_cell_list = area_table_cell_list[k] split_row_index_list = [] # 遍历所有行 for i in range(len(table)): row = table[i] # print('row', i) # for j in range(len(row)): # col = row[j] # print('col', j, ';'.join([bbox_text_dict.get(str(x)) for x in col])) # 判断该行为表格分割行: # 1. 这一行只有一列有值,且这一行的这一列中的最长的bbox,包含其他行同列中的2个以上bbox # 2. 这一行只有一列有值,且这一行的这一列中的最长的bbox,跨过了多列 # 获取前后多行 n = 3 if i-n < 0: last_n_rows = table[0:i] else: last_n_rows = table[i-n:i] if i+1 >= len(table): next_n_rows = [] elif i+n+1 >= len(table): next_n_rows = table[i+1:len(table)] else: next_n_rows = table[i+1:i+n+1] # 寻找一行只有一格有数据的 not_empty_col_cnt = 0 only_one_index = -1 for j in range(len(row)): col = row[j] if col: not_empty_col_cnt += len(col) only_one_index = j if not_empty_col_cnt == 1: print('only_one_index, i', only_one_index, i) # 对比前后n行的同一列数据 for r in last_n_rows+next_n_rows: col = r[only_one_index] if len(col) > 1: print('col', [bbox_text_dict.get(str(x)) for x in col]) # 找出其他行同一单元格中包含多个横向排列bbox的 col_bbox_list = [col[0]] for bbox in col: for j in range(len(col_bbox_list)): bbox1 = col_bbox_list[j] if bbox1[0][0] <= bbox[0][0] <= bbox[2][0] <= bbox1[2][0]: col_bbox_list[j] = bbox elif bbox[0][0] <= bbox1[0][0] <= bbox1[2][0] <= bbox[2][0]: continue else: col_bbox_list.append(bbox) if len(col_bbox_list) > 1: # 找出这一行同列最长的bbox,有没有包含其他行同列的多个bbox col = row[only_one_index] print('long col', [bbox_text_dict.get(str(x)) for x in col]) col.sort(key=lambda x: abs(x[2][0]-x[0][0])) longest_bbox = col[-1] contain_cnt = 0 cross_cnt = 0 for bbox in col_bbox_list: if longest_bbox[0][0] <= bbox[0][0] <= bbox[2][0] <= longest_bbox[2][0]: contain_cnt += 1 if bbox[0][0] < longest_bbox[0][0] < bbox[2][0] or bbox[0][0] < longest_bbox[2][0] < bbox[2][0]: cross_cnt += 1 print('cross_cnt', cross_cnt) if contain_cnt >= 2 or cross_cnt >= 2: print('包含多个横向排列bbox', i) split_row_index_list.append(i) # 看这一行这一列最长bbox有无跨单元格 col = row[only_one_index] col.sort(key=lambda x: abs(x[2][0]-x[0][0])) longest_bbox = col[-1] cell_row = table_cell_list[i] cell_col = cell_row[only_one_index] threshold = 15 if cell_col[0][0]-threshold <= longest_bbox[0][0] <= longest_bbox[2][0] <= cell_col[1][0]+threshold: pass else: print('最长bbox跨单元格', i) split_row_index_list.append(i) if split_row_index_list: # 分割表格 split_row_index_list.insert(0, -1) split_row_index_list.insert(len(split_row_index_list), len(table)) split_row_index_list = list(set(split_row_index_list)) split_row_index_list.sort(key=lambda x: x) print('split_row_index_list', split_row_index_list, len(table)) for l in range(1, len(split_row_index_list)): index = split_row_index_list[l] last_index = split_row_index_list[l-1] if index - last_index <= 2: continue start_row_index = last_index+1 end_row_index = index-1 start_row = table[last_index+1] end_row = table[index-1] start_row = [x for y in start_row for x in y] end_row = [x for y in end_row for x in y] start_row = list(filter(lambda x: x != [], start_row)) end_row = list(filter(lambda x: x != [], end_row)) if not start_row: start_row_index = last_index + 2 start_row = table[start_row_index] start_row = [x for y in start_row for x in y] start_row = list(filter(lambda x: x != [], start_row)) if not end_row: end_row_index = index - 2 end_row = table[end_row_index] end_row = [x for y in end_row for x in y] end_row = list(filter(lambda x: x != [], end_row)) if not start_row or not end_row or end_row_index-start_row_index < 1: continue start_row.sort(key=lambda x: x[0][1]) min_h = int(start_row[0][0][1]) min_w = location[0][0] end_row.sort(key=lambda x: x[2][1]) max_h = int(end_row[-1][2][1]) max_w = location[1][0] new_location = [[min_w, min_h], [max_w, max_h]] temp_location_list.append(new_location) temp_area_row_list.append(row_list[last_index+1:index]) else: temp_location_list.append(location) temp_area_row_list.append(row_list) table_location_list = temp_location_list area_row_list = temp_area_row_list return table_location_list, area_row_list def split_table_new2(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict): temp_location_list = [] temp_area_row_list = [] for k in range(len(table_location_list)): table = area_table_bbox_list[k] location = table_location_list[k] row_list = area_row_list[k] table_cell_list = area_table_cell_list[k] split_row_index_list = [] # 遍历所有行 table_start_row_index = 0 for i in range(len(table)): row = table[i] # 判断该行为表格分割行: # 1. 这一行只有一列有值,且这一行的这一列中的最长的bbox,包含其他行同列中的2个以上bbox # 2. 这一行只有一列有值,且这一行的这一列中的最长的bbox,跨过了多列 # print(i, [bbox_text_dict.get(str(y)) for x in row for y in x]) # 每次找到分割行,更新 if table_start_row_index >= len(table): break # 获取前n行 n = 2 if i-n < table_start_row_index: last_n_rows = table[table_start_row_index:i] else: last_n_rows = table[i-n:i] # 寻找一行中最长的bbox max_len_bbox = [] for col in row: for b in col: if not max_len_bbox: max_len_bbox = b else: if b[2][0] - b[0][0] > max_len_bbox[2][0]-max_len_bbox[0][0]: max_len_bbox = b # 对比前n行的数据 for r in last_n_rows: b_list = [y for x in r for y in x] # 第n行中的非上下重合的bbox temp_b_list = [] for b in b_list: if not temp_b_list: temp_b_list.append(b) else: find_flag = 0 for tb in temp_b_list: if line_overlap(tb[0][0], tb[2][0], b[0][0], b[2][0]) > 0: find_flag = 1 break if not find_flag: temp_b_list.append(b) b_list = temp_b_list if len(b_list) > 1 and max_len_bbox: # 最长bbox是否包含第n行多个bbox contain_cnt = 0 for b in b_list: threshold = (b[2][0]-b[0][0])/4 if max_len_bbox[0][0] <= b[0][0] <= b[2][0] <= max_len_bbox[2][0]: contain_cnt += 1 if b[0][0]+threshold < max_len_bbox[0][0] < b[2][0]-threshold \ or b[0][0]+threshold < max_len_bbox[2][0] < b[2][0]-threshold: contain_cnt += 1 # print('contain_cnt', contain_cnt) if contain_cnt >= 2: # print('包含多个横向排列bbox', i) split_row_index_list.append(i) table_start_row_index = i+1 if split_row_index_list: # 分割表格 split_row_index_list.insert(0, -1) split_row_index_list.insert(len(split_row_index_list), len(table)) split_row_index_list = list(set(split_row_index_list)) split_row_index_list.sort(key=lambda x: x) print('split_row_index_list', split_row_index_list, len(table)) for l in range(1, len(split_row_index_list)): index = split_row_index_list[l] last_index = split_row_index_list[l-1] if index - last_index <= 2: continue start_row_index = last_index+1 end_row_index = index-1 start_row = table[last_index+1] end_row = table[index-1] start_row = [x for y in start_row for x in y] end_row = [x for y in end_row for x in y] start_row = list(filter(lambda x: x != [], start_row)) end_row = list(filter(lambda x: x != [], end_row)) if not start_row: start_row_index = last_index + 2 start_row = table[start_row_index] start_row = [x for y in start_row for x in y] start_row = list(filter(lambda x: x != [], start_row)) if not end_row: end_row_index = index - 2 end_row = table[end_row_index] end_row = [x for y in end_row for x in y] end_row = list(filter(lambda x: x != [], end_row)) if not start_row or not end_row or end_row_index-start_row_index < 1: continue start_row.sort(key=lambda x: x[0][1]) min_h = int(start_row[0][0][1]) min_w = location[0][0] end_row.sort(key=lambda x: x[2][1]) # print('end_row', [bbox_text_dict.get(str(x)) for x in end_row]) max_h = int(end_row[-1][2][1]) max_w = location[1][0] new_location = [[min_w, min_h], [max_w, max_h]] temp_location_list.append(new_location) temp_area_row_list.append(row_list[start_row_index:end_row_index+1]) else: temp_location_list.append(location) temp_area_row_list.append(row_list) table_location_list = temp_location_list area_row_list = temp_area_row_list return table_location_list, area_row_list def delete_not_standard_table(img, area_row_list, area_col_list, table_location_list, bbox_list, bbox_text_dict): table_standard_list = [] for i in range(len(table_location_list)): row_list = area_row_list[i] col_list = area_col_list[i] location = table_location_list[i] table_standard = True # 1. 只有单行或单列 if len(row_list) <= 1 or len(col_list) <= 1: table_standard = False table_standard_list.append(table_standard) continue # 1. 单个单元格过多bbox for row in row_list: for col in col_list: inter = [j for j in row if j in col] inter = [eval(x) for x in list(set([str(x) for x in inter]))] if len(inter) >= 8: table_standard = False break # 1. 判断表格中,不在bbox中的黑色像素 table_black_cnt = count_black(img[location[0][1]:location[1][1], location[0][0]:location[1][0], :]) bbox_black_cnt = 0 for bbox in bbox_list: if location[0][1] <= bbox[0][1] <= location[1][1]: sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :] if sub_img.shape[0] >= 3 and sub_img.shape[1] >= 3: bbox_black_cnt += count_black(sub_img) print('table_black_cnt, bbox_black_cnt', table_black_cnt, bbox_black_cnt, bbox_black_cnt / table_black_cnt) if bbox_black_cnt / table_black_cnt < 0.5: table_standard = False table_standard_list.append(table_standard) print('table_standard_list', table_standard_list) return table_standard_list def bbox_preprocess(bbox_list, text_list, row_list, bbox_text_dict): # 合并同一行中多个单字bbox for row in row_list: single_bbox_list = [] row.sort(key=lambda x: x[0][0]) i = 0 for bbox in row: if len(bbox_text_dict.get(str(bbox))) == 1 and i != len(row) - 1: single_bbox_list.append(bbox) else: if len(single_bbox_list) >= 3: if len(bbox_text_dict.get(str(bbox))) == 1: single_bbox_list.append(bbox) new_bbox = single_bbox_list[0] new_text = "" single_bbox_list.sort(key=lambda x: x[0][0]) for b in single_bbox_list: new_bbox = [[new_bbox[0][0], new_bbox[0][1]], [b[2][0], new_bbox[0][1]], [b[2][0], b[2][1]], [new_bbox[0][0], b[2][1]], ] bbox_list.remove(b) new_text += bbox_text_dict.get(str(b)) text_list.remove(bbox_text_dict.get(str(b))) # print('new_bbox, new_text', new_bbox, new_text) bbox_list.append(new_bbox) text_list.append(new_text) bbox_text_dict[str(new_bbox)] = new_text single_bbox_list = [] i += 1 return bbox_list, text_list, bbox_text_dict def merge_table(area_row_list, area_col_list, table_location_list, bbox_list): table_location_list.sort(key=lambda x: x[0][1]) merge_index_list = [] temp_merge_list = [] for i in range(1, len(table_location_list)): last_col_list = area_col_list[i-1] col_list = area_col_list[i] last_location = table_location_list[i-1] location = table_location_list[i] merge_flag = 0 # 获取每个列的宽度 col_width_list = [] for col in col_list: col.sort(key=lambda x: x[0][0]) min_w = col[0][0][0] col.sort(key=lambda x: x[2][0]) max_w = col[-1][2][0] col_width_list.append([min_w, max_w]) # 获取两个表格之间的bbox,判断bbox是否跨越多列 threshold = 5 merge_flag2 = 1 for bbox in bbox_list: if last_location[1][1]-threshold <= bbox[0][1] <= bbox[2][1] <= location[0][1]+threshold: if bbox[0][0] < col_width_list[0][0] or bbox[2][0] > col_width_list[-1][1]: merge_flag2 = 0 break for w in col_width_list: if w[0] <= bbox[0][0] <= w[1] and bbox[2][0] - bbox[0][0] > w[1] - w[0]: merge_flag2 = 0 break # if location[0][1] - last_location[1][1] <= 20: if merge_flag2: if len(last_col_list) == len(col_list): temp_merge_list += [i-1, i] merge_flag = 1 if not merge_flag: if temp_merge_list: merge_index_list.append(temp_merge_list) else: merge_index_list.append([i-1]) temp_merge_list = [] if temp_merge_list: merge_index_list.append(temp_merge_list) else: merge_index_list.append([len(table_location_list)-1]) # print('merge_index_list', merge_index_list) if not merge_index_list: return area_row_list, area_col_list, table_location_list new_table_location_list = [] new_area_row_list = [] new_area_col_list = [] for index_list in merge_index_list: if not table_location_list: break index_list = list(set(index_list)) temp_table = table_location_list[index_list[0]] new_area_row_list.append(area_row_list[index_list[0]]) new_area_col_list.append(area_col_list[index_list[0]]) for index in index_list[1:]: temp_table = [[min(temp_table[0][0], table_location_list[index][0][0]), min(temp_table[0][1], table_location_list[index][0][1])], [max(temp_table[1][0], table_location_list[index][1][0]), max(temp_table[1][1], table_location_list[index][1][1])] ] new_area_row_list[-1] += area_row_list[index] new_area_col_list[-1] += area_col_list[index] new_table_location_list.append(temp_table) return new_area_row_list, new_area_col_list, new_table_location_list def add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict): """ 对单个单元格内多列的,增加列线 :return: """ add_area_col_lines = [] for i in range(len(table_location_list)): row_list = area_row_list[i] col_list = area_col_list[i] location = table_location_list[i] add_col_lines = [] new_col_list = [] for col in col_list: row_cnt = 0 new_row_list = [] cell_col_lines = [] col.sort(key=lambda x: (x[0][1], x[0][0])) # print('col') for row in row_list: row.sort(key=lambda x: (x[0][0], x[0][1])) inter = [j for j in row if j in col] inter = [eval(x) for x in list(set([str(x) for x in inter]))] inter.sort(key=lambda x: (x[0][1], x[0][0])) new_row = [] print('inter', [bbox_text_dict.get(str(x)) for x in inter]) # if inter: # # 先将同个单元格内上下重叠的bbox合并 # temp_inter = [] # # used_bbox_list = [] # for bbox1 in inter: # if bbox1 in used_bbox_list: # continue # temp_merge_bbox = [bbox1] # for bbox2 in inter: # if bbox2 in used_bbox_list: # continue # if line_overlap(bbox1[0][0], bbox1[2][0], bbox2[0][0], bbox2[2][0]) >= 2/3 * min(bbox1[2][0]-bbox1[0][0], bbox2[2][0], bbox2[0][0]) \ # and line_overlap(bbox1[0][1], bbox1[2][1], bbox2[0][1], bbox2[2][1]) > 0: # temp_merge_bbox += [bbox1, bbox2] # used_bbox_list += [bbox1, bbox2] # temp_merge_bbox = [eval(y) for y in list(set([str(x) for x in temp_merge_bbox]))] # temp_inter.append(temp_merge_bbox) # # inter = [] # for m_bbox in temp_inter: # min_w, min_h, max_w, max_h = 10000, 10000, 0, 0 # temp_text = "" # for bbox in m_bbox: # if bbox[0][0] < min_w: # min_w = bbox[0][0] # if bbox[0][1] < min_h: # min_h = bbox[0][1] # if bbox[2][0] > max_w: # max_w = bbox[2][0] # if bbox[2][1] > max_h: # max_h = bbox[2][1] # temp_text += bbox_text_dict.get(str(bbox)) + ' ' # inter.append([[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]]) # bbox_text_dict[str(inter[-1])] = temp_text # print('merge inter', [bbox_text_dict.get(str(x)) for x in inter]) # 一个单元格内多个bbox if len(inter) > 1: # 单元格内分行 cell_row = [] temp_row = [inter[0]] row_len = [inter[0][0][1], inter[0][2][1]] for bbox in inter[1:]: temp_bbox = temp_row[0] bbox_h_len = bbox[2][1] - bbox[0][1] temp_bbox_h_len = temp_bbox[2][1] - temp_bbox[0][1] # if temp_bbox[0][1]-5 <= bbox[0][1] <= bbox[2][1] <= temp_bbox[2][1]+5 \ # or bbox[0][1]-5 <= temp_bbox[0][1] <= temp_bbox[2][1] <= bbox[2][1]+5 \ if line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]) >= 1/3 * min(bbox_h_len, temp_bbox_h_len): temp_row.append(bbox) row_len[0] = min(row_len[0], bbox[0][1]) row_len[1] = max(row_len[1], bbox[2][1]) # print('in row', bbox_text_dict.get(str(bbox)), bbox_text_dict.get(str(temp_bbox)), # row_len[0], row_len[1], bbox[0][1], bbox[2][1], # line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]), # 1/3 * min(bbox_h_len, temp_bbox_h_len)) else: # print('not in row', bbox_text_dict.get(str(bbox)), bbox_text_dict.get(str(temp_bbox)), # line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]), 1/3 * min(bbox_h_len, temp_bbox_h_len)) # print(bbox_text_dict.get(str(bbox)), temp_bbox[2][1] - bbox[0][1], 1/2 * min(bbox_h_len, temp_bbox_h_len), # bbox[2][1] - temp_bbox[0][1], 1/2 * min(bbox_h_len, temp_bbox_h_len), # line_overlap(temp_bbox[0][1], temp_bbox[2][1], bbox[0][1], bbox[2][1]), 1/3 * min(bbox_h_len, temp_bbox_h_len), # temp_bbox[0][1], temp_bbox[2][1], bbox[0][1], bbox[2][1], bbox_text_dict.get(str(temp_bbox))) cell_row.append(temp_row) temp_row = [bbox] row_len = [bbox[0][1], bbox[2][1]] if temp_row: cell_row.append(temp_row) print('row_cnt', row_cnt) for c in cell_row: c.sort(key=lambda x: x[0][0]) print('cell_row', [bbox_text_dict.get(str(x)) for x in c]) if row_cnt == 0: # 获取最大列数的列 temp_cell_row = copy.deepcopy(cell_row) temp_cell_row.sort(key=lambda x: len(x)) max_cell_row = temp_cell_row[-1] # 对行内上下堆叠的进行合并 max_cell_row.sort(key=lambda x: (x[0][0], x[0][1])) used_bbox_list = [] merge_bbox_list = [] for bbox1 in max_cell_row: temp_merge_bbox = [bbox1] if bbox1 in used_bbox_list: continue for bbox2 in max_cell_row: if bbox2 in used_bbox_list: continue if line_overlap(bbox1[0][0], bbox1[2][0], bbox2[0][0], bbox2[2][0]) >= 2/3 * min(bbox1[2][0]-bbox1[0][0], bbox2[2][0], bbox2[0][0]): temp_merge_bbox.append(bbox2) used_bbox_list += [bbox1, bbox2] # 选范围最大的bbox temp_merge_bbox.sort(key=lambda x: (x[2][0], -x[0][0])) merge_bbox_list.append(temp_merge_bbox[-1]) temp_cell_row[-1] = merge_bbox_list print('temp_cell_row', [bbox_text_dict.get(str(x)) for x in temp_cell_row[-1]]) # print('temp_cell_row', temp_cell_row[-1]) for c in temp_cell_row[-1]: cell_col_lines.append([c[0][0], c[2][0]]) cell_col_lines.sort(key=lambda x: x[0]) for c in cell_col_lines: add_col_lines.append([[int(c[1]), location[0][1]], [int(c[1]), location[1][1]]]) add_area_col_lines.append(add_col_lines) # # 循环所有行,若跨行 # cell_col_lines.sort(key=lambda x: x[0]) # cell_row.sort(key=lambda x: (x[0][0], x[0][1])) # print('sorted cell_col_lines', cell_col_lines) # for r in cell_row: # right_bbox = [] # for bbox in r: # for k in range(len(cell_col_lines)): # if k == 0: # min_w = -10000 # if len(cell_col_lines) <= 1: # max_w = cell_col_lines[k][1] # else: # max_w = cell_col_lines[k+1][0] # # + 1/4*(cell_col_lines[k+1][1]-cell_col_lines[k+1][0]) # elif k == len(cell_col_lines) - 1: # max_w = 10000 # if len(cell_col_lines) <= 1: # min_w = cell_col_lines[k-1][1] # else: # min_w = cell_col_lines[k][0] # else: # if len(cell_col_lines) <= 1: # min_w = -10000 # max_w = 10000 # else: # min_w = cell_col_lines[k-1][1] # max_w = cell_col_lines[k+1][0] # # + 1/4*(cell_col_lines[k+1][1]-cell_col_lines[k+1][0]) # # # 判断跨行 # if min_w <= bbox[0][0] <= bbox[2][0] <= max_w: # new_row.append(bbox) # right_bbox.append(bbox) # # else: # # print(min_w, bbox[0][0], bbox[2][0], max_w, # # bbox_text_dict.get(str(bbox))) # # # 有跨行,该行舍弃 # if len(right_bbox) != len(r): # for r1 in r: # if r1 in new_row: # new_row.remove(r1) # # # 单元格只有一个bbox # else: # new_row = inter # print('new_row', [bbox_text_dict.get(str(x)) for x in new_row]) # new_row.sort(key=lambda x: x[0][0]) # new_row_list.append(new_row) # row_cnt += 1 # new_col_list.append(new_row_list) # # new_row_list = [x for x in new_col_list[0]] # for col in new_col_list[1:]: # for j in range(len(col)): # new_row_list[j] += col[j] # # temp_new_row_list = [] # for r in new_row_list: # if r: # temp_new_row_list.append(r) # print('new_row_list', [bbox_text_dict.get(str(x)) for x in r]) # new_row_list = temp_new_row_list # area_row_list[i] = new_row_list return add_area_col_lines def judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict): new_area_col_lines = [] for i in range(len(table_location_list)): location = table_location_list[i] col_lines = area_col_lines[i] col_lines.sort(key=lambda x: x[0][0]) sub_bbox_list = [] for bbox in bbox_list: if location[0][1] <= bbox[0][1] <= location[1][1]: sub_bbox_list.append(bbox) # 判断线穿过bbox,那一行的线去掉 # temp_col_lines = [] # for c in col_lines: # for bbox in sub_bbox_list: # 判断新增线有没有压在黑色像素上或有没有在bbox之间 # temp_col_lines = [] # for c in col_lines: # if c[1][1] >= img.shape[1] or c[0][1] <= 0: # continue # # black_cnt = count_black(img[c[0][1]:c[1][1], c[0][0]:c[1][0]+1, :]) # if black_cnt > 10: # continue # temp_col_lines.append(c) # col_lines = temp_col_lines # 判断两线之间有没有完整bbox col_lines = [eval(y) for y in list(set([str(x) for x in col_lines]))] col_lines.sort(key=lambda x: x[0][0]) threshold = 5 if not col_lines: new_area_col_lines.append([]) continue temp_col_lines = [col_lines[0]] for j in range(1, len(col_lines)): last_col_w = temp_col_lines[-1][0][0] col_w = col_lines[j][0][0] for bbox in sub_bbox_list: if last_col_w-threshold <= bbox[0][0] <= bbox[2][0] <= col_w+threshold: temp_col_lines.append(col_lines[j]) break temp_col_lines.append(col_lines[-1]) col_lines = temp_col_lines # 判断线穿过bbox,向右移动到空位 for col in col_lines: cross_bbox_list = [] for bbox in sub_bbox_list: if bbox[0][0] < col[0][0] < bbox[2][0]: cross_bbox_list.append(bbox) if cross_bbox_list: # cross_bbox_list.sort(key=lambda x: x[2][0], reverse=True) # for bbox in cross_bbox_list: # line_now_w = col[0][0] # line_move_w = bbox[2][0] # find_flag = 0 # for bbox1 in sub_bbox_list: # if bbox1 in cross_bbox_list: # continue # if line_now_w <= bbox1[0][0] <= line_move_w: # find_flag = 1 # break # # if not find_flag: # col[0][0] = int(line_move_w) # col[1][0] = int(line_move_w) # break cross_bbox_list.sort(key=lambda x: x[2][0], reverse=True) line_move_w = cross_bbox_list[0][2][0] line_now_w = col[0][0] for bbox1 in sub_bbox_list: if bbox1 in cross_bbox_list: continue if line_now_w <= bbox1[0][0] <= line_move_w: line_now_w = line_move_w line_move_w = bbox1[2][0] col[0][0] = int(line_move_w) col[1][0] = int(line_move_w) # 将边框线加上 left_col = [[location[0][0], location[0][1]], [location[0][0], location[1][1]]] right_col = [[location[1][0], location[0][1]], [location[1][0], location[1][1]]] if left_col not in col_lines: col_lines.append(left_col) if right_col not in col_lines: col_lines.append(right_col) new_area_col_lines.append(col_lines) return new_area_col_lines def add_row_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict, area_row_lines): add_area_row_lines = [] for i in range(len(table_location_list)): row_list = area_row_list[i] col_list = area_col_list[i] location = table_location_list[i] row_lines = area_row_lines[i] add_row_lines = [] for row in row_list: col_cnt = 0 row.sort(key=lambda x: (x[0][0], x[0][1])) # # 只以第一列为标准 # first_col = col_list[0] # first_col.sort(key=lambda x: (x[0][1], x[0][0])) # inter = [j for j in row if j in first_col] # inter = [eval(x) for x in list(set([str(x) for x in inter]))] # inter.sort(key=lambda x: (x[0][1], x[0][0])) # 所有列都参与 for col in col_list: col.sort(key=lambda x: (x[0][1], x[0][0])) inter = [j for j in row if j in col] print('col', col_cnt, [bbox_text_dict.get(str(x)) for x in col], [bbox_text_dict.get(str(x)) for x in row]) inter = [eval(x) for x in list(set([str(x) for x in inter]))] inter.sort(key=lambda x: (x[0][1], x[0][0])) print('add_row_lines inter', [bbox_text_dict.get(str(x)) for x in inter]) if len(inter) > 0: # 单元格内分行 cell_row = [] temp_row = [inter[0]] row_len = [inter[0][0][1], inter[0][2][1]] for bbox in inter[1:]: temp_bbox = temp_row[0] bbox_h_len = bbox[2][1] - bbox[0][1] temp_bbox_h_len = temp_bbox[2][1] - temp_bbox[0][1] if line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]) >= 1/3 * min(bbox_h_len, temp_bbox_h_len): temp_row.append(bbox) row_len[0] = min(row_len[0], bbox[0][1]) row_len[1] = max(row_len[1], bbox[2][1]) else: cell_row.append(temp_row) temp_row = [bbox] row_len = [bbox[0][1], bbox[2][1]] if temp_row: cell_row.append(temp_row) print('col_cnt', col_cnt) for c in cell_row: c.sort(key=lambda x: x[0][0]) print('cell_row', [bbox_text_dict.get(str(x)) for x in c]) # 对有多行的 if len(cell_row) > 0: if len(cell_row) == 1: h = int(cell_row[0][0][2][1]) add_row_lines.append([[location[0][0], h], [location[1][0], h]]) for j in range(1, len(cell_row)): last_row = cell_row[j-1] row1 = cell_row[j] last_row.sort(key=lambda x: x[2][1]) row1.sort(key=lambda x: x[0][1]) find_flag = 0 for l in row_lines: if last_row[-1][2][1] <= l[0][1] <= row1[0][0][1]: find_flag = 1 break if not find_flag: h = int(last_row[-1][2][1]) if j == 1: last_row.sort(key=lambda x: x[0][1]) h += int(last_row[0][0][1] - location[0][1]) else: last_two_row = cell_row[j-2] last_two_row.sort(key=lambda x: x[2][1]) last_row.sort(key=lambda x: x[0][1]) h += int(last_row[0][0][1] - last_two_row[-1][2][1]) add_row_lines.append([[location[0][0], h], [location[1][0], h]]) col_cnt += 1 add_area_row_lines.append(add_row_lines) return add_area_row_lines def judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict): new_area_row_lines = [] for i in range(len(table_location_list)): location = table_location_list[i] row_lines = area_row_lines[i] sub_bbox_list = [] for bbox in bbox_list: if location[0][1] <= bbox[0][1] <= location[1][1]: sub_bbox_list.append(bbox) # 判断新增线有没有压在黑色像素上或有没有在bbox之间 # temp_row_lines = [] # for c in row_lines: # if c[1][1] >= img.shape[1] or c[0][1] <= 0: # continue # if c[1][1] > location[1][1] or c[1][1] < location[0][1]: # continue # black_cnt = count_black(img[c[0][1]:c[1][1]+1, c[0][0]:c[1][0], :]) # if black_cnt > 10: # continue # temp_row_lines.append(c) # row_lines = temp_row_lines # if not row_lines: # new_area_row_lines.append([]) # continue # 判断线穿过bbox,向下移动到空位 row_lines.sort(key=lambda x: x[0][1]) for row in row_lines: # 循环找出空位 while True: cross_bbox_list = [] for bbox in sub_bbox_list: # if (bbox[0][1]+bbox[2][1])/2 < row[0][1] < bbox[2][1]: if bbox[0][1] < row[0][1] < bbox[2][1]: cross_bbox_list.append(bbox) if cross_bbox_list: # print('row cross_bbox_list', [bbox_text_dict.get(str(x)) for x in cross_bbox_list]) cross_bbox_list.sort(key=lambda x: x[2][1], reverse=True) line_move_h = cross_bbox_list[0][2][1] move_bbox = cross_bbox_list[0] row[0][1] = int(line_move_h) row[1][1] = int(line_move_h) # print('move_bbox', bbox_text_dict.get(str(move_bbox))) else: break # 判断两线之间有没有完整bbox row_lines.sort(key=lambda x: x[0][1]) threshold = 5 temp_row_lines = [row_lines[0]] for j in range(1, len(row_lines)): last_row_w = temp_row_lines[-1][0][1] row_w = row_lines[j][0][1] for bbox in sub_bbox_list: if last_row_w-threshold <= bbox[0][1] <= bbox[2][1] <= row_w+threshold: temp_row_lines.append(row_lines[j]) break temp_row_lines.append(row_lines[-1]) row_lines = temp_row_lines # 将边框线加上 up_row = [[location[0][0], location[0][1]], [location[1][0], location[0][1]]] bottom_row = [[location[0][0], location[1][1]], [location[1][0], location[1][1]]] if up_row not in row_lines: row_lines.append(up_row) if bottom_row not in row_lines: row_lines.append(bottom_row) new_area_row_lines.append(row_lines) return new_area_row_lines def merge_lines(lines, axis=0, threshold=5): lines.sort(key=lambda x: x[0][1-axis]) used_lines = [] new_lines = copy.deepcopy(lines) for line1 in lines: if line1 in used_lines: continue current_line = line1 temp_merge_lines = [line1] for line2 in lines: if line2 in used_lines: continue if abs(current_line[0][1-axis] - line2[0][1-axis]) <= threshold: temp_merge_lines.append(line2) used_lines.append(line2) current_line = line2 # 取最右或最下线 temp_merge_lines.sort(key=lambda x: x[0][1-axis]) for l in temp_merge_lines: if l in new_lines: new_lines.remove(l) new_lines.append(temp_merge_lines[-1]) new_lines.sort(key=lambda x: x[0][1-axis]) return new_lines def merge_row_bbox_list(area_row_list): new_area_row_list = [] for row_list in area_row_list: new_row_list = copy.deepcopy(row_list) # 针对表头有多行被分在不同行 for i in range(1, len(row_list)-1): last_row = row_list[i-1] row = row_list[i] next_row = row_list[i+1] merge_cnt = 0 for last_b in last_row: find_flag1 = 0 for next_b in next_row: if line_overlap(last_b[0][0], last_b[2][0], next_b[0][0], next_b[2][0]) \ >= 0.8*min(last_b[2][0] - last_b[0][0], next_b[2][0] - next_b[0][0]): find_flag1 = 1 break find_flag2 = 0 if find_flag1: for b in row: if line_overlap(last_b[0][0], last_b[2][0], b[0][0], b[2][0]) \ >= 0.8*min(last_b[2][0] - last_b[0][0], b[2][0] - b[0][0]): find_flag2 = 1 break if find_flag1 and not find_flag2: merge_cnt += 1 if merge_cnt == len(last_row) and merge_cnt > 1: new_row_list = new_row_list[:i-1] + [last_row+row+next_row] + new_row_list[i+2:] new_area_row_list.append(new_row_list) return new_area_row_list def count_black(image_np, threshold=150): lower = np.array([0, 0, 0]) upper = np.array([threshold, threshold, threshold]) mask = cv2.inRange(image_np, lower, upper) cnt = np.sum(mask != 0) # print("count color ", cnt) return cnt def get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0): area_row_list = [] for i in range(len(table_location_list)): row_lines = area_row_lines[i] col_lines = area_col_lines[i] # 求线交点 cross_points = get_points_by_line(img, row_lines, col_lines) # 交点分行 cross_points.sort(key=lambda x: (x[1-axis], x[axis])) row_point_list = [] current_row = [cross_points[0]] for p in cross_points[1:]: if current_row[0][1-axis] == p[1-axis]: current_row.append(p) else: row_point_list.append(current_row) current_row = [p] if current_row: row_point_list.append(current_row) used_bbox_list = [] row_list = [] for j in range(1, len(row_point_list)): last_row = row_point_list[j-1] row = row_point_list[j] sub_row_list = [] for k in range(1, len(row)): last_p = last_row[k-1] p = row[k] for bbox in bbox_list: if bbox in used_bbox_list: continue bbox_h_center = (bbox[0][1-axis]+bbox[2][1-axis]) / 2 bbox_w_center = (bbox[0][axis]+bbox[2][axis]) / 2 if last_p[1-axis] <= bbox_h_center <= p[1-axis] and last_p[axis] <= bbox_w_center <= p[axis]: sub_row_list.append(bbox) used_bbox_list.append(bbox) row_list.append(sub_row_list) area_row_list.append(row_list) return area_row_list def get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list): area_table_bbox_list = [] area_table_cell_list = [] for i in range(len(table_location_list)): row_lines = area_row_lines[i] col_lines = area_col_lines[i] # 求线交点 cross_points = get_points_by_line(img, row_lines, col_lines) # 交点分行 cross_points.sort(key=lambda x: (x[1], x[0])) row_point_list = [] if not cross_points: area_table_bbox_list.append([]) area_table_cell_list.append([]) continue current_row = [cross_points[0]] for p in cross_points[1:]: if current_row[0][1] == p[1]: current_row.append(p) else: row_point_list.append(current_row) current_row = [p] if current_row: row_point_list.append(current_row) # bbox以表格格式排列 used_bbox_list = [] row_list = [] row_cell_list = [] for j in range(1, len(row_point_list)): last_row = row_point_list[j-1] row = row_point_list[j] col_list = [] col_cell_list = [] for k in range(1, len(row)): last_p = last_row[k-1] p = row[k] cell = [] for bbox in bbox_list: if bbox in used_bbox_list: continue bbox_h_center = (bbox[0][1]+bbox[2][1]) / 2 bbox_w_center = (bbox[0][0]+bbox[2][0]) / 2 if last_p[1] <= bbox_h_center <= p[1] and last_p[0] <= bbox_w_center <= p[0]: cell.append(bbox) used_bbox_list.append(bbox) col_list.append(cell) col_cell_list.append([last_p, p]) row_list.append(col_list) row_cell_list.append(col_cell_list) area_table_bbox_list.append(row_list) area_table_cell_list.append(row_cell_list) return area_table_bbox_list, area_table_cell_list def get_lines_from_img(img): gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 开操作提取水平线 kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 1)) binary1 = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel) # cv2.imshow('7,1', binary1) # 开操作提取垂直线 kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 7)) binary2 = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel) # cv2.imshow('1,7', binary2) # # cv2.imshow('table', binary1+binary2) # cv2.waitKey(0) return binary1, binary2 def get_bbox_by_img(row_img, col_img): # cv2.imshow('table', row_img+col_img) # 求线交点 point_img = np.bitwise_and(row_img, col_img) # cv2.imshow('point_img', point_img) # cv2.waitKey(0) # 识别黑白图中的白色交叉点,将横纵坐标取出 ys, xs = np.where(point_img > 200) cross_points = [] for i in range(len(xs)): cross_points.append((xs[i], ys[i])) cross_points.sort(key=lambda x: (x[0], x[1])) temp_cross_points = [] for p1 in cross_points: find_flag = 0 for p2 in temp_cross_points: if abs(p1[1] - p2[1]) <= 5 and abs(p1[0] - p2[0]) <= 5: find_flag = 1 break if not find_flag: temp_cross_points.append(p1) cross_points = temp_cross_points if not cross_points: return [], False print('cross_points', len(cross_points)) axis = 0 # 交点分行 cross_points.sort(key=lambda x: (x[1-axis], x[axis])) row_point_list = [] current_row = [cross_points[0]] for p in cross_points[1:]: if abs(current_row[0][1-axis] - p[1-axis]) <= 5: current_row.append(p) else: current_row.sort(key=lambda x: x[0]) row_point_list.append(current_row) current_row = [p] if current_row: current_row.sort(key=lambda x: x[0]) row_point_list.append(current_row) row_len = len(row_point_list[0]) for row in row_point_list: # print('row_point_list row', len(row), row) if row_len != len(row): return [], False row_list = [] standard_flag = True for j in range(1, len(row_point_list)): last_row = row_point_list[j-1] row = row_point_list[j] sub_row = [] for k in range(1, len(row)): if k-1 >= len(last_row): # print(len(last_row), len(row)) standard_flag = False break last_p = last_row[k-1] p = row[k] bbox = [last_p, p] sub_row.append(bbox) row_list.append(sub_row) if not row_list: return [], False row_len = len(row_list[0]) for row in row_list: if len(row) != row_len: standard_flag = False break print('standard_flag', standard_flag) if standard_flag: new_img = np.zeros((row_img.shape[0], row_img.shape[1], 3), dtype=np.uint8) # for row in row_list: # for b in row: # print('b', b) # cv2.rectangle(new_img, [int(b[0][0]), int(b[0][1])], [int(b[1][0]), int(b[1][1])], # (0, 0, 255), 1) # cv2.imshow('new_img', new_img) # cv2.waitKey(0) return row_list, standard_flag def get_points_by_line(img, row_lines, col_lines): row_img = np.zeros_like(img[:, :, 0], dtype=np.uint8) col_img = np.zeros_like(img[:, :, 0], dtype=np.uint8) for r in row_lines: cv2.line(row_img, r[0], r[1], (255, 255, 255), 1) for c in col_lines: cv2.line(col_img, c[0], c[1], (255, 255, 255), 1) point_img = np.bitwise_and(row_img, col_img) # 识别黑白图中的白色交叉点,将横纵坐标取出 ys, xs = np.where(point_img > 0) points = [] for i in range(len(xs)): points.append((xs[i], ys[i])) points.sort(key=lambda x: (x[0], x[1])) return points def merge_text_and_table(text_bbox_list, table_row_list): used_bbox_list = [] row_list = [] for row in table_row_list: sub_row_list = [] row.sort(key=lambda x: x[0][0]) for bbox1 in row: sub_bbox_list = [] for bbox2 in text_bbox_list: if bbox2 in used_bbox_list: continue bbox_h_center = (bbox2[0][1]+bbox2[2][1]) / 2 bbox_w_center = (bbox2[0][0]+bbox2[2][0]) / 2 if bbox1[0][1] <= bbox_h_center <= bbox1[1][1] and bbox1[0][0] <= bbox_w_center <= bbox1[1][0]: sub_bbox_list.append(bbox2) used_bbox_list.append(bbox2) sub_row_list.append(sub_bbox_list) row_list.append(sub_row_list) return row_list def shrink_bbox(img, bbox_list): def return_first_black_index(image_np): lower = np.array([0, 0, 0]) upper = np.array([150, 150, 150]) mask = cv2.inRange(image_np, lower, upper) black_index_list = np.where(mask != 0) return black_index_list new_bbox_list = [] for bbox in bbox_list: img_bbox = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :] if 0 in img_bbox.shape: new_bbox_list.append(bbox) continue # 左右上下开始扫描,碰到黑像素即停 index_list = return_first_black_index(img_bbox[:, :, :]) if index_list[0].size == 0 or index_list[1].size == 0: new_bbox_list.append(bbox) continue min_h = index_list[0][0] max_h = index_list[0][-1] img_bbox1 = np.swapaxes(img_bbox, 0, 1) index_list = return_first_black_index(img_bbox1[:, :, :]) if index_list[0].size == 0 or index_list[1].size == 0: new_bbox_list.append(bbox) continue min_w = index_list[0][0] max_w = index_list[0][-1] real_min_w = bbox[0][0] + min_w real_max_w = bbox[0][0] + max_w real_min_h = bbox[0][1] + min_h real_max_h = bbox[0][1] + max_h new_bbox = [[real_min_w, real_min_h], [real_min_w, real_max_h], [real_max_w, real_max_h], [real_max_w, real_min_h]] new_bbox_list.append(new_bbox) # cv2.imshow('img', img_bbox) # cv2.imshow('shrink', img[int(new_bbox[0][1]):int(new_bbox[2][1]), int(new_bbox[0][0]):int(new_bbox[2][0]), :]) # cv2.waitKey(0) return new_bbox_list def affinity_propagation(data_list): """ 聚类:近邻传播 :return: """ data_np = np.array(data_list) random_state = 170 model = AffinityPropagation(damping=0.5, convergence_iter=15, random_state=random_state).fit(data_np) # cluster_centers_indices = model.cluster_centers_indices_ y_pred = model.labels_ if y_pred[0] == -1: print('ap dp0.5 ci50') model = AffinityPropagation(convergence_iter=50, random_state=random_state).fit(data_np) y_pred = model.labels_ if y_pred[0] == -1: print('ap dp0.7 ci15') model = AffinityPropagation(damping=0.7, convergence_iter=15, random_state=random_state).fit(data_np) y_pred = model.labels_ if y_pred[0] == -1: print('ap dp0.7 ci50') model = AffinityPropagation(damping=0.7, convergence_iter=50, random_state=random_state).fit(data_np) y_pred = model.labels_ if y_pred[0] == -1: print('all -1') y_pred = np.zeros(y_pred.shape[0]) y_pred = y_pred.tolist() return y_pred def dbscan(data_list): """ 聚类:dbscan :return: """ data_np = np.array(data_list) model = DBSCAN(eps=3, min_samples=2).fit(data_np) y_pred = model.labels_ y_pred = y_pred.tolist() return y_pred def test_ocr_model(img_path): with open(img_path, "rb") as f: file_bytes = f.read() file_base64 = base64.b64encode(file_bytes) file_json = {"data": file_base64, "md5": 0} _url = "http://192.168.2.103:17000/ocr" # _url = "http://127.0.0.1:17000/ocr" result = json.loads(request_post(_url, file_json)) return result def test_cho_model(text): # text = "巧克力" text = [x for x in text] data_json = {"data": json.dumps(text)} _url = "http://192.168.2.103:17058/cho" result = json.loads(request_post(_url, data_json)) if result.get("success"): decode_list = result.get("data") print("char_list", text) print("decode_list", decode_list) return decode_list else: print("failed!") if __name__ == '__main__': get_table_new() # _l = [[18, 0], [0, 0], [14, 0], [0, 0], [12, 0], [0, 0], [14, 0], [2, 0], [15, 0], [0, 0]] # # _l = [[27, 0], [26, 0], [17, 0]] # print(affinity_propagation(_l)) # print(dbscan(_l)) # _img = cv2.imread(r'C:\Users\Administrator\Desktop\111.jpg') # shrink_bbox(_img, [[[0, 0], [0, 0], [_img.shape[1], _img.shape[0]], [_img.shape[1], _img.shape[0]]]])