iou_utils.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from collections import namedtuple
  4. import numpy as np
  5. from shapely.geometry import Polygon
  6. import cv2
  7. def iou_rotate(box_a, box_b, method='union'):
  8. rect_a = cv2.minAreaRect(box_a)
  9. rect_b = cv2.minAreaRect(box_b)
  10. r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b)
  11. if r1[0] == 0:
  12. return 0
  13. else:
  14. inter_area = cv2.contourArea(r1[1])
  15. area_a = cv2.contourArea(box_a)
  16. area_b = cv2.contourArea(box_b)
  17. union_area = area_a + area_b - inter_area
  18. if union_area == 0 or inter_area == 0:
  19. return 0
  20. if method == 'union':
  21. iou = inter_area / union_area
  22. elif method == 'intersection':
  23. iou = inter_area / min(area_a, area_b)
  24. else:
  25. raise NotImplementedError
  26. return iou
  27. class DetectionIoUEvaluator(object):
  28. def __init__(self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5):
  29. self.is_output_polygon = is_output_polygon
  30. self.iou_constraint = iou_constraint
  31. self.area_precision_constraint = area_precision_constraint
  32. def evaluate_image(self, gt, pred):
  33. def get_union(pD, pG):
  34. return Polygon(pD).union(Polygon(pG)).area
  35. def get_intersection_over_union(pD, pG):
  36. return get_intersection(pD, pG) / get_union(pD, pG)
  37. def get_intersection(pD, pG):
  38. return Polygon(pD).intersection(Polygon(pG)).area
  39. def compute_ap(confList, matchList, numGtCare):
  40. correct = 0
  41. AP = 0
  42. if len(confList) > 0:
  43. confList = np.array(confList)
  44. matchList = np.array(matchList)
  45. sorted_ind = np.argsort(-confList)
  46. confList = confList[sorted_ind]
  47. matchList = matchList[sorted_ind]
  48. for n in range(len(confList)):
  49. match = matchList[n]
  50. if match:
  51. correct += 1
  52. AP += float(correct) / (n + 1)
  53. if numGtCare > 0:
  54. AP /= numGtCare
  55. return AP
  56. perSampleMetrics = {}
  57. matchedSum = 0
  58. Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
  59. numGlobalCareGt = 0
  60. numGlobalCareDet = 0
  61. arrGlobalConfidences = []
  62. arrGlobalMatches = []
  63. recall = 0
  64. precision = 0
  65. hmean = 0
  66. detMatched = 0
  67. iouMat = np.empty([1, 1])
  68. gtPols = []
  69. detPols = []
  70. gtPolPoints = []
  71. detPolPoints = []
  72. # Array of Ground Truth Polygons' keys marked as don't Care
  73. gtDontCarePolsNum = []
  74. # Array of Detected Polygons' matched with a don't Care GT
  75. detDontCarePolsNum = []
  76. pairs = []
  77. detMatchedNums = []
  78. arrSampleConfidences = []
  79. arrSampleMatch = []
  80. evaluationLog = ""
  81. for n in range(len(gt)):
  82. points = gt[n]['points']
  83. # transcription = gt[n]['text']
  84. dontCare = gt[n]['ignore']
  85. if not Polygon(points).is_valid or not Polygon(points).is_simple:
  86. continue
  87. gtPol = points
  88. gtPols.append(gtPol)
  89. gtPolPoints.append(points)
  90. if dontCare:
  91. gtDontCarePolsNum.append(len(gtPols) - 1)
  92. evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(
  93. gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
  94. for n in range(len(pred)):
  95. points = pred[n]['points']
  96. if not Polygon(points).is_valid or not Polygon(points).is_simple:
  97. continue
  98. detPol = points
  99. detPols.append(detPol)
  100. detPolPoints.append(points)
  101. if len(gtDontCarePolsNum) > 0:
  102. for dontCarePol in gtDontCarePolsNum:
  103. dontCarePol = gtPols[dontCarePol]
  104. intersected_area = get_intersection(dontCarePol, detPol)
  105. pdDimensions = Polygon(detPol).area
  106. precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
  107. if (precision > self.area_precision_constraint):
  108. detDontCarePolsNum.append(len(detPols) - 1)
  109. break
  110. evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(
  111. detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
  112. if len(gtPols) > 0 and len(detPols) > 0:
  113. # Calculate IoU and precision matrixs
  114. outputShape = [len(gtPols), len(detPols)]
  115. iouMat = np.empty(outputShape)
  116. gtRectMat = np.zeros(len(gtPols), np.int8)
  117. detRectMat = np.zeros(len(detPols), np.int8)
  118. # if self.is_output_polygon:
  119. for gtNum in range(len(gtPols)):
  120. for detNum in range(len(detPols)):
  121. pG = gtPols[gtNum]
  122. pD = detPols[detNum]
  123. iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
  124. # else:
  125. # # gtPols = np.float32(gtPols)
  126. # # detPols = np.float32(detPols)
  127. # for gtNum in range(len(gtPols)):
  128. # for detNum in range(len(detPols)):
  129. # pG = np.float32(gtPols[gtNum])
  130. # pD = np.float32(detPols[detNum])
  131. # iouMat[gtNum, detNum] = iou_rotate(pD, pG)
  132. for gtNum in range(len(gtPols)):
  133. for detNum in range(len(detPols)):
  134. if gtRectMat[gtNum] == 0 and detRectMat[
  135. detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
  136. if iouMat[gtNum, detNum] > self.iou_constraint:
  137. gtRectMat[gtNum] = 1
  138. detRectMat[detNum] = 1
  139. detMatched += 1
  140. pairs.append({'gt': gtNum, 'det': detNum})
  141. detMatchedNums.append(detNum)
  142. evaluationLog += "Match GT #" + \
  143. str(gtNum) + " with Det #" + str(detNum) + "\n"
  144. numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
  145. numDetCare = (len(detPols) - len(detDontCarePolsNum))
  146. if numGtCare == 0:
  147. recall = float(1)
  148. precision = float(0) if numDetCare > 0 else float(1)
  149. else:
  150. recall = float(detMatched) / numGtCare
  151. precision = 0 if numDetCare == 0 else float(
  152. detMatched) / numDetCare
  153. hmean = 0 if (precision + recall) == 0 else 2.0 * \
  154. precision * recall / (precision + recall)
  155. matchedSum += detMatched
  156. numGlobalCareGt += numGtCare
  157. numGlobalCareDet += numDetCare
  158. perSampleMetrics = {
  159. 'precision': precision,
  160. 'recall': recall,
  161. 'hmean': hmean,
  162. 'pairs': pairs,
  163. 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
  164. 'gtPolPoints': gtPolPoints,
  165. 'detPolPoints': detPolPoints,
  166. 'gtCare': numGtCare,
  167. 'detCare': numDetCare,
  168. 'gtDontCare': gtDontCarePolsNum,
  169. 'detDontCare': detDontCarePolsNum,
  170. 'detMatched': detMatched,
  171. 'evaluationLog': evaluationLog
  172. }
  173. return perSampleMetrics
  174. def combine_results(self, results):
  175. numGlobalCareGt = 0
  176. numGlobalCareDet = 0
  177. matchedSum = 0
  178. for result in results:
  179. numGlobalCareGt += result['gtCare']
  180. numGlobalCareDet += result['detCare']
  181. matchedSum += result['detMatched']
  182. methodRecall = 0 if numGlobalCareGt == 0 else float(
  183. matchedSum) / numGlobalCareGt
  184. methodPrecision = 0 if numGlobalCareDet == 0 else float(
  185. matchedSum) / numGlobalCareDet
  186. methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
  187. methodRecall * methodPrecision / (
  188. methodRecall + methodPrecision)
  189. methodMetrics = {'precision': methodPrecision,
  190. 'recall': methodRecall, 'hmean': methodHmean}
  191. return methodMetrics
  192. if __name__ == '__main__':
  193. evaluator = DetectionIoUEvaluator()
  194. preds = [[{
  195. 'points': [(0.1, 0.1), (0.5, 0), (0.5, 1), (0, 1)],
  196. 'text': 1234,
  197. 'ignore': False,
  198. }, {
  199. 'points': [(0.5, 0.1), (1, 0), (1, 1), (0.5, 1)],
  200. 'text': 5678,
  201. 'ignore': False,
  202. }]]
  203. gts = [[{
  204. 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
  205. 'text': 123,
  206. 'ignore': False,
  207. }]]
  208. results = []
  209. for gt, pred in zip(gts, preds):
  210. results.append(evaluator.evaluate_image(gt, pred))
  211. metrics = evaluator.combine_results(results)
  212. print(metrics)