__init__.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import subprocess
  2. import os
  3. import numpy as np
  4. import cv2
  5. import torch
  6. BASE_DIR = os.path.dirname(os.path.realpath(__file__))
  7. if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value
  8. raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR))
  9. def pse_warpper(kernals, min_area=5):
  10. '''
  11. reference https://github.com/liuheng92/tensorflow_PSENet/blob/feature_dev/pse
  12. :param kernals:
  13. :param min_area:
  14. :return:
  15. '''
  16. from .pse import pse_cpp
  17. kernal_num = len(kernals)
  18. if not kernal_num:
  19. return np.array([]), []
  20. kernals = np.array(kernals)
  21. label_num, label = cv2.connectedComponents(kernals[0].astype(np.uint8), connectivity=4)
  22. label_values = []
  23. for label_idx in range(1, label_num):
  24. if np.sum(label == label_idx) < min_area:
  25. label[label == label_idx] = 0
  26. continue
  27. label_values.append(label_idx)
  28. pred = pse_cpp(label, kernals, c=kernal_num)
  29. return np.array(pred), label_values
  30. class pse_postprocess():
  31. def __init__(self, threshold=0.7311):
  32. self.threshold = threshold
  33. def __call__(self, preds, scale):
  34. """
  35. 在输出上使用sigmoid 将值转换为置信度,并使用阈值来进行文字和背景的区分
  36. :param preds: 网络输出
  37. :param scale: 网络的scale
  38. :param threshold: sigmoid的阈值
  39. :return: 最后的输出图和文本框
  40. """
  41. preds = torch.sigmoid(preds)
  42. preds = preds.detach().cpu().numpy()
  43. score = preds[-1].astype(np.float32)
  44. preds = preds > self.threshold
  45. # preds = preds * preds[-1] # 使用最大的kernel作为其他小图的mask,不使用的话效果更好
  46. pred, label_values = pse_warpper(preds, 5)
  47. bbox_list = []
  48. for label_value in label_values:
  49. points = np.array(np.where(pred == label_value)).transpose((1, 0))[:, ::-1]
  50. if points.shape[0] < 800 / (scale * scale):
  51. continue
  52. score_i = np.mean(score[pred == label_value])
  53. if score_i < 0.93:
  54. continue
  55. rect = cv2.minAreaRect(points)
  56. bbox = cv2.boxPoints(rect)
  57. bbox_list.append([bbox[1], bbox[2], bbox[3], bbox[0]])
  58. return pred, np.array(bbox_list)