123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539 |
- from pdfminer.layout import *
- class LineTable():
- def recognize_table(self,list_textbox,list_line):
- self.list_line = list_line
- self.list_crosspoints = self.recognize_crosspoints(list_line)
- #聚类
- cluster_crosspoints = []
- for _point in self.list_crosspoints:
- cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
- while 1:
- _find = False
- new_cluster_crosspoints = []
- for l_point in cluster_crosspoints:
- _flag = False
- for l_n_point in new_cluster_crosspoints:
- line1 = l_point.get("lines")
- line2 = l_n_point.get("lines")
- if len(line1&line2)>0:
- _find = True
- _flag = True
- l_n_point["lines"] = line1.union(line2)
- l_n_point["points"].extend(l_point["points"])
- if not _flag:
- new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
- cluster_crosspoints = new_cluster_crosspoints
- if not _find:
- break
- # print(len(cluster_crosspoints))
- list_l_rect = []
- for table_crosspoint in cluster_crosspoints:
- list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
- list_l_rect.append(list_rect)
- in_objs = set()
- list_tables = []
- for l_rect in list_l_rect:
- _ta = self.rect2table(list_textbox,l_rect,in_objs)
- if _ta:
- list_tables.append(_ta)
- return list_tables,in_objs,list_l_rect
- def recognize_table_by_rect(self,list_textbox,list_rect,margin=2):
- dump_margin = 5
- list_rect_tmp = []
- #去重
- for _rect in list_rect:
- if (_rect.bbox[3]-_rect.bbox[1]<10) or (abs(_rect.bbox[2]-_rect.bbox[0])<5):
- continue
- _find = False
- for _tmp in list_rect_tmp:
- for i in range(4):
- if abs(_rect.bbox[i]-_tmp.bbox[i])<dump_margin:
- pass
- else:
- _find = False
- break
- if i==3:
- _find = True
- if _find:
- break
- if not _find:
- list_rect_tmp.append(_rect)
- # print("=====",len(list_rect),len(list_rect_tmp))
- # print(list_rect_tmp)
- # from matplotlib import pyplot as plt
- # plt.figure()
- # for _rect in list_rect_tmp:
- # x0,y0,x1,y1 = _rect.bbox
- # plt.boxplot(_rect.bbox)
- # plt.show()
- cluster_rect = []
- for _rect in list_rect:
- _find = False
- for cr in cluster_rect:
- for cr_rect in cr:
- if abs((cr_rect.bbox[2]-cr_rect.bbox[0]+_rect.bbox[2]-_rect.bbox[0])-(max(cr_rect.bbox[2],_rect.bbox[2])-min(cr_rect.bbox[0],_rect.bbox[0])))<margin:
- _find = True
- cr.append(_rect)
- break
- elif abs((cr_rect.bbox[3]-cr_rect.bbox[1]+_rect.bbox[3]-_rect.bbox[1])-(max(cr_rect.bbox[3],_rect.bbox[3])-min(cr_rect.bbox[1],_rect.bbox[1])))<margin:
- _find = True
- cr.append(_rect)
- break
- if _find:
- break
- if not _find:
- cluster_rect.append([_rect])
- list_l_rect = cluster_rect
- in_objs = set()
- list_tables = []
- for l_rect in list_l_rect:
- _ta = self.rect2table(list_textbox,l_rect,in_objs)
- if _ta:
- list_tables.append(_ta)
- return list_tables,in_objs,list_l_rect
- def recognize_crosspoints(self,list_line):
- from matplotlib import pyplot as plt
- list_crosspoints = []
- # print("lines num",len(list_line))
- for _i in range(len(list_line)):
- for _j in range(len(list_line)):
- line1 = list_line[_i].__dict__.get("bbox")
- line2 = list_line[_j].__dict__.get("bbox")
- exists,point = self.cross_point(line1,line2)
- if exists:
- list_crosspoints.append(point)
- # plt.figure()
- # for _line in list_line:
- # x0,y0,x1,y1 = _line.__dict__.get("bbox")
- # plt.plot([x0,x1],[y0,y1])
- # for _line in list_line:
- # x0,y0,x1,y1 = _line.bbox
- # plt.plot([x0,x1],[y0,y1])
- # for point in list_crosspoints:
- # plt.scatter(point.get("point")[0],point.get("point")[1])
- # plt.show()
- # print(list_crosspoints)
- # print("points num",len(list_crosspoints))
- return list_crosspoints
- def recognize_rect(self,_page):
- list_line = []
- for _obj in _page._objs:
- if isinstance(_obj,(LTLine)):
- list_line.append(_obj)
- list_crosspoints = self.recognize_crosspoints(list_line)
- #聚类
- cluster_crosspoints = []
- for _point in list_crosspoints:
- cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
- while 1:
- _find = False
- new_cluster_crosspoints = []
- for l_point in cluster_crosspoints:
- _flag = False
- for l_n_point in new_cluster_crosspoints:
- line1 = l_point.get("lines")
- line2 = l_n_point.get("lines")
- if len(line1&line2)>0:
- _find = True
- _flag = True
- l_n_point["lines"] = line1.union(line2)
- l_n_point["points"].extend(l_point["points"])
- if not _flag:
- new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
- cluster_crosspoints = new_cluster_crosspoints
- if not _find:
- break
- # print(len(cluster_crosspoints))
- list_l_rect = []
- for table_crosspoint in cluster_crosspoints:
- list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
- list_l_rect.append(list_rect)
- return list_l_rect
- def crosspoint2rect(self,list_crosspoint,margin=4):
- dict_line_points = {}
- for _point in list_crosspoint:
- lines = list(_point.get("lines"))
- for _line in lines:
- if _line not in dict_line_points:
- dict_line_points[_line] = {"direct":None,"points":[]}
- dict_line_points[_line]["points"].append(_point)
- #排序
- for k,v in dict_line_points.items():
- list_x = []
- list_y = []
- for _p in v["points"]:
- list_x.append(_p.get("point")[0])
- list_y.append(_p.get("point")[1])
- if max(list_x)-min(list_x)>max(list_y)-min(list_y):
- v.get("points").sort(key=lambda x:x.get("point")[0])
- v["direct"] = "row"
- else:
- v.get("points").sort(key=lambda x:x.get("point")[1])
- v["direct"] = "column"
- list_rect = []
- for _point in list_crosspoint:
- if _point["buttom"]>=margin and _point["right"]>=margin:
- lines = list(_point.get("lines"))
- _line = lines[0]
- if dict_line_points[_line]["direct"]=="column":
- _line = lines[1]
- next_point = None
- for p1 in dict_line_points[_line]["points"]:
- if p1["buttom"]>=margin and p1["point"][0]>_point["point"][0]:
- next_point = p1
- break
- if not next_point:
- continue
- lines = list(next_point.get("lines"))
- _line = lines[0]
- if dict_line_points[_line]["direct"]=="row":
- _line = lines[1]
- final_point = None
- for p1 in dict_line_points[_line]["points"]:
- if p1["left"]>=margin and p1["point"][1]>next_point["point"][1]:
- final_point = p1
- break
- if not final_point:
- continue
- _r = LTRect(1,(_point["point"][0],_point["point"][1],final_point["point"][0],final_point["point"][1]))
- list_rect.append(_r)
- return list_rect
- def cross_point(self,line1, line2,segment=True,margin=2):
- point_is_exist = False
- x = y = 0
- x1,y1,x2,y2 = line1
- x3,y3,x4,y4 = line2
- if (x2 - x1) == 0:
- k1 = None
- b1 = 0
- else:
- k1 = (y2 - y1) * 1.0 / (x2 - x1) # 计算k1,由于点均为整数,需要进行浮点数转化
- b1 = y1 * 1.0 - x1 * k1 * 1.0 # 整型转浮点型是关键
- if (x4 - x3) == 0: # L2直线斜率不存在
- k2 = None
- b2 = 0
- else:
- k2 = (y4 - y3) * 1.0 / (x4 - x3) # 斜率存在
- b2 = y3 * 1.0 - x3 * k2 * 1.0
- if k1 is None:
- if not k2 is None:
- x = x1
- y = k2 * x1 + b2
- point_is_exist = True
- elif k2 is None:
- x = x3
- y = k1 * x3 + b1
- elif not k2 == k1:
- x = (b2 - b1) * 1.0 / (k1 - k2)
- y = k1 * x * 1.0 + b1 * 1.0
- point_is_exist = True
- left = 0
- right = 0
- top = 0
- buttom = 0
- if point_is_exist:
- if segment:
- if x>=(min(x1,x2)-margin) and x<=(max(x1,x2)+margin) and y>=(min(y1,y2)-margin) and y<=(max(y1,y2)+margin):
- if x>=(min(x3,x4)-margin) and x<=(max(x3,x4)+margin) and y>=(min(y3,y4)-margin) and y<=(max(y3,y4)+margin):
- point_is_exist = True
- left = abs(min(x1,x3)-x)
- right = abs(max(x2,x4)-x)
- top = abs(min(y1,y3)-y)
- buttom = abs(max(y2,y4)-y)
- else:
- point_is_exist = False
- else:
- point_is_exist = False
- line1_key = "%.2f-%.2f-%.2f-%.2f"%(x1,y1,x2,y2)
- line2_key = "%.2f-%.2f-%.2f-%.2f"%(x3,y3,x4,y4)
- return point_is_exist, {"point":[x, y],"left":left,"right":right,"top":top,"buttom":buttom,"lines":set([line1_key,line2_key])}
- def unionTable(self,list_table,fixspan=True,margin=2):
- set_x = set()
- set_y = set()
- list_cell = []
- for _t in list_table:
- for _line in _t:
- list_cell.extend(_line)
- clusters_rects = []
- #根据y1聚类
- set_id = set()
- list_cell_dump = []
- for _cell in list_cell:
- _id = id(_cell)
- if _id in set_id:
- continue
- set_id.add(_id)
- list_cell_dump.append(_cell)
- list_cell = list_cell_dump
- list_cell.sort(key=lambda x:x.get("bbox")[3])
- for _rect in list_cell:
- _y0 = _rect.get("bbox")[3]
- _find = False
- for l_cr in clusters_rects:
- if abs(l_cr[0].get("bbox")[3]-_y0)<2:
- _find = True
- l_cr.append(_rect)
- break
- if not _find:
- clusters_rects.append([_rect])
- clusters_rects.sort(key=lambda x:x[0].get("bbox")[3],reverse=True)
- for l_cr in clusters_rects:
- l_cr.sort(key=lambda x:x.get("bbox")[0])
- print("=============:")
- for l_r in clusters_rects:
- print(len(l_r))
- for _line in clusters_rects:
- for _rect in _line:
- (x0,y0,x1,y1) = _rect.get("bbox")
- set_x.add(x0)
- set_x.add(x1)
- set_y.add(y0)
- set_y.add(y1)
- if len(set_x)==0 or len(set_y)==0:
- return
- list_x = list(set_x)
- list_y = list(set_y)
- list_x.sort(key=lambda x:x)
- list_y.sort(key=lambda x:x,reverse=True)
- _table = []
- for _line in clusters_rects:
- table_line = []
- for _rect in _line:
- (x0,y0,x1,y1) = _rect.get("bbox")
- _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect.get("rect"),"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":_rect.get("text","")}
- table_line.append(_cell)
- _table.append(table_line)
- # print("=====================>>")
- # for _line in _table:
- # for _cell in _line:
- # print(_cell,end="\t")
- # print("\n")
- # print("=====================>>")
- # print(_table)
- if fixspan:
- for _line in _table:
- for c_i in range(len(_line)):
- _cell = _line[c_i]
- if _cell.get("columnspan")>1:
- _cospan = _cell.get("columnspan")
- _cell["columnspan"] = 1
- for i in range(1,_cospan):
- _line.insert(c_i,_cell)
- for l_i in range(len(_table)):
- _line = _table[l_i]
- for c_i in range(len(_line)):
- _cell = _line[c_i]
- if _cell.get("rowspan")>1:
- _rospan = _cell.get("rowspan")
- _cell["rowspan"] = 1
- for i in range(1,_rospan):
- _table[l_i+i].insert(c_i,_cell)
- table_bbox = (_table[0][0].get("bbox")[0],_table[0][0].get("bbox")[1],_table[-1][-1].get("bbox")[2],_table[-1][-1].get("bbox")[3])
- ta = {"bbox":table_bbox,"table":_table}
- return ta
- def rect2table(self,list_textbox,list_rect,in_objs,margin=0.2,fixspan=True):
- _table = []
- set_x = set()
- set_y = set()
- clusters_rects = []
- #根据y1聚类
- list_rect.sort(key=lambda x:x.bbox[3])
- for _rect in list_rect:
- _y0 = _rect.bbox[3]
- _find = False
- for l_cr in clusters_rects:
- if abs(l_cr[0].bbox[3]-_y0)<2:
- _find = True
- l_cr.append(_rect)
- break
- if not _find:
- clusters_rects.append([_rect])
- clusters_rects.sort(key=lambda x:x[0].bbox[3],reverse=True)
- for l_cr in clusters_rects:
- l_cr.sort(key=lambda x:x.bbox[0])
- #cul spans
- for _line in clusters_rects:
- for _rect in _line:
- (x0,y0,x1,y1) = _rect.bbox
- set_x.add(x0)
- set_x.add(x1)
- set_y.add(y0)
- set_y.add(y1)
- if len(set_x)==0 or len(set_y)==0:
- return
- list_x = list(set_x)
- list_y = list(set_y)
- list_x.sort(key=lambda x:x)
- list_y.sort(key=lambda x:x,reverse=True)
- pop_x = []
- for i in range(len(list_x)-1):
- _i = len(list_x)-i-1
- l_i = _i-1
- if abs(list_x[_i]-list_x[l_i])<2:
- pop_x.append(_i)
- pop_x.sort(key=lambda x:x,reverse=True)
- for _x in pop_x:
- list_x.pop(_x)
- #
- pop_x = []
- for i in range(len(list_y)-1):
- _i = len(list_y)-i-1
- l_i = _i-1
- if abs(list_y[_i]-list_y[l_i])<2:
- pop_x.append(_i)
- pop_x.sort(key=lambda x:x,reverse=True)
- for _x in pop_x:
- list_y.pop(_x)
- print(list_x)
- print(list_y)
- for _line in clusters_rects:
- table_line = []
- for _rect in _line:
- (x0,y0,x1,y1) = _rect.bbox
- _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect,"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":""}
- table_line.append(_cell)
- _table.append(table_line)
- list_textbox.sort(key=lambda x:x.bbox[0])
- list_textbox.sort(key=lambda x:x.bbox[3],reverse=True)
- for textbox in list_textbox:
- (x0,y0,x1,y1) = textbox.bbox
- _text = textbox.get_text()
- _find = False
- for table_line in _table:
- for _cell in table_line:
- if self.inbox(textbox.bbox,_cell["bbox"]):
- _cell["text"]+= _text
- in_objs.add(textbox)
- _find = True
- break
- if _find:
- break
- if fixspan:
- for _line in _table:
- for c_i in range(len(_line)):
- _cell = _line[c_i]
- if _cell.get("columnspan")>1:
- _cospan = _cell.get("columnspan")
- _cell["columnspan"] = 1
- for i in range(1,_cospan):
- _line.insert(c_i,_cell)
- for l_i in range(len(_table)):
- _line = _table[l_i]
- for c_i in range(len(_line)):
- _cell = _line[c_i]
- if _cell.get("rowspan")>1:
- _rospan = _cell.get("rowspan")
- _cell["rowspan"] = 1
- for i in range(1,_rospan):
- if l_i+i<len(_table)-1:
- print(len(_table),l_i+i)
- _table[l_i+i].insert(c_i,_cell)
- # print("=======")
- # for _line in _table:
- # for _cell in _line:
- # print("[%s]"%_cell.get("text")[:10].replace("\n",''),end="\t\t")
- # print("\n")
- # print("===========")
- table_bbox = (_table[0][0].get("bbox")[0],_table[0][0].get("bbox")[1],_table[-1][-1].get("bbox")[2],_table[-1][-1].get("bbox")[3])
- ta = {"bbox":table_bbox,"table":_table}
- return ta
- def inbox(self,bbox0,bbox_g):
- # if bbox_g[0]<=bbox0[0] and bbox_g[1]<=bbox0[1] and bbox_g[2]>=bbox0[2] and bbox_g[3]>=bbox0[3]:
- # return 1
- if self.getIOU(bbox0,bbox_g)>0.5:
- return 1
- return 0
- def getIOU(self,bbox0,bbox1):
- width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0])
- height = max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1])-(bbox0[3]-bbox0[1]+bbox1[3]-bbox1[1])
- if width<0 and height<0:
- return abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))
- return 0
- def getspan(self,_list,x0,x1,margin):
- _count = 0
- (x0,x1) = (min(x0,x1),max(x0,x1))
- for _x in _list:
- if _x>=(x0-margin) and _x<=(x1+margin):
- _count += 1
- return _count-1
|