#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Jun 21 10:53:51 2022 pre_process.py @author: fangjiasheng """ import copy import json import base64 import math import os import random import sys import traceback from glob import glob import numpy as np import six import cv2 from PIL import Image import fitz sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from idc.utils import pil_resize, pil_rotate Image.MAX_IMAGE_PIXELS = 2300000000 def get_img_label(img_np, size, cls_num=4): height, width = size h, w = get_best_predict_size2(img_np, threshold=1080) img_np = pil_resize(img_np, h, w) # cv2.namedWindow("origin", 0) # cv2.resizeWindow("origin", 1000, 800) # cv2.imshow("origin", img_np) # 获取合适的文字区域 result_list, img_np = get_text_region(img_np, size) # print(len(result_list), img_np.shape) if not result_list: return [] if img_np.shape[0] != height or img_np.shape[1] != width: img_np = pil_resize(img_np, height, width) # 生成旋转后的图片及其角度 img_label_list = [[img_np, 0]] # 图片旋转 angle_first = int(360/cls_num) i = 1 for angle in range(angle_first, 360, angle_first): img_rotate = pil_rotate(img_np, angle) img_label_list.append([img_rotate, i]) i += 1 # for _img, _label in img_label_list: # cv2.imshow("img", _img) # cv2.waitKey(0) return img_label_list def get_text_region(img_np, size): img_np = remove_black_border(img_np) origin_h, origin_w = img_np.shape[:2] gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) h, w = get_best_predict_size2(img_np, threshold=640) img_np = pil_resize(img_np, h, w) # 1. 转化成灰度图 img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) # 2. 二值化 ret, binary = cv2.threshold(img_np, 125, 255, cv2.THRESH_BINARY_INV) # 3. 膨胀和腐蚀操作的核函数 kernal = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) # 4. 膨胀一次,让轮廓突出 dilation = cv2.dilate(binary, kernal, iterations=2) dilation1 = copy.deepcopy(dilation) dilation1 = np.add(np.int0(np.full(dilation1.shape, 255)), -1 * np.int0(dilation)) dilation1 = np.uint8(dilation1) # cv2.imshow("dilation1", dilation1) # 5. 腐蚀一次,去掉细节,如表格线等。注意这里去掉的是竖直的线 erosion = cv2.erode(dilation, kernal, iterations=1) dilation = cv2.dilate(erosion, kernal, iterations=3) # 颜色反转 dilation = np.add(np.int0(np.full(dilation.shape, 255)), -1 * np.int0(dilation)) dilation = np.uint8(dilation) # 1. 查找轮廓 contours, hierarchy = cv2.findContours(dilation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) region_list = [] for cont in contours: rect = cv2.minAreaRect(cont) # box是四个点的坐标 box = cv2.boxPoints(rect) box = np.int0(box) box = box.tolist() new_box = [[10000, 10000], [0, 0]] for p in box: if p[0] < new_box[0][0]: new_box[0][0] = p[0] elif p[0] > new_box[1][0]: new_box[1][0] = p[0] if p[1] < new_box[0][1]: new_box[0][1] = p[1] elif p[1] > new_box[1][1]: new_box[1][1] = p[1] # box.sort(key=lambda x: (x[0], x[1])) # if box[0][0] > box[3][0]: # temp = box[0][0] # box[0][0] = box[3][0] # box[3][0] = temp # if box[0][1] > box[3][1]: # temp = box[0][1] # box[0][1] = box[3][1] # box[3][1] = temp region_list.append(new_box) dilation = np.expand_dims(dilation, axis=-1) dilation = np.concatenate([dilation, dilation, dilation], axis=-1) # for box in region_list: # # cv2.drawContours(dilation, [box], 0, (0, 255, 0), 2) # cv2.rectangle(dilation, (box[0][0], box[0][1]), (box[3][0], box[3][1]), (0, 255, 0), 2) region_list.sort(key=lambda x: abs((x[1][0] - x[0][0])*(x[1][1] - x[0][1])), reverse=True) # print("len(region_list)", len(region_list)) # 筛选文字区域 result_list = [] h_scale = origin_h / h w_scale = origin_w / w for box in region_list: # if i >= 20: # break p1 = box[0] p2 = box[1] # print(p1, p2, abs((p2[1]-p1[1])*(p2[1]-p1[1])), h*w) # 旋转的box忽略 if p1[0] >= p2[0] or p1[1] >= p2[1]: # print(box) # print(p1[0], ">", p2[0], p1[1], ">", p2[1]) continue # 太大的box忽略 if abs(p2[0] - p1[0]) >= 0.7 * w and abs(p2[1] - p1[1]) >= 0.7 * h: # print("too large", abs(p2[0] - p1[0]), abs(p2[1] - p1[1]), 0.7 * w, 0.7 * h) continue # 黑色点不够的忽略 cnt_black = count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]) / (abs(p2[0] - p1[0])*abs(p2[1] - p1[1])) # if cnt_black < abs(p2[0] - p1[0])*abs(p2[1] - p1[1])*0.1: # print("black not enough") # continue # if not count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]): # # print("black not enough") # continue p1[1] = int(p1[1] * h_scale) p1[0] = int(p1[0] * w_scale) p2[1] = int(p2[1] * h_scale) p2[0] = int(p2[0] * w_scale) result_list.append([p1, p2, cnt_black]) # cv2.imshow("result", dilation) # cv2.waitKey(0) if not result_list: return [], None result_list.sort(key=lambda x: x[2], reverse=True) # for r in result_list: # print(r) # 裁剪 # if top_n > 1: # result = random.sample(result_list, 1)[0] # height_start = result[0][1] # width_start = result[0][0] # else: height_start = result_list[0][0][1] width_start = result_list[0][0][0] height, width = size gray = gray[height_start:height_start+height, width_start:width_start+width] # cv2.imshow("gray", gray) # cv2.waitKey(0) return result_list, gray def gen(paths, batch_size=2, shape=(640, 640), cls_num=4, is_test=False): def choose(_paths, _i): while True: if _i >= len(_paths): _i = 0 np.random.shuffle(paths) _p = _paths[_i] # read error try: _img_np = cv2.imread(_p) # limit h, w > 150 if _img_np.shape[0] <= 150 or _img_np.shape[1] <= 150: _i += 1 continue # limit pixels 89478485 if _img_np.shape[0] * _img_np.shape[1] * _img_np.shape[2] >= 89478485: _i += 1 continue _img_label_list = get_img_label(_img_np, size=(height, width), cls_num=cls_num) if not _img_label_list: _i += 1 continue except: _i += 1 continue _i += 1 return _img_label_list, _i num = len(paths) i = 0 all_cnt = 0 while True: height, width = shape if is_test: X = np.zeros((batch_size, height, width, 1)) Y = np.zeros((batch_size, cls_num)) else: X = np.zeros((batch_size, height, width, 1)) Y = np.zeros((batch_size, cls_num)) img_np_list = [] batch_list = [] if batch_size % cls_num != 0: print("batch_size % cls_num != 0") raise for j in range(batch_size//cls_num): # 生成标注数据 img_label_list, i = choose(paths, i) random.shuffle(img_label_list) if is_test: img_label_list = random.sample(img_label_list, 1) for c in range(cls_num): if c >= len(img_label_list): break img = img_label_list[c][0] img_np_list.append(img) # 模糊 if_blur = random.choice([0, 0]) if if_blur: # 高斯模糊 sigmaX = random.randint(1, 2) sigmaY = random.randint(1, 2) img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY) # label label_list = [0]*cls_num label_list[img_label_list[c][1]] = 1 label = np.array(label_list) if len(img.shape) < 3: img = np.expand_dims(img, axis=-1) img = np.expand_dims(img[:, :, 0], axis=-1) batch_list.append([img, label]) # X[j+c] = img # Y[j+c] = label random.shuffle(batch_list) for j in range(0, batch_size, cls_num): for c in range(cls_num): img, label = batch_list[j+c] # print("label", label) # cv2.imshow("gen", img) # cv2.waitKey(0) # cv2.imwrite("data/3/"+str(all_cnt)+"_"+str(label)+".jpg", img) all_cnt += 1 X[j+c] = img Y[j+c] = label if is_test: yield X, Y, img_np_list else: yield X, Y def get_image_from_pdf(): paths = glob("C:/Users/Administrator/Desktop/test_pdf/*") save_dir = "D:/Project/image_direction_classification/data/1/" i = 0 for path in paths: try: doc = fitz.open(path) output_image_dict = {} page_count = doc.page_count for page_no in range(page_count): try: page = doc.loadPage(page_no) output = save_dir + "pdf_" + str(i) + ".png" i += 1 rotate = int(0) # 每个尺寸的缩放系数为1.3,这将为我们生成分辨率提高2.6的图像。 # 此处若是不做设置,默认图片大小为:792X612, dpi=96 # (1.33333333 --> 1056x816) (2 --> 1584x1224) # (1.183, 2.28 --> 1920x1080) zoom_x = 1.3 zoom_y = 1.3 mat = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate) pix = page.getPixmap(matrix=mat, alpha=False) pix.writePNG(output) except: continue except Exception as e: print("pdf2Image", traceback.print_exc()) continue def get_best_predict_size2(image_np, threshold=640): h, w = image_np.shape[:2] scale = threshold / max(h, w) h = int(math.ceil(h * scale)) w = int(math.ceil(w * scale)) return h, w def count_black(image_np): try: if len(image_np.shape) == 3: lower_black = np.array([0, 0, 0]) upper_black = np.array([10, 10, 10]) else: lower_black = np.array([0]) upper_black = np.array([10]) mask = cv2.inRange(image_np, lower_black, upper_black) cnt = np.sum(mask != 0) return cnt # print("black count", cnt, image_np.shape[0]*image_np.shape[1]) # if cnt >= image_np.shape[0]*image_np.shape[1]*0.3: # return True # else: # return False except: return 0 def remove_black_border(img_np): try: # 阈值 threshold = 100 # 转换为灰度图像 gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) # 获取图片尺寸 h, w = gray.shape[:2] # 无法区分黑色区域超过一半的情况 rowc = gray[:, int(1/2*w)] colc = gray[int(1/2*h), :] rowflag = np.argwhere(rowc > threshold) colflag = np.argwhere(colc > threshold) left, bottom, right, top = rowflag[0, 0], colflag[-1, 0], rowflag[-1, 0], colflag[0, 0] # cv2.imshow('remove_black_border', img_np[left:right, top:bottom, :]) # cv2.waitKey() return img_np[left:right, top:bottom, :] except: return img_np if __name__ == '__main__': # get_img_label("data/0/7248_fe52d616989e19e6967e0461ef19b149.jpg", (640, 640)) # get_image_from_pdf() # paths = glob("C:\\Users\\Administrator\\Desktop\\test_image\\*") # for path in paths: # get_text_region(cv2.imread(path)) # path = "C:\\Users\\Administrator\\Desktop\\test_image\\error17.jpg" # get_img_label(cv2.imread(path), (192, 192)) table_image_path = "data/0/*" pdf_image_path = "data/1/*" no_table_image_path = 'data/2/*' paths = glob(table_image_path) + glob(pdf_image_path) + glob(no_table_image_path) _gen = gen(paths, batch_size=32, shape=(192, 192)) for g in _gen: print("ok") break