table_detect.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Sep 9 23:11:51 2020
  5. table detect with yolo
  6. @author: chineseocr
  7. """
  8. import cv2
  9. import numpy as np
  10. from config import tableModelDetectPath
  11. from utils import nms_box,letterbox_image,rectangle
  12. tableDetectNet = cv2.dnn.readNetFromDarknet(tableModelDetectPath.replace('.weights', '.cfg'),
  13. tableModelDetectPath)
  14. def table_detect(img, sc=(416, 416), thresh=0.5, NMSthresh=0.3):
  15. """
  16. 表格检测
  17. img:GBR
  18. """
  19. scale =sc[0]
  20. img_height,img_width = img.shape[:2]
  21. inputBlob,fx,fy = letterbox_image(img[...,::-1],(scale,scale))
  22. inputBlob = cv2.dnn.blobFromImage(inputBlob, scalefactor=1.0, size=(scale,scale),swapRB=True ,crop=False);
  23. tableDetectNet.setInput(inputBlob/255.0)
  24. outputName = tableDetectNet.getUnconnectedOutLayersNames()
  25. outputs = tableDetectNet.forward(outputName)
  26. class_ids = []
  27. confidences = []
  28. boxes = []
  29. for output in outputs:
  30. for detection in output:
  31. scores = detection[5:]
  32. class_id = np.argmax(scores)
  33. confidence = scores[class_id]
  34. if confidence > thresh:
  35. center_x = int(detection[0] * scale/fx)
  36. center_y = int(detection[1] * scale/fy)
  37. width = int(detection[2] * scale/fx)
  38. height = int(detection[3] * scale/fy)
  39. left = int(center_x - width / 2)
  40. top = int(center_y - height / 2)
  41. if class_id == 1:
  42. class_ids.append(class_id)
  43. confidences.append(float(confidence))
  44. xmin, ymin, xmax, ymax = left, top, left+width, top+height
  45. xmin = max(xmin, 1)
  46. ymin = max(ymin, 1)
  47. xmax = min(xmax, img_width-1)
  48. ymax = min(ymax, img_height-1)
  49. boxes.append([xmin, ymin, xmax, ymax])
  50. boxes = np.array(boxes)
  51. confidences = np.array(confidences)
  52. if len(boxes)>0:
  53. boxes, confidences = nms_box(boxes, confidences, score_threshold=thresh, nms_threshold=NMSthresh)
  54. boxes, adBoxes = fix_table_box_for_table_line(boxes, confidences, img)
  55. return boxes, adBoxes, confidences
  56. def point_in_box(p, box):
  57. x,y = p
  58. xmin,ymin,xmax,ymax = box
  59. if xmin<=x<=xmin and ymin<=y<=ymax:
  60. return True
  61. else:
  62. return False
  63. def fix_table_box_for_table_line(boxes, confidences, img):
  64. # 修正表格用于表格线检测
  65. h, w = img.shape[:2]
  66. n = len(boxes)
  67. adBoxes = []
  68. for i in range(n):
  69. prob = confidences[i]
  70. xmin, ymin, xmax, ymax = boxes[i]
  71. padx = (xmax-xmin)*(1-prob)
  72. padx = padx
  73. pady = (ymax-ymin)*(1-prob)
  74. pady = pady
  75. xminNew = max(xmin-padx,1)
  76. yminNew = max(ymin-pady,1)
  77. xmaxNew = min(xmax+padx,w)
  78. ymaxNew = min(ymax+pady,h)
  79. adBoxes.append([xminNew, yminNew, xmaxNew, ymaxNew])
  80. return boxes, adBoxes
  81. if __name__ == '__main__':
  82. import time
  83. p = 'train_463.jpg'
  84. img = cv2.imread(p)
  85. t = time.time()
  86. boxes, adBoxes, scores = table_detect(img, sc=(416, 416), thresh=0.5, NMSthresh=0.3)
  87. print("time", time.time()-t)
  88. print("boxes", boxes)
  89. print("adBoxes", adBoxes)
  90. print("scores", scores)
  91. img = rectangle(img, adBoxes)
  92. # img.save('img/table-detect.png')
  93. img.show('table detect')