tableutils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. from pdfminer.layout import *
  2. class LineTable():
  3. def recognize_table(self,list_textbox,list_line):
  4. self.list_line = list_line
  5. self.list_crosspoints = self.recognize_crosspoints(list_line)
  6. #聚类
  7. cluster_crosspoints = []
  8. for _point in self.list_crosspoints:
  9. cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
  10. while 1:
  11. _find = False
  12. new_cluster_crosspoints = []
  13. for l_point in cluster_crosspoints:
  14. _flag = False
  15. for l_n_point in new_cluster_crosspoints:
  16. line1 = l_point.get("lines")
  17. line2 = l_n_point.get("lines")
  18. if len(line1&line2)>0:
  19. _find = True
  20. _flag = True
  21. l_n_point["lines"] = line1.union(line2)
  22. l_n_point["points"].extend(l_point["points"])
  23. if not _flag:
  24. new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
  25. cluster_crosspoints = new_cluster_crosspoints
  26. if not _find:
  27. break
  28. # print(len(cluster_crosspoints))
  29. list_l_rect = []
  30. for table_crosspoint in cluster_crosspoints:
  31. list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
  32. list_l_rect.append(list_rect)
  33. in_objs = set()
  34. list_tables = []
  35. for l_rect in list_l_rect:
  36. _ta = self.rect2table(list_textbox,l_rect,in_objs)
  37. if _ta:
  38. list_tables.append(_ta)
  39. return list_tables,in_objs,list_l_rect
  40. def recognize_table_by_rect(self,list_textbox,list_rect,margin=2):
  41. dump_margin = 5
  42. list_rect_tmp = []
  43. #去重
  44. for _rect in list_rect:
  45. if (_rect.bbox[3]-_rect.bbox[1]<10) or (abs(_rect.bbox[2]-_rect.bbox[0])<5):
  46. continue
  47. _find = False
  48. for _tmp in list_rect_tmp:
  49. for i in range(4):
  50. if abs(_rect.bbox[i]-_tmp.bbox[i])<dump_margin:
  51. pass
  52. else:
  53. _find = False
  54. break
  55. if i==3:
  56. _find = True
  57. if _find:
  58. break
  59. if not _find:
  60. list_rect_tmp.append(_rect)
  61. # print("=====",len(list_rect),len(list_rect_tmp))
  62. # print(list_rect_tmp)
  63. # from matplotlib import pyplot as plt
  64. # plt.figure()
  65. # for _rect in list_rect_tmp:
  66. # x0,y0,x1,y1 = _rect.bbox
  67. # plt.boxplot(_rect.bbox)
  68. # plt.show()
  69. cluster_rect = []
  70. for _rect in list_rect:
  71. _find = False
  72. for cr in cluster_rect:
  73. for cr_rect in cr:
  74. if abs((cr_rect.bbox[2]-cr_rect.bbox[0]+_rect.bbox[2]-_rect.bbox[0])-(max(cr_rect.bbox[2],_rect.bbox[2])-min(cr_rect.bbox[0],_rect.bbox[0])))<margin:
  75. _find = True
  76. cr.append(_rect)
  77. break
  78. elif abs((cr_rect.bbox[3]-cr_rect.bbox[1]+_rect.bbox[3]-_rect.bbox[1])-(max(cr_rect.bbox[3],_rect.bbox[3])-min(cr_rect.bbox[1],_rect.bbox[1])))<margin:
  79. _find = True
  80. cr.append(_rect)
  81. break
  82. if _find:
  83. break
  84. if not _find:
  85. cluster_rect.append([_rect])
  86. list_l_rect = cluster_rect
  87. in_objs = set()
  88. list_tables = []
  89. for l_rect in list_l_rect:
  90. _ta = self.rect2table(list_textbox,l_rect,in_objs)
  91. if _ta:
  92. list_tables.append(_ta)
  93. return list_tables,in_objs,list_l_rect
  94. def recognize_crosspoints(self,list_line,fixLine=True):
  95. from matplotlib import pyplot as plt
  96. # print("lines num",len(list_line))
  97. list_crosspoints = []
  98. for _i in range(len(list_line)):
  99. for _j in range(len(list_line)):
  100. line1 = list_line[_i].__dict__.get("bbox")
  101. line2 = list_line[_j].__dict__.get("bbox")
  102. exists,point = self.cross_point(line1,line2)
  103. if exists:
  104. list_crosspoints.append(point)
  105. if fixLine:
  106. #聚类
  107. cluster_crosspoints = []
  108. for _point in list_crosspoints:
  109. cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
  110. while 1:
  111. _find = False
  112. new_cluster_crosspoints = []
  113. for l_point in cluster_crosspoints:
  114. _flag = False
  115. for l_n_point in new_cluster_crosspoints:
  116. line1 = l_point.get("lines")
  117. line2 = l_n_point.get("lines")
  118. if len(line1&line2)>0:
  119. _find = True
  120. _flag = True
  121. l_n_point["lines"] = line1.union(line2)
  122. l_n_point["points"].extend(l_point["points"])
  123. if not _flag:
  124. new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
  125. cluster_crosspoints = new_cluster_crosspoints
  126. if not _find:
  127. break
  128. for list_cp in cluster_crosspoints:
  129. points = list_cp.get("points")
  130. l_lines = []
  131. for p in points:
  132. l_lines.extend(p.get("p_lines"))
  133. l_lines = list(set(l_lines))
  134. l_lines.sort(key=lambda x:x.bbox[0])
  135. min_x = l_lines[0].bbox[0]+2
  136. l_lines.sort(key=lambda x:x.bbox[1])
  137. min_y = l_lines[0].bbox[1]+2
  138. l_lines.sort(key=lambda x:x.bbox[2])
  139. max_x = l_lines[-1].bbox[2]-2
  140. l_lines.sort(key=lambda x:x.bbox[3])
  141. max_y = l_lines[-1].bbox[3]-2
  142. points.sort(key=lambda x:x.bbox[0])
  143. if abs(min_x-points[0].bbox[0])>10:
  144. list_line.append(LTLine(1,[(min_x,min_y),(min_x,max_y)]))
  145. points.sort(key=lambda x:x.bbox[1])
  146. if abs(min_y-points[0].bbox[1])>10:
  147. list_line.append(LTLine(1,[(min_x,min_y),(max_x,min_y)]))
  148. points.sort(key=lambda x:x.bbox[2])
  149. if abs(max_x-points[-1].bbox[2])>10:
  150. list_line.append(LTLine(1,[(max_x,min_y),(max_x,max_y)]))
  151. points.sort(key=lambda x:x.bbox[3])
  152. if abs(max_y-points[-1].bbox[3])>10:
  153. list_line.append(LTLine(1,[(min_x,max_y),(max_x,max_y)]))
  154. list_crosspoints = []
  155. for _i in range(len(list_line)):
  156. for _j in range(len(list_line)):
  157. line1 = list_line[_i].__dict__.get("bbox")
  158. line2 = list_line[_j].__dict__.get("bbox")
  159. exists,point = self.cross_point(line1,line2)
  160. if exists:
  161. list_crosspoints.append(point)
  162. # plt.figure()
  163. # for _line in list_line:
  164. # x0,y0,x1,y1 = _line.__dict__.get("bbox")
  165. # plt.plot([x0,x1],[y0,y1])
  166. # for _line in list_line:
  167. # x0,y0,x1,y1 = _line.bbox
  168. # plt.plot([x0,x1],[y0,y1])
  169. # for point in list_crosspoints:
  170. # plt.scatter(point.get("point")[0],point.get("point")[1])
  171. # plt.show()
  172. # print(list_crosspoints)
  173. # print("points num",len(list_crosspoints))
  174. return list_crosspoints
  175. def recognize_rect(self,_page):
  176. list_line = []
  177. for _obj in _page._objs:
  178. if isinstance(_obj,(LTLine)):
  179. list_line.append(_obj)
  180. list_crosspoints = self.recognize_crosspoints(list_line)
  181. #聚类
  182. cluster_crosspoints = []
  183. for _point in list_crosspoints:
  184. cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
  185. while 1:
  186. _find = False
  187. new_cluster_crosspoints = []
  188. for l_point in cluster_crosspoints:
  189. _flag = False
  190. for l_n_point in new_cluster_crosspoints:
  191. line1 = l_point.get("lines")
  192. line2 = l_n_point.get("lines")
  193. if len(line1&line2)>0:
  194. _find = True
  195. _flag = True
  196. l_n_point["lines"] = line1.union(line2)
  197. l_n_point["points"].extend(l_point["points"])
  198. if not _flag:
  199. new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
  200. cluster_crosspoints = new_cluster_crosspoints
  201. if not _find:
  202. break
  203. # print(len(cluster_crosspoints))
  204. list_l_rect = []
  205. for table_crosspoint in cluster_crosspoints:
  206. list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
  207. list_l_rect.append(list_rect)
  208. return list_l_rect
  209. def crosspoint2rect(self,list_crosspoint,margin=4):
  210. dict_line_points = {}
  211. for _point in list_crosspoint:
  212. lines = list(_point.get("lines"))
  213. for _line in lines:
  214. if _line not in dict_line_points:
  215. dict_line_points[_line] = {"direct":None,"points":[]}
  216. dict_line_points[_line]["points"].append(_point)
  217. #排序
  218. for k,v in dict_line_points.items():
  219. list_x = []
  220. list_y = []
  221. for _p in v["points"]:
  222. list_x.append(_p.get("point")[0])
  223. list_y.append(_p.get("point")[1])
  224. if max(list_x)-min(list_x)>max(list_y)-min(list_y):
  225. v.get("points").sort(key=lambda x:x.get("point")[0])
  226. v["direct"] = "row"
  227. else:
  228. v.get("points").sort(key=lambda x:x.get("point")[1])
  229. v["direct"] = "column"
  230. list_rect = []
  231. for _point in list_crosspoint:
  232. if _point["buttom"]>=margin and _point["right"]>=margin:
  233. lines = list(_point.get("lines"))
  234. _line = lines[0]
  235. if dict_line_points[_line]["direct"]=="column":
  236. _line = lines[1]
  237. next_point = None
  238. for p1 in dict_line_points[_line]["points"]:
  239. if p1["buttom"]>=margin and p1["point"][0]>_point["point"][0]:
  240. next_point = p1
  241. break
  242. if not next_point:
  243. continue
  244. lines = list(next_point.get("lines"))
  245. _line = lines[0]
  246. if dict_line_points[_line]["direct"]=="row":
  247. _line = lines[1]
  248. final_point = None
  249. for p1 in dict_line_points[_line]["points"]:
  250. if p1["left"]>=margin and p1["point"][1]>next_point["point"][1]:
  251. final_point = p1
  252. break
  253. if not final_point:
  254. continue
  255. _r = LTRect(1,(_point["point"][0],_point["point"][1],final_point["point"][0],final_point["point"][1]))
  256. list_rect.append(_r)
  257. return list_rect
  258. def cross_point(self,line1, line2,segment=True,margin=2):
  259. point_is_exist = False
  260. x = y = 0
  261. x1,y1,x2,y2 = line1
  262. x3,y3,x4,y4 = line2
  263. if (x2 - x1) == 0:
  264. k1 = None
  265. b1 = 0
  266. else:
  267. k1 = (y2 - y1) * 1.0 / (x2 - x1) # 计算k1,由于点均为整数,需要进行浮点数转化
  268. b1 = y1 * 1.0 - x1 * k1 * 1.0 # 整型转浮点型是关键
  269. if (x4 - x3) == 0: # L2直线斜率不存在
  270. k2 = None
  271. b2 = 0
  272. else:
  273. k2 = (y4 - y3) * 1.0 / (x4 - x3) # 斜率存在
  274. b2 = y3 * 1.0 - x3 * k2 * 1.0
  275. if k1 is None:
  276. if not k2 is None:
  277. x = x1
  278. y = k2 * x1 + b2
  279. point_is_exist = True
  280. elif k2 is None:
  281. x = x3
  282. y = k1 * x3 + b1
  283. elif not k2 == k1:
  284. x = (b2 - b1) * 1.0 / (k1 - k2)
  285. y = k1 * x * 1.0 + b1 * 1.0
  286. point_is_exist = True
  287. left = 0
  288. right = 0
  289. top = 0
  290. buttom = 0
  291. if point_is_exist:
  292. if segment:
  293. if x>=(min(x1,x2)-margin) and x<=(max(x1,x2)+margin) and y>=(min(y1,y2)-margin) and y<=(max(y1,y2)+margin):
  294. if x>=(min(x3,x4)-margin) and x<=(max(x3,x4)+margin) and y>=(min(y3,y4)-margin) and y<=(max(y3,y4)+margin):
  295. point_is_exist = True
  296. left = abs(min(x1,x3)-x)
  297. right = abs(max(x2,x4)-x)
  298. top = abs(min(y1,y3)-y)
  299. buttom = abs(max(y2,y4)-y)
  300. else:
  301. point_is_exist = False
  302. else:
  303. point_is_exist = False
  304. line1_key = "%.2f-%.2f-%.2f-%.2f"%(x1,y1,x2,y2)
  305. line2_key = "%.2f-%.2f-%.2f-%.2f"%(x3,y3,x4,y4)
  306. return point_is_exist, {"point":[x, y],"left":left,"right":right,"top":top,"buttom":buttom,"lines":set([line1_key,line2_key]),"p_lines":[line1,line2]}
  307. def unionTable(self,list_table,fixspan=True,margin=2):
  308. set_x = set()
  309. set_y = set()
  310. list_cell = []
  311. for _t in list_table:
  312. for _line in _t:
  313. list_cell.extend(_line)
  314. clusters_rects = []
  315. #根据y1聚类
  316. set_id = set()
  317. list_cell_dump = []
  318. for _cell in list_cell:
  319. _id = id(_cell)
  320. if _id in set_id:
  321. continue
  322. set_id.add(_id)
  323. list_cell_dump.append(_cell)
  324. list_cell = list_cell_dump
  325. list_cell.sort(key=lambda x:x.get("bbox")[3])
  326. for _rect in list_cell:
  327. _y0 = _rect.get("bbox")[3]
  328. _find = False
  329. for l_cr in clusters_rects:
  330. if abs(l_cr[0].get("bbox")[3]-_y0)<2:
  331. _find = True
  332. l_cr.append(_rect)
  333. break
  334. if not _find:
  335. clusters_rects.append([_rect])
  336. clusters_rects.sort(key=lambda x:x[0].get("bbox")[3],reverse=True)
  337. for l_cr in clusters_rects:
  338. l_cr.sort(key=lambda x:x.get("bbox")[0])
  339. for l_r in clusters_rects:
  340. print(len(l_r))
  341. for _line in clusters_rects:
  342. for _rect in _line:
  343. (x0,y0,x1,y1) = _rect.get("bbox")
  344. set_x.add(x0)
  345. set_x.add(x1)
  346. set_y.add(y0)
  347. set_y.add(y1)
  348. if len(set_x)==0 or len(set_y)==0:
  349. return
  350. list_x = list(set_x)
  351. list_y = list(set_y)
  352. list_x.sort(key=lambda x:x)
  353. list_y.sort(key=lambda x:x,reverse=True)
  354. _table = []
  355. for _line in clusters_rects:
  356. table_line = []
  357. for _rect in _line:
  358. (x0,y0,x1,y1) = _rect.get("bbox")
  359. _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect.get("rect"),"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":_rect.get("text","")}
  360. table_line.append(_cell)
  361. _table.append(table_line)
  362. # print("=====================>>")
  363. # for _line in _table:
  364. # for _cell in _line:
  365. # print(_cell,end="\t")
  366. # print("\n")
  367. # print("=====================>>")
  368. # print(_table)
  369. if fixspan:
  370. for _line in _table:
  371. for c_i in range(len(_line)):
  372. _cell = _line[c_i]
  373. if _cell.get("columnspan")>1:
  374. _cospan = _cell.get("columnspan")
  375. _cell["columnspan"] = 1
  376. for i in range(1,_cospan):
  377. _line.insert(c_i,_cell)
  378. for l_i in range(len(_table)):
  379. _line = _table[l_i]
  380. for c_i in range(len(_line)):
  381. _cell = _line[c_i]
  382. if _cell.get("rowspan")>1:
  383. _rospan = _cell.get("rowspan")
  384. _cell["rowspan"] = 1
  385. for i in range(1,_rospan):
  386. _table[l_i+i].insert(c_i,_cell)
  387. table_bbox = (_table[0][0].get("bbox")[0],_table[0][0].get("bbox")[1],_table[-1][-1].get("bbox")[2],_table[-1][-1].get("bbox")[3])
  388. ta = {"bbox":table_bbox,"table":_table}
  389. return ta
  390. def rect2table(self,list_textbox,list_rect,in_objs,margin=0.2,fixspan=True):
  391. _table = []
  392. set_x = set()
  393. set_y = set()
  394. clusters_rects = []
  395. #根据y1聚类
  396. list_rect.sort(key=lambda x:x.bbox[3])
  397. for _rect in list_rect:
  398. _y0 = _rect.bbox[3]
  399. _find = False
  400. for l_cr in clusters_rects:
  401. if abs(l_cr[0].bbox[3]-_y0)<2:
  402. _find = True
  403. l_cr.append(_rect)
  404. break
  405. if not _find:
  406. clusters_rects.append([_rect])
  407. clusters_rects.sort(key=lambda x:x[0].bbox[3],reverse=True)
  408. for l_cr in clusters_rects:
  409. l_cr.sort(key=lambda x:x.bbox[0])
  410. #cul spans
  411. for _line in clusters_rects:
  412. for _rect in _line:
  413. (x0,y0,x1,y1) = _rect.bbox
  414. set_x.add(x0)
  415. set_x.add(x1)
  416. set_y.add(y0)
  417. set_y.add(y1)
  418. if len(set_x)==0 or len(set_y)==0:
  419. return
  420. list_x = list(set_x)
  421. list_y = list(set_y)
  422. list_x.sort(key=lambda x:x)
  423. list_y.sort(key=lambda x:x,reverse=True)
  424. pop_x = []
  425. for i in range(len(list_x)-1):
  426. _i = len(list_x)-i-1
  427. l_i = _i-1
  428. if abs(list_x[_i]-list_x[l_i])<2:
  429. pop_x.append(_i)
  430. pop_x.sort(key=lambda x:x,reverse=True)
  431. for _x in pop_x:
  432. list_x.pop(_x)
  433. #
  434. pop_x = []
  435. for i in range(len(list_y)-1):
  436. _i = len(list_y)-i-1
  437. l_i = _i-1
  438. if abs(list_y[_i]-list_y[l_i])<2:
  439. pop_x.append(_i)
  440. pop_x.sort(key=lambda x:x,reverse=True)
  441. for _x in pop_x:
  442. list_y.pop(_x)
  443. print(list_x)
  444. print(list_y)
  445. for _line in clusters_rects:
  446. table_line = []
  447. for _rect in _line:
  448. (x0,y0,x1,y1) = _rect.bbox
  449. _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect,"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":""}
  450. table_line.append(_cell)
  451. _table.append(table_line)
  452. list_textbox.sort(key=lambda x:x.bbox[0])
  453. list_textbox.sort(key=lambda x:x.bbox[3],reverse=True)
  454. for textbox in list_textbox:
  455. (x0,y0,x1,y1) = textbox.bbox
  456. _text = textbox.get_text()
  457. _find = False
  458. for table_line in _table:
  459. for _cell in table_line:
  460. if self.inbox(textbox.bbox,_cell["bbox"]):
  461. _cell["text"]+= _text
  462. in_objs.add(textbox)
  463. _find = True
  464. break
  465. if _find:
  466. break
  467. if fixspan:
  468. for _line in _table:
  469. for c_i in range(len(_line)):
  470. _cell = _line[c_i]
  471. if _cell.get("columnspan")>1:
  472. _cospan = _cell.get("columnspan")
  473. _cell["columnspan"] = 1
  474. for i in range(1,_cospan):
  475. _line.insert(c_i,_cell)
  476. for l_i in range(len(_table)):
  477. _line = _table[l_i]
  478. for c_i in range(len(_line)):
  479. _cell = _line[c_i]
  480. if _cell.get("rowspan")>1:
  481. _rospan = _cell.get("rowspan")
  482. _cell["rowspan"] = 1
  483. for i in range(1,_rospan):
  484. if l_i+i<len(_table)-1:
  485. print(len(_table),l_i+i)
  486. _table[l_i+i].insert(c_i,_cell)
  487. # print("=======")
  488. # for _line in _table:
  489. # for _cell in _line:
  490. # print("[%s]"%_cell.get("text")[:10].replace("\n",''),end="\t\t")
  491. # print("\n")
  492. # print("===========")
  493. table_bbox = (_table[0][0].get("bbox")[0],_table[0][0].get("bbox")[1],_table[-1][-1].get("bbox")[2],_table[-1][-1].get("bbox")[3])
  494. ta = {"bbox":table_bbox,"table":_table}
  495. return ta
  496. def inbox(self,bbox0,bbox_g):
  497. # if bbox_g[0]<=bbox0[0] and bbox_g[1]<=bbox0[1] and bbox_g[2]>=bbox0[2] and bbox_g[3]>=bbox0[3]:
  498. # return 1
  499. if self.getIOU(bbox0,bbox_g)>0.5:
  500. return 1
  501. return 0
  502. def getIOU(self,bbox0,bbox1):
  503. width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0])
  504. height = max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1])-(bbox0[3]-bbox0[1]+bbox1[3]-bbox1[1])
  505. if width<0 and height<0:
  506. return abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))
  507. return 0
  508. def getspan(self,_list,x0,x1,margin):
  509. _count = 0
  510. (x0,x1) = (min(x0,x1),max(x0,x1))
  511. for _x in _list:
  512. if _x>=(x0-margin) and _x<=(x1+margin):
  513. _count += 1
  514. return _count-1