tableutils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. from pdfminer.layout import *
  2. class LineTable():
  3. def recognize_table(self,list_textbox,list_line):
  4. self.list_line = list_line
  5. self.list_crosspoints = self.recognize_crosspoints(list_line)
  6. #聚类
  7. cluster_crosspoints = []
  8. for _point in self.list_crosspoints:
  9. cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
  10. while 1:
  11. _find = False
  12. new_cluster_crosspoints = []
  13. for l_point in cluster_crosspoints:
  14. _flag = False
  15. for l_n_point in new_cluster_crosspoints:
  16. line1 = l_point.get("lines")
  17. line2 = l_n_point.get("lines")
  18. if len(line1&line2)>0:
  19. _find = True
  20. _flag = True
  21. l_n_point["lines"] = line1.union(line2)
  22. l_n_point["points"].extend(l_point["points"])
  23. if not _flag:
  24. new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
  25. cluster_crosspoints = new_cluster_crosspoints
  26. if not _find:
  27. break
  28. # print(len(cluster_crosspoints))
  29. list_l_rect = []
  30. for table_crosspoint in cluster_crosspoints:
  31. list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
  32. list_l_rect.append(list_rect)
  33. in_objs = set()
  34. list_tables = []
  35. for l_rect in list_l_rect:
  36. _ta = self.rect2table(list_textbox,l_rect,in_objs)
  37. if _ta:
  38. list_tables.append(_ta)
  39. return list_tables,in_objs,list_l_rect
  40. def recognize_table_by_rect(self,list_textbox,list_rect,margin=2):
  41. dump_margin = 5
  42. list_rect_tmp = []
  43. #去重
  44. for _rect in list_rect:
  45. if (_rect.bbox[3]-_rect.bbox[1]<10) or (abs(_rect.bbox[2]-_rect.bbox[0])<5):
  46. continue
  47. _find = False
  48. for _tmp in list_rect_tmp:
  49. for i in range(4):
  50. if abs(_rect.bbox[i]-_tmp.bbox[i])<dump_margin:
  51. pass
  52. else:
  53. _find = False
  54. break
  55. if i==3:
  56. _find = True
  57. if _find:
  58. break
  59. if not _find:
  60. list_rect_tmp.append(_rect)
  61. # print("=====",len(list_rect),len(list_rect_tmp))
  62. # print(list_rect_tmp)
  63. # from matplotlib import pyplot as plt
  64. # plt.figure()
  65. # for _rect in list_rect_tmp:
  66. # x0,y0,x1,y1 = _rect.bbox
  67. # plt.boxplot(_rect.bbox)
  68. # plt.show()
  69. cluster_rect = []
  70. for _rect in list_rect:
  71. _find = False
  72. for cr in cluster_rect:
  73. for cr_rect in cr:
  74. if abs((cr_rect.bbox[2]-cr_rect.bbox[0]+_rect.bbox[2]-_rect.bbox[0])-(max(cr_rect.bbox[2],_rect.bbox[2])-min(cr_rect.bbox[0],_rect.bbox[0])))<margin:
  75. _find = True
  76. cr.append(_rect)
  77. break
  78. elif abs((cr_rect.bbox[3]-cr_rect.bbox[1]+_rect.bbox[3]-_rect.bbox[1])-(max(cr_rect.bbox[3],_rect.bbox[3])-min(cr_rect.bbox[1],_rect.bbox[1])))<margin:
  79. _find = True
  80. cr.append(_rect)
  81. break
  82. if _find:
  83. break
  84. if not _find:
  85. cluster_rect.append([_rect])
  86. list_l_rect = cluster_rect
  87. in_objs = set()
  88. list_tables = []
  89. for l_rect in list_l_rect:
  90. _ta = self.rect2table(list_textbox,l_rect,in_objs)
  91. if _ta:
  92. list_tables.append(_ta)
  93. return list_tables,in_objs,list_l_rect
  94. def recognize_crosspoints(self,list_line):
  95. from matplotlib import pyplot as plt
  96. list_crosspoints = []
  97. # print("lines num",len(list_line))
  98. for _i in range(len(list_line)):
  99. for _j in range(len(list_line)):
  100. line1 = list_line[_i].__dict__.get("bbox")
  101. line2 = list_line[_j].__dict__.get("bbox")
  102. exists,point = self.cross_point(line1,line2)
  103. if exists:
  104. list_crosspoints.append(point)
  105. # plt.figure()
  106. # for _line in list_line:
  107. # x0,y0,x1,y1 = _line.__dict__.get("bbox")
  108. # plt.plot([x0,x1],[y0,y1])
  109. # for _line in list_line:
  110. # x0,y0,x1,y1 = _line.bbox
  111. # plt.plot([x0,x1],[y0,y1])
  112. # for point in list_crosspoints:
  113. # plt.scatter(point.get("point")[0],point.get("point")[1])
  114. # plt.show()
  115. # print(list_crosspoints)
  116. # print("points num",len(list_crosspoints))
  117. return list_crosspoints
  118. def recognize_rect(self,_page):
  119. list_line = []
  120. for _obj in _page._objs:
  121. if isinstance(_obj,(LTLine)):
  122. list_line.append(_obj)
  123. list_crosspoints = self.recognize_crosspoints(list_line)
  124. #聚类
  125. cluster_crosspoints = []
  126. for _point in list_crosspoints:
  127. cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
  128. while 1:
  129. _find = False
  130. new_cluster_crosspoints = []
  131. for l_point in cluster_crosspoints:
  132. _flag = False
  133. for l_n_point in new_cluster_crosspoints:
  134. line1 = l_point.get("lines")
  135. line2 = l_n_point.get("lines")
  136. if len(line1&line2)>0:
  137. _find = True
  138. _flag = True
  139. l_n_point["lines"] = line1.union(line2)
  140. l_n_point["points"].extend(l_point["points"])
  141. if not _flag:
  142. new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
  143. cluster_crosspoints = new_cluster_crosspoints
  144. if not _find:
  145. break
  146. # print(len(cluster_crosspoints))
  147. list_l_rect = []
  148. for table_crosspoint in cluster_crosspoints:
  149. list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
  150. list_l_rect.append(list_rect)
  151. return list_l_rect
  152. def crosspoint2rect(self,list_crosspoint,margin=4):
  153. dict_line_points = {}
  154. for _point in list_crosspoint:
  155. lines = list(_point.get("lines"))
  156. for _line in lines:
  157. if _line not in dict_line_points:
  158. dict_line_points[_line] = {"direct":None,"points":[]}
  159. dict_line_points[_line]["points"].append(_point)
  160. #排序
  161. for k,v in dict_line_points.items():
  162. list_x = []
  163. list_y = []
  164. for _p in v["points"]:
  165. list_x.append(_p.get("point")[0])
  166. list_y.append(_p.get("point")[1])
  167. if max(list_x)-min(list_x)>max(list_y)-min(list_y):
  168. v.get("points").sort(key=lambda x:x.get("point")[0])
  169. v["direct"] = "row"
  170. else:
  171. v.get("points").sort(key=lambda x:x.get("point")[1])
  172. v["direct"] = "column"
  173. list_rect = []
  174. for _point in list_crosspoint:
  175. if _point["buttom"]>=margin and _point["right"]>=margin:
  176. lines = list(_point.get("lines"))
  177. _line = lines[0]
  178. if dict_line_points[_line]["direct"]=="column":
  179. _line = lines[1]
  180. next_point = None
  181. for p1 in dict_line_points[_line]["points"]:
  182. if p1["buttom"]>=margin and p1["point"][0]>_point["point"][0]:
  183. next_point = p1
  184. break
  185. if not next_point:
  186. continue
  187. lines = list(next_point.get("lines"))
  188. _line = lines[0]
  189. if dict_line_points[_line]["direct"]=="row":
  190. _line = lines[1]
  191. final_point = None
  192. for p1 in dict_line_points[_line]["points"]:
  193. if p1["left"]>=margin and p1["point"][1]>next_point["point"][1]:
  194. final_point = p1
  195. break
  196. if not final_point:
  197. continue
  198. _r = LTRect(1,(_point["point"][0],_point["point"][1],final_point["point"][0],final_point["point"][1]))
  199. list_rect.append(_r)
  200. return list_rect
  201. def cross_point(self,line1, line2,segment=True,margin=2):
  202. point_is_exist = False
  203. x = y = 0
  204. x1,y1,x2,y2 = line1
  205. x3,y3,x4,y4 = line2
  206. if (x2 - x1) == 0:
  207. k1 = None
  208. b1 = 0
  209. else:
  210. k1 = (y2 - y1) * 1.0 / (x2 - x1) # 计算k1,由于点均为整数,需要进行浮点数转化
  211. b1 = y1 * 1.0 - x1 * k1 * 1.0 # 整型转浮点型是关键
  212. if (x4 - x3) == 0: # L2直线斜率不存在
  213. k2 = None
  214. b2 = 0
  215. else:
  216. k2 = (y4 - y3) * 1.0 / (x4 - x3) # 斜率存在
  217. b2 = y3 * 1.0 - x3 * k2 * 1.0
  218. if k1 is None:
  219. if not k2 is None:
  220. x = x1
  221. y = k2 * x1 + b2
  222. point_is_exist = True
  223. elif k2 is None:
  224. x = x3
  225. y = k1 * x3 + b1
  226. elif not k2 == k1:
  227. x = (b2 - b1) * 1.0 / (k1 - k2)
  228. y = k1 * x * 1.0 + b1 * 1.0
  229. point_is_exist = True
  230. left = 0
  231. right = 0
  232. top = 0
  233. buttom = 0
  234. if point_is_exist:
  235. if segment:
  236. if x>=(min(x1,x2)-margin) and x<=(max(x1,x2)+margin) and y>=(min(y1,y2)-margin) and y<=(max(y1,y2)+margin):
  237. if x>=(min(x3,x4)-margin) and x<=(max(x3,x4)+margin) and y>=(min(y3,y4)-margin) and y<=(max(y3,y4)+margin):
  238. point_is_exist = True
  239. left = abs(min(x1,x3)-x)
  240. right = abs(max(x2,x4)-x)
  241. top = abs(min(y1,y3)-y)
  242. buttom = abs(max(y2,y4)-y)
  243. else:
  244. point_is_exist = False
  245. else:
  246. point_is_exist = False
  247. line1_key = "%.2f-%.2f-%.2f-%.2f"%(x1,y1,x2,y2)
  248. line2_key = "%.2f-%.2f-%.2f-%.2f"%(x3,y3,x4,y4)
  249. return point_is_exist, {"point":[x, y],"left":left,"right":right,"top":top,"buttom":buttom,"lines":set([line1_key,line2_key])}
  250. def unionTable(self,list_table,fixspan=True,margin=2):
  251. set_x = set()
  252. set_y = set()
  253. list_cell = []
  254. for _t in list_table:
  255. for _line in _t:
  256. list_cell.extend(_line)
  257. clusters_rects = []
  258. #根据y1聚类
  259. set_id = set()
  260. list_cell_dump = []
  261. for _cell in list_cell:
  262. _id = id(_cell)
  263. if _id in set_id:
  264. continue
  265. set_id.add(_id)
  266. list_cell_dump.append(_cell)
  267. list_cell = list_cell_dump
  268. list_cell.sort(key=lambda x:x.get("bbox")[3])
  269. for _rect in list_cell:
  270. _y0 = _rect.get("bbox")[3]
  271. _find = False
  272. for l_cr in clusters_rects:
  273. if abs(l_cr[0].get("bbox")[3]-_y0)<2:
  274. _find = True
  275. l_cr.append(_rect)
  276. break
  277. if not _find:
  278. clusters_rects.append([_rect])
  279. clusters_rects.sort(key=lambda x:x[0].get("bbox")[3],reverse=True)
  280. for l_cr in clusters_rects:
  281. l_cr.sort(key=lambda x:x.get("bbox")[0])
  282. print("=============:")
  283. for l_r in clusters_rects:
  284. print(len(l_r))
  285. for _line in clusters_rects:
  286. for _rect in _line:
  287. (x0,y0,x1,y1) = _rect.get("bbox")
  288. set_x.add(x0)
  289. set_x.add(x1)
  290. set_y.add(y0)
  291. set_y.add(y1)
  292. if len(set_x)==0 or len(set_y)==0:
  293. return
  294. list_x = list(set_x)
  295. list_y = list(set_y)
  296. list_x.sort(key=lambda x:x)
  297. list_y.sort(key=lambda x:x,reverse=True)
  298. _table = []
  299. for _line in clusters_rects:
  300. table_line = []
  301. for _rect in _line:
  302. (x0,y0,x1,y1) = _rect.get("bbox")
  303. _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect.get("rect"),"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":_rect.get("text","")}
  304. table_line.append(_cell)
  305. _table.append(table_line)
  306. # print("=====================>>")
  307. # for _line in _table:
  308. # for _cell in _line:
  309. # print(_cell,end="\t")
  310. # print("\n")
  311. # print("=====================>>")
  312. # print(_table)
  313. if fixspan:
  314. for _line in _table:
  315. for c_i in range(len(_line)):
  316. _cell = _line[c_i]
  317. if _cell.get("columnspan")>1:
  318. _cospan = _cell.get("columnspan")
  319. _cell["columnspan"] = 1
  320. for i in range(1,_cospan):
  321. _line.insert(c_i,_cell)
  322. for l_i in range(len(_table)):
  323. _line = _table[l_i]
  324. for c_i in range(len(_line)):
  325. _cell = _line[c_i]
  326. if _cell.get("rowspan")>1:
  327. _rospan = _cell.get("rowspan")
  328. _cell["rowspan"] = 1
  329. for i in range(1,_rospan):
  330. _table[l_i+i].insert(c_i,_cell)
  331. table_bbox = (_table[0][0].get("bbox")[0],_table[0][0].get("bbox")[1],_table[-1][-1].get("bbox")[2],_table[-1][-1].get("bbox")[3])
  332. ta = {"bbox":table_bbox,"table":_table}
  333. return ta
  334. def rect2table(self,list_textbox,list_rect,in_objs,margin=0.2,fixspan=True):
  335. _table = []
  336. set_x = set()
  337. set_y = set()
  338. clusters_rects = []
  339. #根据y1聚类
  340. list_rect.sort(key=lambda x:x.bbox[3])
  341. for _rect in list_rect:
  342. _y0 = _rect.bbox[3]
  343. _find = False
  344. for l_cr in clusters_rects:
  345. if abs(l_cr[0].bbox[3]-_y0)<2:
  346. _find = True
  347. l_cr.append(_rect)
  348. break
  349. if not _find:
  350. clusters_rects.append([_rect])
  351. clusters_rects.sort(key=lambda x:x[0].bbox[3],reverse=True)
  352. for l_cr in clusters_rects:
  353. l_cr.sort(key=lambda x:x.bbox[0])
  354. #cul spans
  355. for _line in clusters_rects:
  356. for _rect in _line:
  357. (x0,y0,x1,y1) = _rect.bbox
  358. set_x.add(x0)
  359. set_x.add(x1)
  360. set_y.add(y0)
  361. set_y.add(y1)
  362. if len(set_x)==0 or len(set_y)==0:
  363. return
  364. list_x = list(set_x)
  365. list_y = list(set_y)
  366. list_x.sort(key=lambda x:x)
  367. list_y.sort(key=lambda x:x,reverse=True)
  368. pop_x = []
  369. for i in range(len(list_x)-1):
  370. _i = len(list_x)-i-1
  371. l_i = _i-1
  372. if abs(list_x[_i]-list_x[l_i])<2:
  373. pop_x.append(_i)
  374. pop_x.sort(key=lambda x:x,reverse=True)
  375. for _x in pop_x:
  376. list_x.pop(_x)
  377. #
  378. pop_x = []
  379. for i in range(len(list_y)-1):
  380. _i = len(list_y)-i-1
  381. l_i = _i-1
  382. if abs(list_y[_i]-list_y[l_i])<2:
  383. pop_x.append(_i)
  384. pop_x.sort(key=lambda x:x,reverse=True)
  385. for _x in pop_x:
  386. list_y.pop(_x)
  387. print(list_x)
  388. print(list_y)
  389. for _line in clusters_rects:
  390. table_line = []
  391. for _rect in _line:
  392. (x0,y0,x1,y1) = _rect.bbox
  393. _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect,"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":""}
  394. table_line.append(_cell)
  395. _table.append(table_line)
  396. list_textbox.sort(key=lambda x:x.bbox[0])
  397. list_textbox.sort(key=lambda x:x.bbox[3],reverse=True)
  398. for textbox in list_textbox:
  399. (x0,y0,x1,y1) = textbox.bbox
  400. _text = textbox.get_text()
  401. _find = False
  402. for table_line in _table:
  403. for _cell in table_line:
  404. if self.inbox(textbox.bbox,_cell["bbox"]):
  405. _cell["text"]+= _text
  406. in_objs.add(textbox)
  407. _find = True
  408. break
  409. if _find:
  410. break
  411. if fixspan:
  412. for _line in _table:
  413. for c_i in range(len(_line)):
  414. _cell = _line[c_i]
  415. if _cell.get("columnspan")>1:
  416. _cospan = _cell.get("columnspan")
  417. _cell["columnspan"] = 1
  418. for i in range(1,_cospan):
  419. _line.insert(c_i,_cell)
  420. for l_i in range(len(_table)):
  421. _line = _table[l_i]
  422. for c_i in range(len(_line)):
  423. _cell = _line[c_i]
  424. if _cell.get("rowspan")>1:
  425. _rospan = _cell.get("rowspan")
  426. _cell["rowspan"] = 1
  427. for i in range(1,_rospan):
  428. if l_i+i<len(_table)-1:
  429. print(len(_table),l_i+i)
  430. _table[l_i+i].insert(c_i,_cell)
  431. # print("=======")
  432. # for _line in _table:
  433. # for _cell in _line:
  434. # print("[%s]"%_cell.get("text")[:10].replace("\n",''),end="\t\t")
  435. # print("\n")
  436. # print("===========")
  437. table_bbox = (_table[0][0].get("bbox")[0],_table[0][0].get("bbox")[1],_table[-1][-1].get("bbox")[2],_table[-1][-1].get("bbox")[3])
  438. ta = {"bbox":table_bbox,"table":_table}
  439. return ta
  440. def inbox(self,bbox0,bbox_g):
  441. # if bbox_g[0]<=bbox0[0] and bbox_g[1]<=bbox0[1] and bbox_g[2]>=bbox0[2] and bbox_g[3]>=bbox0[3]:
  442. # return 1
  443. if self.getIOU(bbox0,bbox_g)>0.5:
  444. return 1
  445. return 0
  446. def getIOU(self,bbox0,bbox1):
  447. width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0])
  448. height = max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1])-(bbox0[3]-bbox0[1]+bbox1[3]-bbox1[1])
  449. if width<0 and height<0:
  450. return abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))
  451. return 0
  452. def getspan(self,_list,x0,x1,margin):
  453. _count = 0
  454. (x0,x1) = (min(x0,x1),max(x0,x1))
  455. for _x in _list:
  456. if _x>=(x0-margin) and _x<=(x1+margin):
  457. _count += 1
  458. return _count-1