DetDateSetPse.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2018/6/11 15:54
  3. # @Author : zhoujun
  4. import os
  5. import math
  6. import random
  7. import numbers
  8. import pathlib
  9. import pyclipper
  10. from torch.utils import data
  11. import glob
  12. import numpy as np
  13. import cv2
  14. from skimage.util import random_noise
  15. import json
  16. from tqdm import tqdm
  17. from torchvision import transforms
  18. # from utils.utils import draw_bbox
  19. # 图像均为cv2读取
  20. class DataAugment():
  21. def __init__(self):
  22. pass
  23. def add_noise(self, im: np.ndarray):
  24. """
  25. 对图片加噪声
  26. :param img: 图像array
  27. :return: 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
  28. """
  29. return (random_noise(im, mode='gaussian', clip=True) * 255).astype(im.dtype)
  30. def random_scale(self, im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray or list) -> tuple:
  31. """
  32. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  33. :param im: 原图
  34. :param text_polys: 文本框
  35. :param scales: 尺度
  36. :return: 经过缩放的图片和文本
  37. """
  38. tmp_text_polys = text_polys.copy()
  39. rd_scale = float(np.random.choice(scales))
  40. im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
  41. tmp_text_polys *= rd_scale
  42. return im, tmp_text_polys
  43. def random_rotate_img_bbox(self, img, text_polys, degrees: numbers.Number or list or tuple or np.ndarray,
  44. same_size=False):
  45. """
  46. 从给定的角度中选择一个角度,对图片和文本框进行旋转
  47. :param img: 图片
  48. :param text_polys: 文本框
  49. :param degrees: 角度,可以是一个数值或者list
  50. :param same_size: 是否保持和原图一样大
  51. :return: 旋转后的图片和角度
  52. """
  53. if isinstance(degrees, numbers.Number):
  54. if degrees < 0:
  55. raise ValueError("If degrees is a single number, it must be positive.")
  56. degrees = (-degrees, degrees)
  57. elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray):
  58. if len(degrees) != 2:
  59. raise ValueError("If degrees is a sequence, it must be of len 2.")
  60. degrees = degrees
  61. else:
  62. raise Exception('degrees must in Number or list or tuple or np.ndarray')
  63. # ---------------------- 旋转图像 ----------------------
  64. w = img.shape[1]
  65. h = img.shape[0]
  66. angle = np.random.uniform(degrees[0], degrees[1])
  67. if same_size:
  68. nw = w
  69. nh = h
  70. else:
  71. # 角度变弧度
  72. rangle = np.deg2rad(angle)
  73. # 计算旋转之后图像的w, h
  74. nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
  75. nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
  76. # 构造仿射矩阵
  77. rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
  78. # 计算原图中心点到新图中心点的偏移量
  79. rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
  80. # 更新仿射矩阵
  81. rot_mat[0, 2] += rot_move[0]
  82. rot_mat[1, 2] += rot_move[1]
  83. # 仿射变换
  84. rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
  85. # ---------------------- 矫正bbox坐标 ----------------------
  86. # rot_mat是最终的旋转矩阵
  87. # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
  88. rot_text_polys = list()
  89. for bbox in text_polys:
  90. point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
  91. point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
  92. point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
  93. point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
  94. rot_text_polys.append([point1, point2, point3, point4])
  95. return rot_img, np.array(rot_text_polys, dtype=np.float32)
  96. def random_crop_img_bboxes(self, im: np.ndarray, text_polys: np.ndarray, max_tries=50) -> tuple:
  97. """
  98. 从图片中裁剪出 cropsize大小的图片和对应区域的文本框
  99. :param im: 图片
  100. :param text_polys: 文本框
  101. :param max_tries: 最大尝试次数
  102. :return: 裁剪后的图片和文本框
  103. """
  104. h, w, _ = im.shape
  105. pad_h = h // 10
  106. pad_w = w // 10
  107. h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
  108. w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
  109. for poly in text_polys:
  110. poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入取整
  111. minx = np.min(poly[:, 0])
  112. maxx = np.max(poly[:, 0])
  113. w_array[minx + pad_w:maxx + pad_w] = 1 # 将文本区域的在w_array上设为1,表示x轴方向上这部分位置有文本
  114. miny = np.min(poly[:, 1])
  115. maxy = np.max(poly[:, 1])
  116. h_array[miny + pad_h:maxy + pad_h] = 1 # 将文本区域的在h_array上设为1,表示y轴方向上这部分位置有文本
  117. # 在两个轴上 拿出背景位置去进行随机的位置选择,避免选择的区域穿过文本
  118. h_axis = np.where(h_array == 0)[0]
  119. w_axis = np.where(w_array == 0)[0]
  120. if len(h_axis) == 0 or len(w_axis) == 0:
  121. # 整张图全是文本的情况下,直接返回
  122. return im, text_polys
  123. for i in range(max_tries):
  124. xx = np.random.choice(w_axis, size=2)
  125. # 对选择区域进行边界控制
  126. xmin = np.min(xx) - pad_w
  127. xmax = np.max(xx) - pad_w
  128. xmin = np.clip(xmin, 0, w - 1)
  129. xmax = np.clip(xmax, 0, w - 1)
  130. yy = np.random.choice(h_axis, size=2)
  131. ymin = np.min(yy) - pad_h
  132. ymax = np.max(yy) - pad_h
  133. ymin = np.clip(ymin, 0, h - 1)
  134. ymax = np.clip(ymax, 0, h - 1)
  135. if xmax - xmin < 0.1 * w or ymax - ymin < 0.1 * h:
  136. # 选择的区域过小
  137. # area too small
  138. continue
  139. if text_polys.shape[0] != 0: # 这个判断不知道干啥的
  140. poly_axis_in_area = (text_polys[:, :, 0] >= xmin) & (text_polys[:, :, 0] <= xmax) \
  141. & (text_polys[:, :, 1] >= ymin) & (text_polys[:, :, 1] <= ymax)
  142. selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
  143. else:
  144. selected_polys = []
  145. if len(selected_polys) == 0:
  146. # 区域内没有文本
  147. continue
  148. im = im[ymin:ymax + 1, xmin:xmax + 1, :]
  149. polys = text_polys[selected_polys]
  150. # 坐标调整到裁剪图片上
  151. polys[:, :, 0] -= xmin
  152. polys[:, :, 1] -= ymin
  153. return im, polys
  154. return im, text_polys
  155. def random_crop_image_pse(self, im: np.ndarray, text_polys: np.ndarray, input_size) -> tuple:
  156. """
  157. 从图片中裁剪出 cropsize大小的图片和对应区域的文本框
  158. :param im: 图片
  159. :param text_polys: 文本框
  160. :param input_size: 输出图像大小
  161. :return: 裁剪后的图片和文本框
  162. """
  163. h, w, _ = im.shape
  164. short_edge = min(h, w)
  165. if short_edge < input_size:
  166. # 保证短边 >= inputsize
  167. scale = input_size / short_edge
  168. im = cv2.resize(im, dsize=None, fx=scale, fy=scale)
  169. text_polys *= scale
  170. h, w, _ = im.shape
  171. # 计算随机范围
  172. w_range = w - input_size
  173. h_range = h - input_size
  174. for _ in range(50):
  175. xmin = random.randint(0, w_range)
  176. ymin = random.randint(0, h_range)
  177. xmax = xmin + input_size
  178. ymax = ymin + input_size
  179. if text_polys.shape[0] != 0:
  180. selected_polys = []
  181. for poly in text_polys:
  182. if poly[:, 0].max() < xmin or poly[:, 0].min() > xmax or \
  183. poly[:, 1].max() < ymin or poly[:, 1].min() > ymax:
  184. continue
  185. # area_p = cv2.contourArea(poly)
  186. poly[:, 0] -= xmin
  187. poly[:, 1] -= ymin
  188. poly[:, 0] = np.clip(poly[:, 0], 0, input_size)
  189. poly[:, 1] = np.clip(poly[:, 1], 0, input_size)
  190. # rect = cv2.minAreaRect(poly)
  191. # area_n = cv2.contourArea(poly)
  192. # h1, w1 = rect[1]
  193. # if w1 < 10 or h1 < 10 or area_n / area_p < 0.5:
  194. # continue
  195. selected_polys.append(poly)
  196. else:
  197. selected_polys = []
  198. # if len(selected_polys) == 0:
  199. # 区域内没有文本
  200. # continue
  201. im = im[ymin:ymax, xmin:xmax, :]
  202. polys = np.array(selected_polys)
  203. return im, polys
  204. return im, text_polys
  205. def random_crop_author(self, imgs, img_size):
  206. h, w = imgs[0].shape[0:2]
  207. th, tw = img_size
  208. if w == tw and h == th:
  209. return imgs
  210. # label中存在文本实例,并且按照概率进行裁剪
  211. if np.max(imgs[1][:, :, -1]) > 0 and random.random() > 3.0 / 8.0:
  212. # 文本实例的top left点
  213. tl = np.min(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size
  214. tl[tl < 0] = 0
  215. # 文本实例的 bottom right 点
  216. br = np.max(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size
  217. br[br < 0] = 0
  218. # 保证选到右下角点是,有足够的距离进行crop
  219. br[0] = min(br[0], h - th)
  220. br[1] = min(br[1], w - tw)
  221. for _ in range(50000):
  222. i = random.randint(tl[0], br[0])
  223. j = random.randint(tl[1], br[1])
  224. # 保证最小的图有文本
  225. if imgs[1][:, :, 0][i:i + th, j:j + tw].sum() <= 0:
  226. continue
  227. else:
  228. break
  229. else:
  230. i = random.randint(0, h - th)
  231. j = random.randint(0, w - tw)
  232. # return i, j, th, tw
  233. for idx in range(len(imgs)):
  234. if len(imgs[idx].shape) == 3:
  235. imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
  236. else:
  237. imgs[idx] = imgs[idx][i:i + th, j:j + tw]
  238. return imgs
  239. def resize(self, im: np.ndarray, text_polys: np.ndarray,
  240. input_size: numbers.Number or list or tuple or np.ndarray, keep_ratio: bool = False) -> tuple:
  241. """
  242. 对图片和文本框进行resize
  243. :param im: 图片
  244. :param text_polys: 文本框
  245. :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
  246. :param keep_ratio: 是否保持长宽比
  247. :return: resize后的图片和文本框
  248. """
  249. if isinstance(input_size, numbers.Number):
  250. if input_size < 0:
  251. raise ValueError("If input_size is a single number, it must be positive.")
  252. input_size = (input_size, input_size)
  253. elif isinstance(input_size, list) or isinstance(input_size, tuple) or isinstance(input_size, np.ndarray):
  254. if len(input_size) != 2:
  255. raise ValueError("If input_size is a sequence, it must be of len 2.")
  256. input_size = (input_size[0], input_size[1])
  257. else:
  258. raise Exception('input_size must in Number or list or tuple or np.ndarray')
  259. if keep_ratio:
  260. # 将图片短边pad到和长边一样
  261. h, w, c = im.shape
  262. max_h = max(h, input_size[0])
  263. max_w = max(w, input_size[1])
  264. im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8)
  265. im_padded[:h, :w] = im.copy()
  266. im = im_padded
  267. text_polys = text_polys.astype(np.float32)
  268. h, w, _ = im.shape
  269. im = cv2.resize(im, input_size)
  270. w_scale = input_size[0] / float(w)
  271. h_scale = input_size[1] / float(h)
  272. text_polys[:, :, 0] *= w_scale
  273. text_polys[:, :, 1] *= h_scale
  274. return im, text_polys
  275. def horizontal_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple:
  276. """
  277. 对图片和文本框进行水平翻转
  278. :param im: 图片
  279. :param text_polys: 文本框
  280. :return: 水平翻转之后的图片和文本框
  281. """
  282. flip_text_polys = text_polys.copy()
  283. flip_im = cv2.flip(im, 1)
  284. h, w, _ = flip_im.shape
  285. flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
  286. return flip_im, flip_text_polys
  287. def vertical_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple:
  288. """
  289. 对图片和文本框进行竖直翻转
  290. :param im: 图片
  291. :param text_polys: 文本框
  292. :return: 竖直翻转之后的图片和文本框
  293. """
  294. flip_text_polys = text_polys.copy()
  295. flip_im = cv2.flip(im, 0)
  296. h, w, _ = flip_im.shape
  297. flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
  298. return flip_im, flip_text_polys
  299. def test(self, im: np.ndarray, text_polys: np.ndarray):
  300. print('随机尺度缩放')
  301. t_im, t_text_polys = self.random_scale(im, text_polys, [0.5, 1, 2, 3])
  302. print(t_im.shape, t_text_polys.dtype)
  303. show_pic(t_im, t_text_polys, 'random_scale')
  304. print('随机旋转')
  305. t_im, t_text_polys = self.random_rotate_img_bbox(im, text_polys, 10)
  306. print(t_im.shape, t_text_polys.dtype)
  307. show_pic(t_im, t_text_polys, 'random_rotate_img_bbox')
  308. print('随机裁剪')
  309. t_im, t_text_polys = self.random_crop_img_bboxes(im, text_polys)
  310. print(t_im.shape, t_text_polys.dtype)
  311. show_pic(t_im, t_text_polys, 'random_crop_img_bboxes')
  312. print('水平翻转')
  313. t_im, t_text_polys = self.horizontal_flip(im, text_polys)
  314. print(t_im.shape, t_text_polys.dtype)
  315. show_pic(t_im, t_text_polys, 'horizontal_flip')
  316. print('竖直翻转')
  317. t_im, t_text_polys = self.vertical_flip(im, text_polys)
  318. print(t_im.shape, t_text_polys.dtype)
  319. show_pic(t_im, t_text_polys, 'vertical_flip')
  320. show_pic(im, text_polys, 'vertical_flip_ori')
  321. print('加噪声')
  322. t_im = self.add_noise(im)
  323. print(t_im.shape)
  324. show_pic(t_im, text_polys, 'add_noise')
  325. show_pic(im, text_polys, 'add_noise_ori')
  326. data_aug = DataAugment()
  327. def load_json(file_path: str):
  328. with open(file_path, 'r', encoding='utf8') as f:
  329. content = json.load(f)
  330. return content
  331. def check_and_validate_polys(polys, xxx_todo_changeme):
  332. '''
  333. check so that the text poly is in the same direction,
  334. and also filter some invalid polygons
  335. :param polys:
  336. :param tags:
  337. :return:
  338. '''
  339. (h, w) = xxx_todo_changeme
  340. if polys.shape[0] == 0:
  341. return polys
  342. polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) # x coord not max w-1, and not min 0
  343. polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) # y coord not max h-1, and not min 0
  344. validated_polys = []
  345. for poly in polys:
  346. p_area = cv2.contourArea(poly)
  347. if abs(p_area) < 1:
  348. continue
  349. validated_polys.append(poly)
  350. return np.array(validated_polys)
  351. def generate_rbox(im_size, text_polys, text_tags, training_mask, i, n, m):
  352. """
  353. 生成mask图,白色部分是文本,黑色是北京
  354. :param im_size: 图像的h,w
  355. :param text_polys: 框的坐标
  356. :param text_tags: 标注文本框是否参与训练
  357. :return: 生成的mask图
  358. """
  359. h, w = im_size
  360. score_map = np.zeros((h, w), dtype=np.uint8)
  361. for poly, tag in zip(text_polys, text_tags):
  362. poly = poly.astype(np.int)
  363. r_i = 1 - (1 - m) * (n - i) / (n - 1)
  364. d_i = cv2.contourArea(poly) * (1 - r_i * r_i) / cv2.arcLength(poly, True)
  365. pco = pyclipper.PyclipperOffset()
  366. # pco.AddPath(pyclipper.scale_to_clipper(poly), pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  367. # shrinked_poly = np.floor(np.array(pyclipper.scale_from_clipper(pco.Execute(-d_i)))).astype(np.int)
  368. pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  369. shrinked_poly = np.array(pco.Execute(-d_i))
  370. cv2.fillPoly(score_map, shrinked_poly, 1)
  371. # 制作mask
  372. # rect = cv2.minAreaRect(shrinked_poly)
  373. # poly_h, poly_w = rect[1]
  374. # if min(poly_h, poly_w) < 10:
  375. # cv2.fillPoly(training_mask, shrinked_poly, 0)
  376. if tag:
  377. cv2.fillPoly(training_mask, shrinked_poly, 0)
  378. # 闭运算填充内部小框
  379. # kernel = np.ones((3, 3), np.uint8)
  380. # score_map = cv2.morphologyEx(score_map, cv2.MORPH_CLOSE, kernel)
  381. return score_map, training_mask
  382. def augmentation(im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray, degrees: int, input_size: int) -> tuple:
  383. # the images are rescaled with ratio {0.5, 1.0, 2.0, 3.0} randomly
  384. im, text_polys = data_aug.random_scale(im, text_polys, scales)
  385. # the images are horizontally fliped and rotated in range [−10◦, 10◦] randomly
  386. if random.random() < 0.5:
  387. im, text_polys = data_aug.horizontal_flip(im, text_polys)
  388. if random.random() < 0.5:
  389. im, text_polys = data_aug.random_rotate_img_bbox(im, text_polys, degrees)
  390. # 640 × 640 random samples are cropped from the transformed images
  391. # im, text_polys = data_aug.random_crop_img_bboxes(im, text_polys)
  392. # im, text_polys = data_aug.resize(im, text_polys, input_size, keep_ratio=False)
  393. # im, text_polys = data_aug.random_crop_image_pse(im, text_polys, input_size)
  394. return im, text_polys
  395. class EastRandomCropData():
  396. def __init__(self, size=(640, 640), max_tries=50, min_crop_side_ratio=0.1, require_original_image=False, keep_ratio=True):
  397. self.size = size
  398. self.max_tries = max_tries
  399. self.min_crop_side_ratio = min_crop_side_ratio
  400. self.require_original_image = require_original_image
  401. self.keep_ratio = keep_ratio
  402. def __call__(self, data: dict) -> dict:
  403. """
  404. 从scales中随机选择一个尺度,对图片和文本框进行缩放
  405. :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
  406. :return:
  407. """
  408. im = data['img']
  409. training_mask = data['training_mask']
  410. score_maps = data['score_maps'].transpose((1,2,0))
  411. text_polys = data['text_polys']
  412. ignore_tags = data['ignore_tags']
  413. texts = data['texts']
  414. all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
  415. # 计算crop区域
  416. crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
  417. # crop 图片 保持比例填充
  418. scale_w = self.size[0] / crop_w
  419. scale_h = self.size[1] / crop_h
  420. scale = min(scale_w, scale_h)
  421. h = int(crop_h * scale)
  422. w = int(crop_w * scale)
  423. try:
  424. if self.keep_ratio:
  425. if len(im.shape) == 3:
  426. padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
  427. else:
  428. padimg = np.zeros((self.size[1], self.size[0]), im.dtype)
  429. padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
  430. img = padimg
  431. padimg2 = np.zeros((self.size[1], self.size[0]), im.dtype)
  432. padimg2[:h, :w] = cv2.resize(training_mask[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
  433. data['training_mask'] = padimg2
  434. padimg2 = np.zeros((self.size[1], self.size[0],6), im.dtype)
  435. padimg2[:h, :w] = cv2.resize(score_maps[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
  436. data['score_maps'] = padimg2.transpose((2,0,1))
  437. else:
  438. img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], tuple(self.size))
  439. except Exception:
  440. import traceback
  441. traceback.print_exc()
  442. # crop 文本框
  443. text_polys_crop = []
  444. ignore_tags_crop = []
  445. texts_crop = []
  446. try:
  447. for poly, text, tag in zip(text_polys, texts, ignore_tags):
  448. poly = ((np.array(poly) - (crop_x, crop_y)) * scale).astype('float32')
  449. if not self.is_poly_outside_rect(poly, 0, 0, w, h):
  450. text_polys_crop.append(poly)
  451. ignore_tags_crop.append(tag)
  452. texts_crop.append(text)
  453. data['img'] = img
  454. data['text_polys'] = text_polys_crop
  455. data['ignore_tags'] = ignore_tags_crop
  456. data['texts'] = texts_crop
  457. except:
  458. a = 1
  459. return data
  460. def is_poly_in_rect(self, poly, x, y, w, h):
  461. poly = np.array(poly)
  462. if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
  463. return False
  464. if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
  465. return False
  466. return True
  467. def is_poly_outside_rect(self, poly, x, y, w, h):
  468. poly = np.array(poly)
  469. if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
  470. return True
  471. if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
  472. return True
  473. return False
  474. def split_regions(self, axis):
  475. regions = []
  476. min_axis = 0
  477. for i in range(1, axis.shape[0]):
  478. if axis[i] != axis[i - 1] + 1:
  479. region = axis[min_axis:i]
  480. min_axis = i
  481. regions.append(region)
  482. return regions
  483. def random_select(self, axis, max_size):
  484. xx = np.random.choice(axis, size=2)
  485. xmin = np.min(xx)
  486. xmax = np.max(xx)
  487. xmin = np.clip(xmin, 0, max_size - 1)
  488. xmax = np.clip(xmax, 0, max_size - 1)
  489. return xmin, xmax
  490. def region_wise_random_select(self, regions, max_size):
  491. selected_index = list(np.random.choice(len(regions), 2))
  492. selected_values = []
  493. for index in selected_index:
  494. axis = regions[index]
  495. xx = int(np.random.choice(axis, size=1))
  496. selected_values.append(xx)
  497. xmin = min(selected_values)
  498. xmax = max(selected_values)
  499. return xmin, xmax
  500. def crop_area(self, im, text_polys):
  501. h, w = im.shape[:2]
  502. h_array = np.zeros(h, dtype=np.int32)
  503. w_array = np.zeros(w, dtype=np.int32)
  504. for points in text_polys:
  505. points = np.round(points, decimals=0).astype(np.int32)
  506. minx = np.min(points[:, 0])
  507. maxx = np.max(points[:, 0])
  508. w_array[minx:maxx] = 1
  509. miny = np.min(points[:, 1])
  510. maxy = np.max(points[:, 1])
  511. h_array[miny:maxy] = 1
  512. # ensure the cropped area not across a text
  513. h_axis = np.where(h_array == 0)[0]
  514. w_axis = np.where(w_array == 0)[0]
  515. if len(h_axis) == 0 or len(w_axis) == 0:
  516. return 0, 0, w, h
  517. h_regions = self.split_regions(h_axis)
  518. w_regions = self.split_regions(w_axis)
  519. for i in range(self.max_tries):
  520. if len(w_regions) > 1:
  521. xmin, xmax = self.region_wise_random_select(w_regions, w)
  522. else:
  523. xmin, xmax = self.random_select(w_axis, w)
  524. if len(h_regions) > 1:
  525. ymin, ymax = self.region_wise_random_select(h_regions, h)
  526. else:
  527. ymin, ymax = self.random_select(h_axis, h)
  528. if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
  529. # area too small
  530. continue
  531. num_poly_in_rect = 0
  532. for poly in text_polys:
  533. if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
  534. num_poly_in_rect += 1
  535. break
  536. if num_poly_in_rect > 0:
  537. return xmin, ymin, xmax - xmin, ymax - ymin
  538. return 0, 0, w, h
  539. erc=EastRandomCropData()
  540. def image_label(data, n: int, m: float, input_size: int,
  541. defrees: int = 10,
  542. scales: np.ndarray = np.array([0.5, 1, 2.0, 3.0])) -> tuple:
  543. '''
  544. get image's corresponding matrix and ground truth
  545. return
  546. images [512, 512, 3]
  547. score [128, 128, 1]
  548. geo [128, 128, 5]
  549. mask [128, 128, 1]
  550. '''
  551. im = cv2.imread(data['img_path'])
  552. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  553. h, w, _ = im.shape
  554. # 检查越界
  555. data['text_polys'] = check_and_validate_polys(data['text_polys'], (h, w))
  556. data['img'], data['text_polys'], = augmentation(im, data['text_polys'], scales, defrees, input_size)
  557. h, w, _ = data['img'].shape
  558. short_edge = min(h, w)
  559. if isinstance(input_size, dict):
  560. print(input_size)
  561. pass
  562. if short_edge < input_size:
  563. # 保证短边 >= inputsize
  564. scale = input_size / short_edge
  565. data['img'] = cv2.resize(data['img'], dsize=None, fx=scale, fy=scale)
  566. data['text_polys'] *= scale
  567. h, w, _ = data['img'].shape
  568. training_mask = np.ones((h, w), dtype=np.uint8)
  569. score_maps = []
  570. for i in range(1, n + 1):
  571. # s1->sn,由小到大
  572. score_map, training_mask = generate_rbox((h, w), data['text_polys'], data['ignore_tags'], training_mask, i, n, m)
  573. score_maps.append(score_map)
  574. score_maps = np.array(score_maps, dtype=np.float32)
  575. data['training_mask']=training_mask
  576. data['score_maps']=score_maps
  577. data=erc(data)
  578. return data
  579. # imgs = data_aug.random_crop_author([im, score_maps.transpose((1, 2, 0)), training_mask], (input_size, input_size))
  580. # return imgs[0], imgs[1].transpose((2, 0, 1)), imgs[2], text_polys, text_tags # im,score_maps,training_mask#
  581. import torch
  582. class MyDataset(data.Dataset):
  583. def __init__(self, config):
  584. self.load_char_annotation = False
  585. self.data_list = self.load_data(config.file)
  586. self.data_shape = config.data_shape
  587. self.filter_keys = config.filter_keys
  588. self.transform = transforms.Compose([
  589. transforms.ToTensor(),
  590. transforms.Normalize(mean=config.mean, std=config.std)
  591. ])
  592. self.n = config.n
  593. self.m = config.m
  594. def __getitem__(self, index):
  595. # print(self.image_list[index])
  596. data = self.data_list[index]
  597. img_path, text_polys, text_tags = self.data_list[index]['img_path'], self.data_list[index]['text_polys'], self.data_list[index]['ignore_tags']
  598. data = image_label(data, input_size=self.data_shape,n=self.n,m=self.m)
  599. im = cv2.imread(img_path)
  600. if self.transform:
  601. img = self.transform(data['img'])
  602. shape = (data['img'].shape[0], data['img'].shape[1])
  603. data['img'] = img
  604. data['shape'] = shape
  605. # data['score_maps'] = score_maps
  606. # data['training_mask'] = training_mask
  607. # data['text_polys'] =torch.Tensor(list(text_polys))
  608. # data['ignore_tags'] = [text_tags]
  609. # data['shape'] = shape
  610. # data['texts'] = [data['texts']]
  611. if len(self.filter_keys):
  612. data_dict = {}
  613. for k, v in data.items():
  614. if k not in self.filter_keys:
  615. data_dict[k] = v
  616. return data_dict
  617. else:
  618. # return {'img': img, 'score_maps': score_maps, 'training_mask': training_mask, 'shape': shape, 'text_polys': list(text_polys), 'ignore_tags': text_tags}
  619. return {}
  620. def load_data(self, path: str) -> list:
  621. data_list = []
  622. content = load_json(path)
  623. for gt in tqdm(content['data_list'], desc='read file {}'.format(path)):
  624. img_path = os.path.join(content['data_root'], gt['img_name'])
  625. polygons = []
  626. texts = []
  627. illegibility_list = []
  628. language_list = []
  629. for annotation in gt['annotations']:
  630. if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
  631. continue
  632. polygons.append(annotation['polygon'])
  633. texts.append(annotation['text'])
  634. illegibility_list.append(annotation['illegibility'])
  635. language_list.append(annotation['language'])
  636. if self.load_char_annotation:
  637. for char_annotation in annotation['chars']:
  638. if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
  639. continue
  640. polygons.append(char_annotation['polygon'])
  641. texts.append(char_annotation['char'])
  642. illegibility_list.append(char_annotation['illegibility'])
  643. language_list.append(char_annotation['language'])
  644. data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': np.array(polygons, dtype=np.float32),
  645. 'texts': texts, 'ignore_tags': illegibility_list})
  646. return data_list
  647. def __len__(self):
  648. return len(self.data_list)
  649. def save_label(self, img_path, label):
  650. save_path = img_path.replace('img', 'save')
  651. if not os.path.exists(os.path.split(save_path)[0]):
  652. os.makedirs(os.path.split(save_path)[0])
  653. img = draw_bbox(img_path, label)
  654. cv2.imwrite(save_path, img)
  655. return img
  656. def show_img(imgs: np.ndarray, color=False):
  657. if (len(imgs.shape) == 3 and color) or (len(imgs.shape) == 2 and not color):
  658. imgs = np.expand_dims(imgs, axis=0)
  659. for img in imgs:
  660. plt.figure()
  661. plt.imshow(img, cmap=None if color else 'gray')
  662. if __name__ == '__main__':
  663. import torch
  664. import config
  665. from config.cfg_det_pse import config
  666. from tqdm import tqdm
  667. from torch.utils.data import DataLoader
  668. import matplotlib.pyplot as plt
  669. from torchvision import transforms
  670. train_data = MyDataset(config.dataset.train.dataset)
  671. train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, num_workers=0)
  672. pbar = tqdm(total=len(train_loader))
  673. for i, batch_data in enumerate(train_loader):
  674. img, label, mask = batch_data['img'], batch_data['score_maps'], batch_data['training_mask']
  675. print(label.shape)
  676. print(img.shape)
  677. print(label[0][-1].sum())
  678. print(mask[0].shape)
  679. pbar.update(1)
  680. show_img((img[0] * mask[0].to(torch.float)).numpy().transpose(1, 2, 0), color=True)
  681. show_img(label[0])
  682. show_img(mask[0])
  683. plt.show()
  684. pbar.close()