pre_process.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Jun 21 10:53:51 2022
  5. pre_process.py
  6. @author: fangjiasheng
  7. """
  8. import json
  9. import base64
  10. import random
  11. import traceback
  12. from glob import glob
  13. import numpy as np
  14. import six
  15. import cv2
  16. from PIL import Image
  17. import fitz
  18. Image.MAX_IMAGE_PIXELS = 2300000000
  19. def get_img_label(img_np, size, cls_num=4):
  20. height, width = size
  21. img_pil = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
  22. # 图片缩放
  23. img_pil = img_pil.resize((int(width), int(height)), Image.BICUBIC)
  24. # 生成旋转后的图片及其角度
  25. img_label_list = [[np.array(img_pil), 0]]
  26. # 图片旋转
  27. angle_first = int(360/cls_num)
  28. i = 1
  29. for angle in range(angle_first, 360, angle_first):
  30. img_label_list.append([np.array(img_pil.rotate(angle, expand=1)), i])
  31. i += 1
  32. # for _img, _label in img_label_list:
  33. # cv2.imshow("img", _img)
  34. # cv2.waitKey(0)
  35. return img_label_list
  36. def gen(paths, batch_size=2, shape=(640, 640), cls_num=4, is_test=False):
  37. num = len(paths)
  38. i = 0
  39. while True:
  40. height, width = shape
  41. if is_test:
  42. X = np.zeros((batch_size, height, width, 3))
  43. Y = np.zeros((batch_size, cls_num))
  44. else:
  45. X = np.zeros((batch_size * cls_num, height, width, 3))
  46. Y = np.zeros((batch_size * cls_num, cls_num))
  47. img_np_list = []
  48. for j in range(batch_size):
  49. if i >= num:
  50. i = 0
  51. np.random.shuffle(paths)
  52. p = paths[i]
  53. i += 1
  54. # limit pixels 89478485
  55. img_np = cv2.imread(p)
  56. if img_np.shape[0] * img_np.shape[1] * img_np.shape[2] >= 89478485:
  57. # print("image too large, limit 89478485 pixels", img_np.shape)
  58. new_i = random.randint(0, num-1)
  59. if i != new_i:
  60. p = paths[new_i]
  61. img_label_list = get_img_label(img_np, size=(height, width), cls_num=cls_num)
  62. random.shuffle(img_label_list)
  63. if is_test:
  64. img_label_list = random.sample(img_label_list, 1)
  65. for c in range(cls_num):
  66. if c >= len(img_label_list):
  67. break
  68. img = img_label_list[c][0]
  69. img_np_list.append(img)
  70. # 模糊
  71. if_blur = random.choice([0, 1])
  72. # print(if_blur, img_label_list[c][1])
  73. if if_blur:
  74. # 高斯模糊
  75. sigmaX = random.randint(1, 2)
  76. sigmaY = random.randint(1, 2)
  77. img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY)
  78. # cv2.imshow("gen", img)
  79. # cv2.waitKey(0)
  80. # print("gen image size", img.shape)
  81. # label
  82. label_list = [0]*cls_num
  83. label_list[img_label_list[c][1]] = 1
  84. label = np.array(label_list)
  85. # print(p, img_label_list[c][1])
  86. X[j+c] = img
  87. Y[j+c] = label
  88. # print("X.shape", X.shape)
  89. if is_test:
  90. yield X, Y, img_np_list
  91. else:
  92. yield X, Y
  93. def get_image_from_pdf():
  94. paths = glob("C:/Users/Administrator/Desktop/test_pdf/*")
  95. save_dir = "D:/Project/image_direction_classification/data/1/"
  96. i = 0
  97. for path in paths:
  98. try:
  99. doc = fitz.open(path)
  100. output_image_dict = {}
  101. page_count = doc.page_count
  102. for page_no in range(page_count):
  103. try:
  104. page = doc.loadPage(page_no)
  105. output = save_dir + "pdf_" + str(i) + ".png"
  106. i += 1
  107. rotate = int(0)
  108. # 每个尺寸的缩放系数为1.3,这将为我们生成分辨率提高2.6的图像。
  109. # 此处若是不做设置,默认图片大小为:792X612, dpi=96
  110. # (1.33333333 --> 1056x816) (2 --> 1584x1224)
  111. # (1.183, 2.28 --> 1920x1080)
  112. zoom_x = 1.3
  113. zoom_y = 1.3
  114. mat = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
  115. pix = page.getPixmap(matrix=mat, alpha=False)
  116. pix.writePNG(output)
  117. except:
  118. continue
  119. except Exception as e:
  120. print("pdf2Image", traceback.print_exc())
  121. continue
  122. if __name__ == '__main__':
  123. get_img_label("data/0/7248_fe52d616989e19e6967e0461ef19b149.jpg", (640, 640))
  124. # get_image_from_pdf()