from pdfminer.layout import * class LineTable(): def recognize_table(self,list_textbox,list_line): self.list_line = list_line self.list_crosspoints = self.recognize_crosspoints(list_line) #聚类 cluster_crosspoints = [] for _point in self.list_crosspoints: cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]}) while 1: _find = False new_cluster_crosspoints = [] for l_point in cluster_crosspoints: _flag = False for l_n_point in new_cluster_crosspoints: line1 = l_point.get("lines") line2 = l_n_point.get("lines") if len(line1&line2)>0: _find = True _flag = True l_n_point["lines"] = line1.union(line2) l_n_point["points"].extend(l_point["points"]) if not _flag: new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")}) cluster_crosspoints = new_cluster_crosspoints if not _find: break # print(len(cluster_crosspoints)) list_l_rect = [] for table_crosspoint in cluster_crosspoints: list_rect = self.crosspoint2rect(table_crosspoint.get("points")) list_l_rect.append(list_rect) in_objs = set() list_tables = [] for l_rect in list_l_rect: _ta = self.rect2table(list_textbox,l_rect,in_objs) if _ta: list_tables.append(_ta) return list_tables,in_objs,list_l_rect def recognize_table_by_rect(self,list_textbox,list_rect,margin=2): dump_margin = 5 list_rect_tmp = [] #去重 for _rect in list_rect: if (_rect.bbox[3]-_rect.bbox[1]<10) or (abs(_rect.bbox[2]-_rect.bbox[0])<5): continue _find = False for _tmp in list_rect_tmp: for i in range(4): if abs(_rect.bbox[i]-_tmp.bbox[i])0: _find = True _flag = True l_n_point["lines"] = line1.union(line2) l_n_point["points"].extend(l_point["points"]) if not _flag: new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")}) cluster_crosspoints = new_cluster_crosspoints if not _find: break for list_cp in cluster_crosspoints: points = list_cp.get("points") l_lines = [] for p in points: l_lines.extend(p.get("p_lines")) l_lines = list(set(l_lines)) l_lines.sort(key=lambda x:x.bbox[0]) min_x = l_lines[0].bbox[0]+2 l_lines.sort(key=lambda x:x.bbox[1]) min_y = l_lines[0].bbox[1]+2 l_lines.sort(key=lambda x:x.bbox[2]) max_x = l_lines[-1].bbox[2]-2 l_lines.sort(key=lambda x:x.bbox[3]) max_y = l_lines[-1].bbox[3]-2 points.sort(key=lambda x:x.bbox[0]) if abs(min_x-points[0].bbox[0])>10: list_line.append(LTLine(1,[(min_x,min_y),(min_x,max_y)])) points.sort(key=lambda x:x.bbox[1]) if abs(min_y-points[0].bbox[1])>10: list_line.append(LTLine(1,[(min_x,min_y),(max_x,min_y)])) points.sort(key=lambda x:x.bbox[2]) if abs(max_x-points[-1].bbox[2])>10: list_line.append(LTLine(1,[(max_x,min_y),(max_x,max_y)])) points.sort(key=lambda x:x.bbox[3]) if abs(max_y-points[-1].bbox[3])>10: list_line.append(LTLine(1,[(min_x,max_y),(max_x,max_y)])) list_crosspoints = [] for _i in range(len(list_line)): for _j in range(len(list_line)): line1 = list_line[_i].__dict__.get("bbox") line2 = list_line[_j].__dict__.get("bbox") exists,point = self.cross_point(line1,line2) if exists: list_crosspoints.append(point) # plt.figure() # for _line in list_line: # x0,y0,x1,y1 = _line.__dict__.get("bbox") # plt.plot([x0,x1],[y0,y1]) # for _line in list_line: # x0,y0,x1,y1 = _line.bbox # plt.plot([x0,x1],[y0,y1]) # for point in list_crosspoints: # plt.scatter(point.get("point")[0],point.get("point")[1]) # plt.show() # print(list_crosspoints) # print("points num",len(list_crosspoints)) return list_crosspoints def recognize_rect(self,_page): list_line = [] for _obj in _page._objs: if isinstance(_obj,(LTLine)): list_line.append(_obj) list_crosspoints = self.recognize_crosspoints(list_line) #聚类 cluster_crosspoints = [] for _point in list_crosspoints: cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]}) while 1: _find = False new_cluster_crosspoints = [] for l_point in cluster_crosspoints: _flag = False for l_n_point in new_cluster_crosspoints: line1 = l_point.get("lines") line2 = l_n_point.get("lines") if len(line1&line2)>0: _find = True _flag = True l_n_point["lines"] = line1.union(line2) l_n_point["points"].extend(l_point["points"]) if not _flag: new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")}) cluster_crosspoints = new_cluster_crosspoints if not _find: break # print(len(cluster_crosspoints)) list_l_rect = [] for table_crosspoint in cluster_crosspoints: list_rect = self.crosspoint2rect(table_crosspoint.get("points")) list_l_rect.append(list_rect) return list_l_rect def crosspoint2rect(self,list_crosspoint,margin=4): dict_line_points = {} for _point in list_crosspoint: lines = list(_point.get("lines")) for _line in lines: if _line not in dict_line_points: dict_line_points[_line] = {"direct":None,"points":[]} dict_line_points[_line]["points"].append(_point) #排序 for k,v in dict_line_points.items(): list_x = [] list_y = [] for _p in v["points"]: list_x.append(_p.get("point")[0]) list_y.append(_p.get("point")[1]) if max(list_x)-min(list_x)>max(list_y)-min(list_y): v.get("points").sort(key=lambda x:x.get("point")[0]) v["direct"] = "row" else: v.get("points").sort(key=lambda x:x.get("point")[1]) v["direct"] = "column" list_rect = [] for _point in list_crosspoint: if _point["buttom"]>=margin and _point["right"]>=margin: lines = list(_point.get("lines")) _line = lines[0] if dict_line_points[_line]["direct"]=="column": _line = lines[1] next_point = None for p1 in dict_line_points[_line]["points"]: if p1["buttom"]>=margin and p1["point"][0]>_point["point"][0]: next_point = p1 break if not next_point: continue lines = list(next_point.get("lines")) _line = lines[0] if dict_line_points[_line]["direct"]=="row": _line = lines[1] final_point = None for p1 in dict_line_points[_line]["points"]: if p1["left"]>=margin and p1["point"][1]>next_point["point"][1]: final_point = p1 break if not final_point: continue _r = LTRect(1,(_point["point"][0],_point["point"][1],final_point["point"][0],final_point["point"][1])) list_rect.append(_r) return list_rect def cross_point(self,line1, line2,segment=True,margin=2): point_is_exist = False x = y = 0 x1,y1,x2,y2 = line1 x3,y3,x4,y4 = line2 if (x2 - x1) == 0: k1 = None b1 = 0 else: k1 = (y2 - y1) * 1.0 / (x2 - x1) # 计算k1,由于点均为整数,需要进行浮点数转化 b1 = y1 * 1.0 - x1 * k1 * 1.0 # 整型转浮点型是关键 if (x4 - x3) == 0: # L2直线斜率不存在 k2 = None b2 = 0 else: k2 = (y4 - y3) * 1.0 / (x4 - x3) # 斜率存在 b2 = y3 * 1.0 - x3 * k2 * 1.0 if k1 is None: if not k2 is None: x = x1 y = k2 * x1 + b2 point_is_exist = True elif k2 is None: x = x3 y = k1 * x3 + b1 elif not k2 == k1: x = (b2 - b1) * 1.0 / (k1 - k2) y = k1 * x * 1.0 + b1 * 1.0 point_is_exist = True left = 0 right = 0 top = 0 buttom = 0 if point_is_exist: if segment: if x>=(min(x1,x2)-margin) and x<=(max(x1,x2)+margin) and y>=(min(y1,y2)-margin) and y<=(max(y1,y2)+margin): if x>=(min(x3,x4)-margin) and x<=(max(x3,x4)+margin) and y>=(min(y3,y4)-margin) and y<=(max(y3,y4)+margin): point_is_exist = True left = abs(min(x1,x3)-x) right = abs(max(x2,x4)-x) top = abs(min(y1,y3)-y) buttom = abs(max(y2,y4)-y) else: point_is_exist = False else: point_is_exist = False line1_key = "%.2f-%.2f-%.2f-%.2f"%(x1,y1,x2,y2) line2_key = "%.2f-%.2f-%.2f-%.2f"%(x3,y3,x4,y4) return point_is_exist, {"point":[x, y],"left":left,"right":right,"top":top,"buttom":buttom,"lines":set([line1_key,line2_key]),"p_lines":[line1,line2]} def unionTable(self,list_table,fixspan=True,margin=2): set_x = set() set_y = set() list_cell = [] for _t in list_table: for _line in _t: list_cell.extend(_line) clusters_rects = [] #根据y1聚类 set_id = set() list_cell_dump = [] for _cell in list_cell: _id = id(_cell) if _id in set_id: continue set_id.add(_id) list_cell_dump.append(_cell) list_cell = list_cell_dump list_cell.sort(key=lambda x:x.get("bbox")[3]) for _rect in list_cell: _y0 = _rect.get("bbox")[3] _find = False for l_cr in clusters_rects: if abs(l_cr[0].get("bbox")[3]-_y0)<2: _find = True l_cr.append(_rect) break if not _find: clusters_rects.append([_rect]) clusters_rects.sort(key=lambda x:x[0].get("bbox")[3],reverse=True) for l_cr in clusters_rects: l_cr.sort(key=lambda x:x.get("bbox")[0]) for l_r in clusters_rects: print(len(l_r)) for _line in clusters_rects: for _rect in _line: (x0,y0,x1,y1) = _rect.get("bbox") set_x.add(x0) set_x.add(x1) set_y.add(y0) set_y.add(y1) if len(set_x)==0 or len(set_y)==0: return list_x = list(set_x) list_y = list(set_y) list_x.sort(key=lambda x:x) list_y.sort(key=lambda x:x,reverse=True) _table = [] for _line in clusters_rects: table_line = [] for _rect in _line: (x0,y0,x1,y1) = _rect.get("bbox") _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect.get("rect"),"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":_rect.get("text","")} table_line.append(_cell) _table.append(table_line) # print("=====================>>") # for _line in _table: # for _cell in _line: # print(_cell,end="\t") # print("\n") # print("=====================>>") # print(_table) if fixspan: for _line in _table: for c_i in range(len(_line)): _cell = _line[c_i] if _cell.get("columnspan")>1: _cospan = _cell.get("columnspan") _cell["columnspan"] = 1 for i in range(1,_cospan): _line.insert(c_i,_cell) for l_i in range(len(_table)): _line = _table[l_i] for c_i in range(len(_line)): _cell = _line[c_i] if _cell.get("rowspan")>1: _rospan = _cell.get("rowspan") _cell["rowspan"] = 1 for i in range(1,_rospan): _table[l_i+i].insert(c_i,_cell) table_bbox = (_table[0][0].get("bbox")[0],_table[0][0].get("bbox")[1],_table[-1][-1].get("bbox")[2],_table[-1][-1].get("bbox")[3]) ta = {"bbox":table_bbox,"table":_table} return ta def rect2table(self,list_textbox,list_rect,in_objs,margin=0.2,fixspan=True): _table = [] set_x = set() set_y = set() clusters_rects = [] #根据y1聚类 list_rect.sort(key=lambda x:x.bbox[3]) for _rect in list_rect: _y0 = _rect.bbox[3] _find = False for l_cr in clusters_rects: if abs(l_cr[0].bbox[3]-_y0)<2: _find = True l_cr.append(_rect) break if not _find: clusters_rects.append([_rect]) clusters_rects.sort(key=lambda x:x[0].bbox[3],reverse=True) for l_cr in clusters_rects: l_cr.sort(key=lambda x:x.bbox[0]) #cul spans for _line in clusters_rects: for _rect in _line: (x0,y0,x1,y1) = _rect.bbox set_x.add(x0) set_x.add(x1) set_y.add(y0) set_y.add(y1) if len(set_x)==0 or len(set_y)==0: return list_x = list(set_x) list_y = list(set_y) list_x.sort(key=lambda x:x) list_y.sort(key=lambda x:x,reverse=True) pop_x = [] for i in range(len(list_x)-1): _i = len(list_x)-i-1 l_i = _i-1 if abs(list_x[_i]-list_x[l_i])<2: pop_x.append(_i) pop_x.sort(key=lambda x:x,reverse=True) for _x in pop_x: list_x.pop(_x) # pop_x = [] for i in range(len(list_y)-1): _i = len(list_y)-i-1 l_i = _i-1 if abs(list_y[_i]-list_y[l_i])<2: pop_x.append(_i) pop_x.sort(key=lambda x:x,reverse=True) for _x in pop_x: list_y.pop(_x) print(list_x) print(list_y) for _line in clusters_rects: table_line = [] for _rect in _line: (x0,y0,x1,y1) = _rect.bbox _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect,"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":""} table_line.append(_cell) _table.append(table_line) list_textbox.sort(key=lambda x:x.bbox[0]) list_textbox.sort(key=lambda x:x.bbox[3],reverse=True) for textbox in list_textbox: (x0,y0,x1,y1) = textbox.bbox _text = textbox.get_text() _find = False for table_line in _table: for _cell in table_line: if self.inbox(textbox.bbox,_cell["bbox"]): _cell["text"]+= _text in_objs.add(textbox) _find = True break if _find: break if fixspan: for _line in _table: for c_i in range(len(_line)): _cell = _line[c_i] if _cell.get("columnspan")>1: _cospan = _cell.get("columnspan") _cell["columnspan"] = 1 for i in range(1,_cospan): _line.insert(c_i,_cell) for l_i in range(len(_table)): _line = _table[l_i] for c_i in range(len(_line)): _cell = _line[c_i] if _cell.get("rowspan")>1: _rospan = _cell.get("rowspan") _cell["rowspan"] = 1 for i in range(1,_rospan): if l_i+i=bbox0[2] and bbox_g[3]>=bbox0[3]: # return 1 if self.getIOU(bbox0,bbox_g)>0.5: return 1 return 0 def getIOU(self,bbox0,bbox1): width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0]) height = max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1])-(bbox0[3]-bbox0[1]+bbox1[3]-bbox1[1]) if width<0 and height<0: return abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1])))) return 0 def getspan(self,_list,x0,x1,margin): _count = 0 (x0,x1) = (min(x0,x1),max(x0,x1)) for _x in _list: if _x>=(x0-margin) and _x<=(x1+margin): _count += 1 return _count-1