pre_process.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Jun 21 10:53:51 2022
  5. pre_process.py
  6. @author: fangjiasheng
  7. """
  8. import copy
  9. import json
  10. import base64
  11. import math
  12. import os
  13. import random
  14. import sys
  15. import traceback
  16. from glob import glob
  17. import numpy as np
  18. import six
  19. import cv2
  20. from PIL import Image
  21. import fitz
  22. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  23. from idc.utils import pil_resize, pil_rotate
  24. Image.MAX_IMAGE_PIXELS = 2300000000
  25. def get_img_label(img_np, size, cls_num=4):
  26. height, width = size
  27. h, w = get_best_predict_size2(img_np, threshold=1080)
  28. img_np = pil_resize(img_np, h, w)
  29. # cv2.namedWindow("origin", 0)
  30. # cv2.resizeWindow("origin", 1000, 800)
  31. # cv2.imshow("origin", img_np)
  32. # 获取合适的文字区域
  33. result_list, img_np = get_text_region(img_np, size, 1)
  34. # print(len(result_list), img_np.shape)
  35. if not result_list:
  36. return []
  37. if img_np.shape[0] != height or img_np.shape[1] != width:
  38. img_np = pil_resize(img_np, height, width)
  39. # 生成旋转后的图片及其角度
  40. img_label_list = [[img_np, 0]]
  41. # 图片旋转
  42. angle_first = int(360/cls_num)
  43. i = 1
  44. for angle in range(angle_first, 360, angle_first):
  45. img_rotate = pil_rotate(img_np, angle)
  46. img_label_list.append([img_rotate, i])
  47. i += 1
  48. # for _img, _label in img_label_list:
  49. # cv2.imshow("img", _img)
  50. # cv2.waitKey(0)
  51. return img_label_list
  52. def get_text_region(img_np, size, top_n=1):
  53. img_np = remove_black_border(img_np)
  54. origin_h, origin_w = img_np.shape[:2]
  55. gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  56. h, w = get_best_predict_size2(img_np, threshold=640)
  57. img_np = pil_resize(img_np, h, w)
  58. # 1. 转化成灰度图
  59. img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  60. # 2. 二值化
  61. ret, binary = cv2.threshold(img_np, 125, 255, cv2.THRESH_BINARY_INV)
  62. # 3. 膨胀和腐蚀操作的核函数
  63. kernal = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
  64. # 4. 膨胀一次,让轮廓突出
  65. dilation = cv2.dilate(binary, kernal, iterations=2)
  66. dilation1 = copy.deepcopy(dilation)
  67. dilation1 = np.add(np.int0(np.full(dilation1.shape, 255)), -1 * np.int0(dilation))
  68. dilation1 = np.uint8(dilation1)
  69. # cv2.imshow("dilation1", dilation1)
  70. # 5. 腐蚀一次,去掉细节,如表格线等。注意这里去掉的是竖直的线
  71. erosion = cv2.erode(dilation, kernal, iterations=1)
  72. dilation = cv2.dilate(erosion, kernal, iterations=3)
  73. # 颜色反转
  74. dilation = np.add(np.int0(np.full(dilation.shape, 255)), -1 * np.int0(dilation))
  75. dilation = np.uint8(dilation)
  76. # 1. 查找轮廓
  77. contours, hierarchy = cv2.findContours(dilation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
  78. region_list = []
  79. for cont in contours:
  80. rect = cv2.minAreaRect(cont)
  81. # box是四个点的坐标
  82. box = cv2.boxPoints(rect)
  83. box = np.int0(box)
  84. box = box.tolist()
  85. box.sort(key=lambda x: (x[0], x[1]))
  86. if box[0][0] > box[3][0]:
  87. temp = box[0][0]
  88. box[0][0] = box[3][0]
  89. box[3][0] = temp
  90. if box[0][1] > box[3][1]:
  91. temp = box[0][1]
  92. box[0][1] = box[3][1]
  93. box[3][1] = temp
  94. region_list.append(box)
  95. # dilation = np.expand_dims(dilation, axis=-1)
  96. # dilation = np.concatenate([dilation, dilation, dilation], axis=-1)
  97. # for box in region_list:
  98. # # cv2.drawContours(dilation, [box], 0, (0, 255, 0), 2)
  99. # cv2.rectangle(dilation, (box[0][0], box[0][1]), (box[3][0], box[3][1]), (0, 255, 0), 2)
  100. region_list.sort(key=lambda x: abs((x[3][0] - x[0][0])*(x[3][1] - x[0][1])), reverse=True)
  101. # 筛选文字区域
  102. result_list = []
  103. h_scale = origin_h / h
  104. w_scale = origin_w / w
  105. i = 0
  106. for box in region_list:
  107. if i >= top_n:
  108. break
  109. p1 = box[0]
  110. p2 = box[3]
  111. # print(p1, p2, abs((p2[1]-p1[1])*(p2[1]-p1[1])), h*w)
  112. # 旋转的box忽略
  113. if p1[0] >= p2[0] or p1[1] >= p2[1]:
  114. # print(p1[0], ">=", p2[0], p1[1], ">=", p2[1])
  115. continue
  116. # 太大的box忽略
  117. if abs(p2[0] - p1[0]) >= 0.7 * w and abs(p2[1] - p1[1]) >= 0.7 * h:
  118. # print("too large", abs(p2[0] - p1[0]), abs(p2[1] - p1[1]), 0.7 * w, 0.7 * h)
  119. continue
  120. # 黑色点不够的忽略
  121. if not count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]):
  122. # print("black not enough")
  123. continue
  124. p1[1] = int(p1[1] * h_scale)
  125. p1[0] = int(p1[0] * w_scale)
  126. p2[1] = int(p2[1] * h_scale)
  127. p2[0] = int(p2[0] * w_scale)
  128. result_list.append([p1, p2])
  129. i += 1
  130. # cv2.imshow("result", dilation)
  131. # cv2.waitKey(0)
  132. if not result_list:
  133. return [], None
  134. # 裁剪
  135. if top_n > 1:
  136. result = random.sample(result_list, 1)[0]
  137. height_start = result[0][1]
  138. width_start = result[0][0]
  139. else:
  140. height_start = result_list[0][0][1]
  141. width_start = result_list[0][0][0]
  142. height, width = size
  143. gray = gray[height_start:height_start+height, width_start:width_start+width]
  144. # cv2.imshow("gray", gray)
  145. # cv2.waitKey(0)
  146. return result_list, gray
  147. def gen(paths, batch_size=2, shape=(640, 640), cls_num=4, is_test=False):
  148. def choose(_paths, _i):
  149. while True:
  150. if _i >= len(_paths):
  151. _i = 0
  152. np.random.shuffle(paths)
  153. _p = _paths[_i]
  154. # read error
  155. try:
  156. _img_np = cv2.imread(_p)
  157. # limit h, w > 150
  158. if _img_np.shape[0] <= 150 or _img_np.shape[1] <= 150:
  159. _i += 1
  160. continue
  161. # limit pixels 89478485
  162. if _img_np.shape[0] * _img_np.shape[1] * _img_np.shape[2] >= 89478485:
  163. _i += 1
  164. continue
  165. _img_label_list = get_img_label(_img_np, size=(height, width), cls_num=cls_num)
  166. if not _img_label_list:
  167. _i += 1
  168. continue
  169. except:
  170. _i += 1
  171. continue
  172. _i += 1
  173. return _img_label_list, _i
  174. num = len(paths)
  175. i = 0
  176. all_cnt = 0
  177. while True:
  178. height, width = shape
  179. if is_test:
  180. X = np.zeros((batch_size, height, width, 1))
  181. Y = np.zeros((batch_size, cls_num))
  182. else:
  183. X = np.zeros((batch_size, height, width, 1))
  184. Y = np.zeros((batch_size, cls_num))
  185. img_np_list = []
  186. batch_list = []
  187. if batch_size % cls_num != 0:
  188. print("batch_size % cls_num != 0")
  189. raise
  190. for j in range(batch_size//cls_num):
  191. # 生成标注数据
  192. img_label_list, i = choose(paths, i)
  193. random.shuffle(img_label_list)
  194. if is_test:
  195. img_label_list = random.sample(img_label_list, 1)
  196. for c in range(cls_num):
  197. if c >= len(img_label_list):
  198. break
  199. img = img_label_list[c][0]
  200. img_np_list.append(img)
  201. # 模糊
  202. if_blur = random.choice([0, 0])
  203. if if_blur:
  204. # 高斯模糊
  205. sigmaX = random.randint(1, 2)
  206. sigmaY = random.randint(1, 2)
  207. img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY)
  208. # label
  209. label_list = [0]*cls_num
  210. label_list[img_label_list[c][1]] = 1
  211. label = np.array(label_list)
  212. if len(img.shape) < 3:
  213. img = np.expand_dims(img, axis=-1)
  214. img = np.expand_dims(img[:, :, 0], axis=-1)
  215. batch_list.append([img, label])
  216. # X[j+c] = img
  217. # Y[j+c] = label
  218. random.shuffle(batch_list)
  219. for j in range(0, batch_size, cls_num):
  220. for c in range(cls_num):
  221. img, label = batch_list[j+c]
  222. # print("label", label)
  223. # cv2.imshow("gen", img)
  224. # cv2.waitKey(0)
  225. # cv2.imwrite("data/3/"+str(all_cnt)+"_"+str(label)+".jpg", img)
  226. all_cnt += 1
  227. X[j+c] = img
  228. Y[j+c] = label
  229. if is_test:
  230. yield X, Y, img_np_list
  231. else:
  232. yield X, Y
  233. def get_image_from_pdf():
  234. paths = glob("C:/Users/Administrator/Desktop/test_pdf/*")
  235. save_dir = "D:/Project/image_direction_classification/data/1/"
  236. i = 0
  237. for path in paths:
  238. try:
  239. doc = fitz.open(path)
  240. output_image_dict = {}
  241. page_count = doc.page_count
  242. for page_no in range(page_count):
  243. try:
  244. page = doc.loadPage(page_no)
  245. output = save_dir + "pdf_" + str(i) + ".png"
  246. i += 1
  247. rotate = int(0)
  248. # 每个尺寸的缩放系数为1.3,这将为我们生成分辨率提高2.6的图像。
  249. # 此处若是不做设置,默认图片大小为:792X612, dpi=96
  250. # (1.33333333 --> 1056x816) (2 --> 1584x1224)
  251. # (1.183, 2.28 --> 1920x1080)
  252. zoom_x = 1.3
  253. zoom_y = 1.3
  254. mat = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
  255. pix = page.getPixmap(matrix=mat, alpha=False)
  256. pix.writePNG(output)
  257. except:
  258. continue
  259. except Exception as e:
  260. print("pdf2Image", traceback.print_exc())
  261. continue
  262. def get_best_predict_size2(image_np, threshold=640):
  263. h, w = image_np.shape[:2]
  264. scale = threshold / max(h, w)
  265. h = int(math.ceil(h * scale))
  266. w = int(math.ceil(w * scale))
  267. return h, w
  268. def count_black(image_np):
  269. try:
  270. if len(image_np.shape) == 3:
  271. lower_black = np.array([0, 0, 0])
  272. upper_black = np.array([10, 10, 10])
  273. else:
  274. lower_black = np.array([0])
  275. upper_black = np.array([10])
  276. mask = cv2.inRange(image_np, lower_black, upper_black)
  277. cnt = np.sum(mask != 0)
  278. # print("black count", cnt, image_np.shape[0]*image_np.shape[1])
  279. if cnt >= image_np.shape[0]*image_np.shape[1]*0.3:
  280. return True
  281. else:
  282. return False
  283. except:
  284. return False
  285. def remove_black_border(img_np):
  286. try:
  287. # 阈值
  288. threshold = 100
  289. # 转换为灰度图像
  290. gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
  291. # 获取图片尺寸
  292. h, w = gray.shape[:2]
  293. # 无法区分黑色区域超过一半的情况
  294. rowc = gray[:, int(1/2*w)]
  295. colc = gray[int(1/2*h), :]
  296. rowflag = np.argwhere(rowc > threshold)
  297. colflag = np.argwhere(colc > threshold)
  298. left, bottom, right, top = rowflag[0, 0], colflag[-1, 0], rowflag[-1, 0], colflag[0, 0]
  299. # cv2.imshow('remove_black_border', img_np[left:right, top:bottom, :])
  300. # cv2.waitKey()
  301. return img_np[left:right, top:bottom, :]
  302. except:
  303. return img_np
  304. if __name__ == '__main__':
  305. # get_img_label("data/0/7248_fe52d616989e19e6967e0461ef19b149.jpg", (640, 640))
  306. # get_image_from_pdf()
  307. # paths = glob("C:\\Users\\Administrator\\Desktop\\test_image\\*")
  308. # for path in paths:
  309. # get_text_region(cv2.imread(path))
  310. # path = "C:\\Users\\Administrator\\Desktop\\test_image\\error17.jpg"
  311. # get_img_label(cv2.imread(path), (192, 192))
  312. table_image_path = "data/0/*"
  313. pdf_image_path = "data/1/*"
  314. no_table_image_path = 'data/2/*'
  315. paths = glob(table_image_path) + glob(pdf_image_path) + glob(no_table_image_path)
  316. _gen = gen(paths, batch_size=32, shape=(192, 192))
  317. for g in _gen:
  318. print("ok")
  319. break