123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Tue Jun 21 10:53:51 2022
- pre_process.py
- @author: fangjiasheng
- """
- import copy
- import json
- import base64
- import math
- import os
- import random
- import sys
- import traceback
- from glob import glob
- import numpy as np
- import six
- import cv2
- from PIL import Image
- import fitz
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from idc.utils import pil_resize, pil_rotate
- Image.MAX_IMAGE_PIXELS = 2300000000
- def get_img_label(img_np, size, cls_num=4):
- height, width = size
- h, w = get_best_predict_size2(img_np, threshold=1080)
- img_np = pil_resize(img_np, h, w)
- # cv2.namedWindow("origin", 0)
- # cv2.resizeWindow("origin", 1000, 800)
- # cv2.imshow("origin", img_np)
- # 获取合适的文字区域
- result_list, img_np = get_text_region(img_np, size)
- # print(len(result_list), img_np.shape)
- if not result_list:
- return []
- if img_np.shape[0] != height or img_np.shape[1] != width:
- img_np = pil_resize(img_np, height, width)
- # 生成旋转后的图片及其角度
- img_label_list = [[img_np, 0]]
- # 图片旋转
- angle_first = int(360/cls_num)
- i = 1
- for angle in range(angle_first, 360, angle_first):
- img_rotate = pil_rotate(img_np, angle)
- img_label_list.append([img_rotate, i])
- i += 1
- # for _img, _label in img_label_list:
- # cv2.imshow("img", _img)
- # cv2.waitKey(0)
- return img_label_list
- def get_text_region(img_np, size):
- img_np = remove_black_border(img_np)
- origin_h, origin_w = img_np.shape[:2]
- gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
- h, w = get_best_predict_size2(img_np, threshold=640)
- img_np = pil_resize(img_np, h, w)
- # 1. 转化成灰度图
- img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
- # 2. 二值化
- ret, binary = cv2.threshold(img_np, 125, 255, cv2.THRESH_BINARY_INV)
- # 3. 膨胀和腐蚀操作的核函数
- kernal = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
- # 4. 膨胀一次,让轮廓突出
- dilation = cv2.dilate(binary, kernal, iterations=2)
- dilation1 = copy.deepcopy(dilation)
- dilation1 = np.add(np.int0(np.full(dilation1.shape, 255)), -1 * np.int0(dilation))
- dilation1 = np.uint8(dilation1)
- # cv2.imshow("dilation1", dilation1)
- # 5. 腐蚀一次,去掉细节,如表格线等。注意这里去掉的是竖直的线
- erosion = cv2.erode(dilation, kernal, iterations=1)
- dilation = cv2.dilate(erosion, kernal, iterations=3)
- # 颜色反转
- dilation = np.add(np.int0(np.full(dilation.shape, 255)), -1 * np.int0(dilation))
- dilation = np.uint8(dilation)
- # 1. 查找轮廓
- contours, hierarchy = cv2.findContours(dilation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
- region_list = []
- for cont in contours:
- rect = cv2.minAreaRect(cont)
- # box是四个点的坐标
- box = cv2.boxPoints(rect)
- box = np.int0(box)
- box = box.tolist()
- new_box = [[10000, 10000], [0, 0]]
- for p in box:
- if p[0] < new_box[0][0]:
- new_box[0][0] = p[0]
- elif p[0] > new_box[1][0]:
- new_box[1][0] = p[0]
- if p[1] < new_box[0][1]:
- new_box[0][1] = p[1]
- elif p[1] > new_box[1][1]:
- new_box[1][1] = p[1]
- # box.sort(key=lambda x: (x[0], x[1]))
- # if box[0][0] > box[3][0]:
- # temp = box[0][0]
- # box[0][0] = box[3][0]
- # box[3][0] = temp
- # if box[0][1] > box[3][1]:
- # temp = box[0][1]
- # box[0][1] = box[3][1]
- # box[3][1] = temp
- region_list.append(new_box)
- dilation = np.expand_dims(dilation, axis=-1)
- dilation = np.concatenate([dilation, dilation, dilation], axis=-1)
- # for box in region_list:
- # # cv2.drawContours(dilation, [box], 0, (0, 255, 0), 2)
- # cv2.rectangle(dilation, (box[0][0], box[0][1]), (box[3][0], box[3][1]), (0, 255, 0), 2)
- region_list.sort(key=lambda x: abs((x[1][0] - x[0][0])*(x[1][1] - x[0][1])), reverse=True)
- # print("len(region_list)", len(region_list))
- # 筛选文字区域
- result_list = []
- h_scale = origin_h / h
- w_scale = origin_w / w
- for box in region_list:
- # if i >= 20:
- # break
- p1 = box[0]
- p2 = box[1]
- # print(p1, p2, abs((p2[1]-p1[1])*(p2[1]-p1[1])), h*w)
- # 旋转的box忽略
- if p1[0] >= p2[0] or p1[1] >= p2[1]:
- # print(box)
- # print(p1[0], ">", p2[0], p1[1], ">", p2[1])
- continue
- # 太大的box忽略
- if abs(p2[0] - p1[0]) >= 0.7 * w and abs(p2[1] - p1[1]) >= 0.7 * h:
- # print("too large", abs(p2[0] - p1[0]), abs(p2[1] - p1[1]), 0.7 * w, 0.7 * h)
- continue
- # 黑色点不够的忽略
- cnt_black = count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]) / (abs(p2[0] - p1[0])*abs(p2[1] - p1[1]))
- # if cnt_black < abs(p2[0] - p1[0])*abs(p2[1] - p1[1])*0.1:
- # print("black not enough")
- # continue
- # if not count_black(dilation1[p1[1]:p2[1], p1[0]:p2[0]]):
- # # print("black not enough")
- # continue
- p1[1] = int(p1[1] * h_scale)
- p1[0] = int(p1[0] * w_scale)
- p2[1] = int(p2[1] * h_scale)
- p2[0] = int(p2[0] * w_scale)
- result_list.append([p1, p2, cnt_black])
- # cv2.imshow("result", dilation)
- # cv2.waitKey(0)
- if not result_list:
- return [], None
- result_list.sort(key=lambda x: x[2], reverse=True)
- # for r in result_list:
- # print(r)
- # 裁剪
- # if top_n > 1:
- # result = random.sample(result_list, 1)[0]
- # height_start = result[0][1]
- # width_start = result[0][0]
- # else:
- height_start = result_list[0][0][1]
- width_start = result_list[0][0][0]
- height, width = size
- gray = gray[height_start:height_start+height, width_start:width_start+width]
- # cv2.imshow("gray", gray)
- # cv2.waitKey(0)
- return result_list, gray
- def gen(paths, batch_size=2, shape=(640, 640), cls_num=4, is_test=False):
- def choose(_paths, _i):
- while True:
- if _i >= len(_paths):
- _i = 0
- np.random.shuffle(paths)
- _p = _paths[_i]
- # read error
- try:
- _img_np = cv2.imread(_p)
- # limit h, w > 150
- if _img_np.shape[0] <= 150 or _img_np.shape[1] <= 150:
- _i += 1
- continue
- # limit pixels 89478485
- if _img_np.shape[0] * _img_np.shape[1] * _img_np.shape[2] >= 89478485:
- _i += 1
- continue
- _img_label_list = get_img_label(_img_np, size=(height, width), cls_num=cls_num)
- if not _img_label_list:
- _i += 1
- continue
- except:
- _i += 1
- continue
- _i += 1
- return _img_label_list, _i
- num = len(paths)
- i = 0
- all_cnt = 0
- while True:
- height, width = shape
- if is_test:
- X = np.zeros((batch_size, height, width, 1))
- Y = np.zeros((batch_size, cls_num))
- else:
- X = np.zeros((batch_size, height, width, 1))
- Y = np.zeros((batch_size, cls_num))
- img_np_list = []
- batch_list = []
- if batch_size % cls_num != 0:
- print("batch_size % cls_num != 0")
- raise
- for j in range(batch_size//cls_num):
- # 生成标注数据
- img_label_list, i = choose(paths, i)
- random.shuffle(img_label_list)
- if is_test:
- img_label_list = random.sample(img_label_list, 1)
- for c in range(cls_num):
- if c >= len(img_label_list):
- break
- img = img_label_list[c][0]
- img_np_list.append(img)
- # 模糊
- if_blur = random.choice([0, 0])
- if if_blur:
- # 高斯模糊
- sigmaX = random.randint(1, 2)
- sigmaY = random.randint(1, 2)
- img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY)
- # label
- label_list = [0]*cls_num
- label_list[img_label_list[c][1]] = 1
- label = np.array(label_list)
- if len(img.shape) < 3:
- img = np.expand_dims(img, axis=-1)
- img = np.expand_dims(img[:, :, 0], axis=-1)
- batch_list.append([img, label])
- # X[j+c] = img
- # Y[j+c] = label
- random.shuffle(batch_list)
- for j in range(0, batch_size, cls_num):
- for c in range(cls_num):
- img, label = batch_list[j+c]
- # print("label", label)
- # cv2.imshow("gen", img)
- # cv2.waitKey(0)
- # cv2.imwrite("data/3/"+str(all_cnt)+"_"+str(label)+".jpg", img)
- all_cnt += 1
- X[j+c] = img
- Y[j+c] = label
- if is_test:
- yield X, Y, img_np_list
- else:
- yield X, Y
- def get_image_from_pdf():
- paths = glob("C:/Users/Administrator/Desktop/test_pdf/*")
- save_dir = "D:/Project/image_direction_classification/data/1/"
- i = 0
- for path in paths:
- try:
- doc = fitz.open(path)
- output_image_dict = {}
- page_count = doc.page_count
- for page_no in range(page_count):
- try:
- page = doc.loadPage(page_no)
- output = save_dir + "pdf_" + str(i) + ".png"
- i += 1
- rotate = int(0)
- # 每个尺寸的缩放系数为1.3,这将为我们生成分辨率提高2.6的图像。
- # 此处若是不做设置,默认图片大小为:792X612, dpi=96
- # (1.33333333 --> 1056x816) (2 --> 1584x1224)
- # (1.183, 2.28 --> 1920x1080)
- zoom_x = 1.3
- zoom_y = 1.3
- mat = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
- pix = page.getPixmap(matrix=mat, alpha=False)
- pix.writePNG(output)
- except:
- continue
- except Exception as e:
- print("pdf2Image", traceback.print_exc())
- continue
- def get_best_predict_size2(image_np, threshold=640):
- h, w = image_np.shape[:2]
- scale = threshold / max(h, w)
- h = int(math.ceil(h * scale))
- w = int(math.ceil(w * scale))
- return h, w
- def count_black(image_np):
- try:
- if len(image_np.shape) == 3:
- lower_black = np.array([0, 0, 0])
- upper_black = np.array([10, 10, 10])
- else:
- lower_black = np.array([0])
- upper_black = np.array([10])
- mask = cv2.inRange(image_np, lower_black, upper_black)
- cnt = np.sum(mask != 0)
- return cnt
- # print("black count", cnt, image_np.shape[0]*image_np.shape[1])
- # if cnt >= image_np.shape[0]*image_np.shape[1]*0.3:
- # return True
- # else:
- # return False
- except:
- return 0
- def remove_black_border(img_np):
- try:
- # 阈值
- threshold = 100
- # 转换为灰度图像
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
- # 获取图片尺寸
- h, w = gray.shape[:2]
- # 无法区分黑色区域超过一半的情况
- rowc = gray[:, int(1/2*w)]
- colc = gray[int(1/2*h), :]
- rowflag = np.argwhere(rowc > threshold)
- colflag = np.argwhere(colc > threshold)
- left, bottom, right, top = rowflag[0, 0], colflag[-1, 0], rowflag[-1, 0], colflag[0, 0]
- # cv2.imshow('remove_black_border', img_np[left:right, top:bottom, :])
- # cv2.waitKey()
- return img_np[left:right, top:bottom, :]
- except:
- return img_np
- if __name__ == '__main__':
- # get_img_label("data/0/7248_fe52d616989e19e6967e0461ef19b149.jpg", (640, 640))
- # get_image_from_pdf()
- # paths = glob("C:\\Users\\Administrator\\Desktop\\test_image\\*")
- # for path in paths:
- # get_text_region(cv2.imread(path))
- # path = "C:\\Users\\Administrator\\Desktop\\test_image\\error17.jpg"
- # get_img_label(cv2.imread(path), (192, 192))
- table_image_path = "data/0/*"
- pdf_image_path = "data/1/*"
- no_table_image_path = 'data/2/*'
- paths = glob(table_image_path) + glob(pdf_image_path) + glob(no_table_image_path)
- _gen = gen(paths, batch_size=32, shape=(192, 192))
- for g in _gen:
- print("ok")
- break
|