pre_process.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Jun 21 10:53:51 2022
  5. pre_process.py
  6. @author: fangjiasheng
  7. """
  8. import copy
  9. import json
  10. import base64
  11. import math
  12. import os
  13. import random
  14. import sys
  15. import traceback
  16. from glob import glob
  17. import numpy as np
  18. import six
  19. import cv2
  20. from PIL import Image
  21. import fitz
  22. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  23. from idc.utils import pil_resize, pil_rotate
  24. Image.MAX_IMAGE_PIXELS = 2300000000
  25. def get_img_label(img_np, size, cls_num=4):
  26. height, width = size
  27. h, w = get_best_predict_size2(img_np, threshold=1080)
  28. img_np = pil_resize(img_np, h, w)
  29. # cv2.namedWindow("origin", 0)
  30. # cv2.resizeWindow("origin", 1000, 800)
  31. # cv2.imshow("origin", img_np)
  32. # 获取合适的文字区域
  33. result_list, img_np = get_text_region(img_np, size)
  34. # print(len(result_list), img_np.shape)
  35. if not result_list:
  36. return []
  37. if img_np.shape[0] != height or img_np.shape[1] != width:
  38. img_np = pil_resize(img_np, height, width)
  39. # 生成旋转后的图片及其角度
  40. img_label_list = [[img_np, 0]]
  41. # 图片旋转
  42. angle_first = int(360/cls_num)
  43. i = 1
  44. for angle in range(angle_first, 360, angle_first):
  45. img_rotate = pil_rotate(img_np, angle)
  46. img_label_list.append([img_rotate, i])
  47. i += 1
  48. # for _img, _label in img_label_list:
  49. # cv2.imshow("img", _img)
  50. # cv2.waitKey(0)
  51. return img_label_list
  52. def get_text_region2(img_np, size):
  53. img_np = remove_black_border(img_np)
  54. origin_h, origin_w = img_np.shape[:2]
  55. gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  56. h, w = get_best_predict_size2(img_np, threshold=640)
  57. img_np = pil_resize(img_np, h, w)
  58. # 1. 转化成灰度图
  59. img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  60. # 2. 二值化
  61. ret, binary = cv2.threshold(img_np, 125, 255, cv2.THRESH_BINARY_INV)
  62. # 3. 膨胀和腐蚀操作的核函数
  63. kernal = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
  64. # 4. 膨胀一次,让轮廓突出
  65. dilation = cv2.dilate(binary, kernal, iterations=2)
  66. dilation1 = copy.deepcopy(dilation)
  67. dilation1 = np.add(np.int0(np.full(dilation1.shape, 255)), -1 * np.int0(dilation))
  68. dilation1 = np.uint8(dilation1)
  69. # cv2.imshow("dilation1", dilation1)
  70. # 5. 腐蚀一次,去掉细节,如表格线等。注意这里去掉的是竖直的线
  71. erosion = cv2.erode(dilation, kernal, iterations=1)
  72. dilation = cv2.dilate(erosion, kernal, iterations=3)
  73. # 颜色反转
  74. dilation = np.add(np.int0(np.full(dilation.shape, 255)), -1 * np.int0(dilation))
  75. dilation = np.uint8(dilation)
  76. # 1. 查找轮廓
  77. contours, hierarchy = cv2.findContours(dilation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
  78. region_list = []
  79. for cont in contours:
  80. rect = cv2.minAreaRect(cont)
  81. # box是四个点的坐标
  82. box = cv2.boxPoints(rect)
  83. box = np.int0(box)
  84. box = box.tolist()
  85. new_box = [[10000, 10000], [0, 0]]
  86. for p in box:
  87. if p[0] < new_box[0][0]:
  88. new_box[0][0] = p[0]
  89. elif p[0] > new_box[1][0]:
  90. new_box[1][0] = p[0]
  91. if p[1] < new_box[0][1]:
  92. new_box[0][1] = p[1]
  93. elif p[1] > new_box[1][1]:
  94. new_box[1][1] = p[1]
  95. # box.sort(key=lambda x: (x[0], x[1]))
  96. # if box[0][0] > box[3][0]:
  97. # temp = box[0][0]
  98. # box[0][0] = box[3][0]
  99. # box[3][0] = temp
  100. # if box[0][1] > box[3][1]:
  101. # temp = box[0][1]
  102. # box[0][1] = box[3][1]
  103. # box[3][1] = temp
  104. region_list.append(new_box)
  105. dilation = np.expand_dims(dilation, axis=-1)
  106. dilation = np.concatenate([dilation, dilation, dilation], axis=-1)
  107. # for box in region_list:
  108. # # cv2.drawContours(dilation, [box], 0, (0, 255, 0), 2)
  109. # cv2.rectangle(dilation, (box[0][0], box[0][1]), (box[3][0], box[3][1]), (0, 255, 0), 2)
  110. region_list.sort(key=lambda x: abs((x[1][0] - x[0][0])*(x[1][1] - x[0][1])), reverse=True)
  111. # print("len(region_list)", len(region_list))
  112. # 筛选文字区域
  113. result_list = []
  114. h_scale = origin_h / h
  115. w_scale = origin_w / w
  116. for box in region_list:
  117. # if i >= 20:
  118. # break
  119. p1 = box[0]
  120. p2 = box[1]
  121. # print(p1, p2, abs((p2[1]-p1[1])*(p2[1]-p1[1])), h*w)
  122. # 旋转的box忽略
  123. if p1[0] >= p2[0] or p1[1] >= p2[1]:
  124. # print(box)
  125. # print(p1[0], ">", p2[0], p1[1], ">", p2[1])
  126. continue
  127. # 太大的box忽略
  128. if abs(p2[0] - p1[0]) >= 0.7 * w and abs(p2[1] - p1[1]) >= 0.7 * h:
  129. # print("too large", abs(p2[0] - p1[0]), abs(p2[1] - p1[1]), 0.7 * w, 0.7 * h)
  130. continue
  131. # 黑色点不够的忽略
  132. cnt_black = count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]) / (abs(p2[0] - p1[0])*abs(p2[1] - p1[1]))
  133. # if cnt_black < abs(p2[0] - p1[0])*abs(p2[1] - p1[1])*0.1:
  134. # print("black not enough")
  135. # continue
  136. # if not count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]):
  137. # # print("black not enough")
  138. # continue
  139. p1[1] = int(p1[1] * h_scale)
  140. p1[0] = int(p1[0] * w_scale)
  141. p2[1] = int(p2[1] * h_scale)
  142. p2[0] = int(p2[0] * w_scale)
  143. result_list.append([p1, p2, cnt_black])
  144. # cv2.imshow("result", dilation)
  145. # cv2.waitKey(0)
  146. if not result_list:
  147. return [], None
  148. result_list.sort(key=lambda x: x[2], reverse=True)
  149. # for r in result_list:
  150. # print(r)
  151. # 裁剪
  152. # if top_n > 1:
  153. # result = random.sample(result_list, 1)[0]
  154. # height_start = result[0][1]
  155. # width_start = result[0][0]
  156. # else:
  157. height_start = result_list[0][0][1]
  158. width_start = result_list[0][0][0]
  159. height, width = size
  160. gray = gray[height_start:height_start+height, width_start:width_start+width]
  161. # cv2.imshow("gray", gray)
  162. # cv2.waitKey(0)
  163. return result_list, gray
  164. def get_text_region3(img_np, size):
  165. img_np = remove_black_border(img_np)
  166. origin_h, origin_w = img_np.shape[:2]
  167. gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  168. h, w = get_best_predict_size2(img_np, threshold=640)
  169. img_np = pil_resize(img_np, h, w)
  170. # 1. 转化成灰度图
  171. img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  172. result_list = []
  173. return result_list, gray
  174. def get_text_region(img_np, size=(640, 640)):
  175. origin_h, origin_w = img_np.shape[:2]
  176. # 1. crop
  177. crop_h, crop_w = 2000, 2000
  178. if origin_h > crop_h:
  179. index = int((origin_h - crop_h) / 2)
  180. img_np = img_np[index:index+crop_h, :]
  181. if origin_w > crop_w:
  182. index = int((origin_w - crop_w) / 2)
  183. img_np = img_np[:, index:index+crop_w]
  184. # 2. resize
  185. # h, w = get_best_predict_size2(img_np, threshold=640)
  186. img_np = pil_resize(img_np, size[0], size[1])
  187. # 3. gray
  188. img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  189. return img_np
  190. def gen(paths, batch_size=2, shape=(640, 640), cls_num=4, is_test=False):
  191. def choose(_paths, _i):
  192. while True:
  193. if _i >= len(_paths):
  194. _i = 0
  195. np.random.shuffle(paths)
  196. _p = _paths[_i]
  197. # read error
  198. try:
  199. _img_np = cv2.imread(_p)
  200. # limit h, w > 150
  201. if _img_np.shape[0] <= 150 or _img_np.shape[1] <= 150:
  202. _i += 1
  203. continue
  204. # limit pixels 89478485
  205. if _img_np.shape[0] * _img_np.shape[1] * _img_np.shape[2] >= 89478485:
  206. _i += 1
  207. continue
  208. _img_label_list = get_img_label(_img_np, size=(height, width), cls_num=cls_num)
  209. if not _img_label_list:
  210. _i += 1
  211. continue
  212. except:
  213. _i += 1
  214. continue
  215. _i += 1
  216. return _img_label_list, _i
  217. num = len(paths)
  218. i = 0
  219. all_cnt = 0
  220. while True:
  221. height, width = shape
  222. if is_test:
  223. X = np.zeros((batch_size, height, width, 1))
  224. Y = np.zeros((batch_size, cls_num))
  225. else:
  226. X = np.zeros((batch_size, height, width, 1))
  227. Y = np.zeros((batch_size, cls_num))
  228. img_np_list = []
  229. batch_list = []
  230. if batch_size % cls_num != 0:
  231. print("batch_size % cls_num != 0")
  232. raise
  233. for j in range(batch_size//cls_num):
  234. # 生成标注数据
  235. img_label_list, i = choose(paths, i)
  236. random.shuffle(img_label_list)
  237. if is_test:
  238. img_label_list = random.sample(img_label_list, 1)
  239. for c in range(cls_num):
  240. if c >= len(img_label_list):
  241. break
  242. img = img_label_list[c][0]
  243. img_np_list.append(img)
  244. # 模糊
  245. if_blur = random.choice([0, 0])
  246. if if_blur:
  247. # 高斯模糊
  248. sigmaX = random.randint(1, 2)
  249. sigmaY = random.randint(1, 2)
  250. img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY)
  251. # label
  252. label_list = [0]*cls_num
  253. label_list[img_label_list[c][1]] = 1
  254. label = np.array(label_list)
  255. if len(img.shape) < 3:
  256. img = np.expand_dims(img, axis=-1)
  257. img = np.expand_dims(img[:, :, 0], axis=-1)
  258. batch_list.append([img, label])
  259. # X[j+c] = img
  260. # Y[j+c] = label
  261. random.shuffle(batch_list)
  262. for j in range(0, batch_size, cls_num):
  263. for c in range(cls_num):
  264. img, label = batch_list[j+c]
  265. # print("label", label)
  266. # cv2.imshow("gen", img)
  267. # cv2.waitKey(0)
  268. # cv2.imwrite("data/3/"+str(all_cnt)+"_"+str(label)+".jpg", img)
  269. all_cnt += 1
  270. X[j+c] = img
  271. Y[j+c] = label
  272. if is_test:
  273. yield X, Y, img_np_list
  274. else:
  275. yield X, Y
  276. def get_image_from_pdf():
  277. paths = glob("C:/Users/Administrator/Desktop/test_pdf/*")
  278. save_dir = "D:/Project/image_direction_classification/data/1/"
  279. i = 0
  280. for path in paths:
  281. try:
  282. doc = fitz.open(path)
  283. output_image_dict = {}
  284. page_count = doc.page_count
  285. for page_no in range(page_count):
  286. try:
  287. page = doc.loadPage(page_no)
  288. output = save_dir + "pdf_" + str(i) + ".png"
  289. i += 1
  290. rotate = int(0)
  291. # 每个尺寸的缩放系数为1.3,这将为我们生成分辨率提高2.6的图像。
  292. # 此处若是不做设置,默认图片大小为:792X612, dpi=96
  293. # (1.33333333 --> 1056x816) (2 --> 1584x1224)
  294. # (1.183, 2.28 --> 1920x1080)
  295. zoom_x = 1.3
  296. zoom_y = 1.3
  297. mat = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
  298. pix = page.getPixmap(matrix=mat, alpha=False)
  299. pix.writePNG(output)
  300. except:
  301. continue
  302. except Exception as e:
  303. print("pdf2Image", traceback.print_exc())
  304. continue
  305. def get_best_predict_size2(image_np, threshold=640):
  306. h, w = image_np.shape[:2]
  307. scale = threshold / max(h, w)
  308. h = int(math.ceil(h * scale))
  309. w = int(math.ceil(w * scale))
  310. return h, w
  311. def count_black(image_np):
  312. try:
  313. if len(image_np.shape) == 3:
  314. lower_black = np.array([0, 0, 0])
  315. upper_black = np.array([10, 10, 10])
  316. else:
  317. lower_black = np.array([0])
  318. upper_black = np.array([10])
  319. mask = cv2.inRange(image_np, lower_black, upper_black)
  320. cnt = np.sum(mask != 0)
  321. return cnt
  322. # print("black count", cnt, image_np.shape[0]*image_np.shape[1])
  323. # if cnt >= image_np.shape[0]*image_np.shape[1]*0.3:
  324. # return True
  325. # else:
  326. # return False
  327. except:
  328. return 0
  329. def remove_black_border(img_np):
  330. try:
  331. # 阈值
  332. threshold = 100
  333. # 转换为灰度图像
  334. gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
  335. # 获取图片尺寸
  336. h, w = gray.shape[:2]
  337. # 无法区分黑色区域超过一半的情况
  338. rowc = gray[:, int(1/2*w)]
  339. colc = gray[int(1/2*h), :]
  340. rowflag = np.argwhere(rowc > threshold)
  341. colflag = np.argwhere(colc > threshold)
  342. left, bottom, right, top = rowflag[0, 0], colflag[-1, 0], rowflag[-1, 0], colflag[0, 0]
  343. # cv2.imshow('remove_black_border', img_np[left:right, top:bottom, :])
  344. # cv2.waitKey()
  345. return img_np[left:right, top:bottom, :]
  346. except:
  347. return img_np
  348. if __name__ == '__main__':
  349. # get_img_label("data/0/7248_fe52d616989e19e6967e0461ef19b149.jpg", (640, 640))
  350. # get_image_from_pdf()
  351. # paths = glob("C:\\Users\\Administrator\\Desktop\\test_image\\*")
  352. # for path in paths:
  353. # get_text_region(cv2.imread(path))
  354. # path = "C:\\Users\\Administrator\\Desktop\\test_image\\error17.jpg"
  355. # get_img_label(cv2.imread(path), (192, 192))
  356. table_image_path = "data/0/*"
  357. pdf_image_path = "data/1/*"
  358. no_table_image_path = 'data/2/*'
  359. paths = glob(table_image_path) + glob(pdf_image_path) + glob(no_table_image_path)
  360. _gen = gen(paths, batch_size=32, shape=(192, 192))
  361. for g in _gen:
  362. print("ok")
  363. break