import random import re import numpy as np import cv2 # import psycopg2 from PIL import Image, ImageFont, ImageDraw import os from PIL import Image, ImageFont, ImageDraw import pandas as pd # project_path = "D:\\Project\\PaddleOCR-release-2.0\\" from bs4 import BeautifulSoup project_path = "../../" image_output_path = project_path + "train_data/bidi_data/mix_data4/" train_data_path = image_output_path + "rec_gt_train.txt" test_data_path = image_output_path + "rec_gt_test.txt" def create_image(data_dir, file_name, text): # list1 = re.findall('[a-zA-Z\d]', text) # list2 = re.findall('[\u4e00-\u9fa5。,!?¥《》【】’“:;·、()]', text) # list3 = re.findall('[,.!?&@*+=~%()#<>-''|/:{}$;]', text) # english_len = len(list1) # chinese_len = len(list2) # character_len = len(list3) # # if english_len + chinese_len + character_len == 0: # character_len = len(text) # 根据各字体大小生成图片 # font 10 : a1-6 字-10 image-len*, 16 # font 20 : a1-12 字-20 image-len*, 32 font_list = [7, 8, 9, 10, 11, 12, 15, 20, 25, 30, 35, 40] # 随机选字体大小 font_index = random.randint(0, len(font_list)-1) font = font_list[font_index] # 根据字体大小计算各字符长度 # chinese_charc_len = font * 1 # english_charc_len = int(font * 0.7) # number_charc_len = int(font * 0.3) # image_width = int(font * 1.6) # text_len = english_len * english_charc_len + chinese_len * chinese_charc_len \ # + character_len * number_charc_len # 获取字体及其实际像素大小 font_type = random.choice(['msyh.ttc', 'msyhbd.ttc', 'msyhl.ttc']) font = ImageFont.truetype("tools/fonts/"+font_type, font) font_width = font.getsize(text)[0] font_height = font.getsize(text)[1] # 增加上下左右边距 margin_h = 0 margin_w = 0 if font_height > 9: if random.choice([0, 1]): margin_h = random.randint(3, 6) margin_w = random.randint(3, 6) im = Image.new("RGB", (font_width+margin_w*2, font_height+margin_h*2), (255, 255, 255)) dr = ImageDraw.Draw(im) dr.text((0+margin_w, 0+margin_h), text, font=font, fill="#000000") # if random.choice([0, 1, 1]): # # PIL -> CV2 # img = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR) # # 文字加下划线 # img = cv2.line(img, (0+margin_w, font_height+margin_h), # (font_width+margin_w, font_height+margin_h), # (0, 0, 0), 1) # # CV2 -> PIL # im = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # # # 随机缩放 # resize_y = random.randint(1, 2) # resize_x = random.randint(1, 2) # img = cv2.resize(img, (img.shape[1]*resize_y, img.shape[0]*resize_x)) # # # 模糊 # # 高斯模糊 # sigmaX = random.randint(1, 3) # sigmaY = random.randint(1, 3) # img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY) # # # resize_y = random.uniform(1, 3) # # resize_x = random.uniform(1, 3) # # img = im.resize((int(im.size[0]*resize_y), int(im.size[1]*resize_x)), Image.ANTIALIAS) # # # 保存 # # cv2.imwrite(data_dir + file_name, img) # 图像增强 im = my_image_aug(im) # im.show("img") im.save(data_dir + file_name) def my_image_aug(image_pil): # 图像增强 # PIL -> CV2 img = cv2.cvtColor(np.asarray(image_pil), cv2.COLOR_RGB2BGR) # 分辨率低的跳过 if img.shape[0] > 20: # 随机缩放 if random.choice([0, 1]): resize_y = random.randint(1, 3) resize_x = random.randint(1, 3) img = cv2.resize(img, (img.shape[1]*resize_y, img.shape[0]*resize_x)) if_dilate = random.choice([0, 1, 1, 2, 2, 2]) if if_dilate == 2: gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) iters = 1 if img.shape[0] > 40: iters = 2 dilate = cv2.dilate(gray, kernel, iterations=iters) img = cv2.cvtColor(dilate, cv2.COLOR_GRAY2BGR) elif if_dilate == 1: # 高斯模糊 sigmaX = random.randint(1, 2) sigmaY = random.randint(1, 2) img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY) # CV2 -> PIL im = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) return im def create_orgs_image(df): df = df[:1000] label_file_train = project_path + "train_data\\bidi_data\\orgs_data\\rec_gt_train.txt" label_file_test = project_path + "train_data\\bidi_data\\orgs_data\\rec_gt_test.txt" image_output_path = project_path + "train_data\\bidi_data\\orgs_data\\" f1 = open(label_file_train, "w") f2 = open(label_file_test, "w") print(df.shape) for index, row in df.iterrows(): text = row["name"] # text = "晋江滨江国家体育训练基地有限公司" im = Image.new("RGB", (len(text)*10, 16), (255, 255, 255)) dr = ImageDraw.Draw(im) font = ImageFont.truetype(os.path.join(os.getcwd(), "fonts", "msyh.ttc"), 10) dr.text((0, 0), text, font=font, fill="#000000") # im.show() if index / df.shape[0] <= 0.8: mode = "train" f = f1 else: mode = "test" f = f2 im.save(image_output_path + mode + "\\" + "text_" + str(index) + ".jpg") f.write(mode + "/text_" + str(index) + ".jpg" + "\t" + text + "\n") f1.close() f2.close() def create_longSentence_image(df): # df = df[:3000] label_file_train = project_path + "train_data\\bidi_data\\longSentence_data\\rec_gt_train.txt" label_file_test = project_path + "train_data\\bidi_data\\longSentence_data\\rec_gt_test.txt" image_output_path = project_path + "train_data\\bidi_data\\longSentence_data\\" f1 = open(label_file_train, "w") f2 = open(label_file_test, "w") print(df.shape) for index, row in df.iterrows(): text = row["text"] # text = "晋江滨江国家体育训练基地有限公司" im = Image.new("RGB", (len(text)*10, 16), (255, 255, 255)) dr = ImageDraw.Draw(im) font = ImageFont.truetype(os.path.join(os.getcwd(), "fonts", "msyh.ttc"), 10) dr.text((0, 0), text, font=font, fill="#000000") # im.show() if index <= int((df.shape[0]-1)*0.8): mode = "train" f = f1 else: mode = "test" f = f2 im.save(image_output_path + mode + "\\" + "text_" + str(index) + ".jpg") f.write(mode + "/text_" + str(index) + ".jpg" + "\t" + text + "\n") f1.close() f2.close() # def readPostgreSQL(): # conn_string = "host=192.168.2.101 port=5432 dbname=iepy " \ # "user=iepy_read password=iepy_read" # conn = psycopg2.connect(conn_string) # # # 执行SQL语句 # sql = "select text from corpus_iedocument " \ # "where jump_signal=0" # df = pd.read_sql(sql, conn) # return df # 生成多个场景混合数据 def create_mix_txt(): # 最长字符串长度 max_length = 100 # list1 = create_text_list(max_length) list1 = create_price(3000000) print("finish get list1", len(list1)) # list2 = create_org_list() list2 = get_long_sentence_from_file(1000000) print("finish get list2", len(list2)) # list2 = list2[0:100] with open("appendix_text.txt", "r") as f: list3 = f.readlines() list3 = list3[:2000000] print("finish get list3", len(list3)) list4 = create_org_list() list4 = list4[:3000000] print("finish get list4", len(list4)) train_data = list1[0:int(len(list1)*0.95)] + list2[0:int(len(list2)*0.95)] + \ list3[0:int(len(list3)*0.95)] + list4[0:int(len(list4)*0.95)] test_data = list1[int(len(list1)*0.95):] + list2[int(len(list2)*0.95):] + \ list3[int(len(list3)*0.95):] + list4[int(len(list4)*0.95):] print("len(train_data)", len(train_data)) print("len(test_data)", len(test_data)) data_index = 0 with open(train_data_path, "w") as f: for data in train_data: prefix = "train/text_" + str(data_index) + ".jpg" + "\t" data = prefix + data f.write(data) data_index += 1 print("finish write train data") with open(test_data_path, "w") as f: for data in test_data: prefix = "test/text_" + str(data_index) + ".jpg" + "\t" data = prefix + data f.write(data) data_index += 1 print("finish write test data") return # def create_text_list(max_length): # # 招投标文章语句 # df1 = readPostgreSQL() # list1 = [] # for index, row in df1.iterrows(): # text = row["text"].split(",") # # print(len(text)) # # # 每篇文章最多取10个句子 # max_sentence = 15 # sentence_count = 0 # while sentence_count < max_sentence: # if len(text) <= max_sentence: # if sentence_count < len(text): # sentence = text[sentence_count] # else: # break # else: # r1 = random.randint(0, len(text) - 1) # sentence = text[r1] # if len(sentence) > max_length: # # 限制字数,随机截取前或后 # r2 = random.randint(0, 1) # if r2: # sentence = sentence[:max_length] # else: # sentence = sentence[-max_length:] # # # sentence = re.sub("\n", "", sentence) # if sentence != "": # list1.append(sentence+"\n") # sentence_count += 1 # print("len(list1)", len(list1)) # return list1 def delete_image(data_dir, file_name): if os.path.exists(data_dir + file_name): os.remove(data_dir + file_name) def create_org_list(): # 1kw公司名 with open("C:\\Users\\Administrator\\Desktop\\LEGAL_ENTERPRISE.txt", "r") as f: list2 = f.readlines() # list2 = list2[:100] # print("len(list2)", len(list2)) return list2 def create_number_list(number): no_list = [] for i in range(number): # 随机选择生成几位小数 decimal_place = random.choices([0, 1, 2, 3, 4, 5, 6])[0] if decimal_place == 0: no = random.randint(0, 10000000) else: no = random.uniform(0, 10000) no = round(no, decimal_place) no_list.append(str(no)+"\n") # print(no_list) return no_list def get_mix_data_from_file(number): with open("../../train_data/bidi_data/mix_data/rec_gt_train.txt") as f: _list = f.readlines() _list = _list[:number] new_list = [] for line in _list: s = line.split("\t")[1] new_list.append(s) # print(new_list) return new_list def get_long_sentence_from_file(number): with open("../../train_data/bidi_data/longSentence_data/rec_gt_train.txt") as f: list1 = f.readlines() with open("../../train_data/bidi_data/longSentence_data/rec_gt_test.txt") as f: list2 = f.readlines() _list = list1 + list2 _list = _list[:number] new_list = [] for line in _list: s = line.split("\t")[1] new_list.append(s) # print(new_list) return new_list def get_data_from_appendix(): df = pd.read_excel("dochtmlcon.xlsx") df = df text_list = [] for index, row in df.iterrows(): html_text = row["dochtmlcon"] # 创建一个BeautifulSoup解析对象 soup = BeautifulSoup(html_text, "html.parser", from_encoding="utf-8") # 获取所有的链接 appendix_text = soup.find_all('div', class_='richTextFetch') # print(str(appendix_text[0])[49:-6]) appendix_text = str(appendix_text[0])[49:-6] ss = appendix_text.split("\n") for s in ss: text = re.sub(" ", "", s) text = re.sub("\t", "", text) if s == "": continue text_list.append(text + "\n") with open("appendix_text.txt", "w") as f: f.writelines(text_list) return def get_data_from_paddle(): path = "D:\\DataSet\\" with open(path + "char.txt", "r") as f: dictionary = f.readlines() with open(path + "data_train.txt") as f: train_list = f.readlines() with open(path + "data_test.txt") as f: test_list = f.readlines() data_list = train_list + test_list # data_list = data_list[-100:] text_list = [] for data in data_list: ss = data[:-1].split(" ") image_path = "image/" + ss[0] text = "" for num in ss[1:]: char = dictionary[int(num)][:-1] text += char if text == "": print("no text!") continue text_list.append(image_path + "\t" + text + "\n") with open("paddle_data.txt", "w") as f: f.writelines(text_list) def create_number_list2(number): no_list = [] for i in range(number): c1 = random.choice([0, 1, 1]) if c1: no = random.randint(0, 10) else: no = random.randint(10, 100) no_list.append(str(no) + "\n") # print(no_list) return no_list def create_price(number): no_list = [] for i in range(number): # 选择小数整数 c1 = random.choice([0, 1, 1]) if c1: no = random.randint(10, 1000000000) no = str(no) # 加3位分割逗号 # 选择中英文逗号 c2 = random.choice([',', ',', ',']) for i in range(len(no)-3, 0, -3): no = no[:i] + c2 + no[i:] else: no = random.uniform(10, 1000000000) no = str(no) nos = no.split(".") no_1 = nos[0] no_2 = nos[1] # 加3位分割逗号 # 选择中英文逗号 c2 = random.choice([',', ',', ',']) for i in range(len(no_1)-3, 0, -3): no_1 = no_1[:i] + c2 + no_1[i:] no = no_1 + "." + no_2 # 选择是否加¥符号 c3 = random.choice(['', "¥"]) no_list.append(c3 + str(no) + "\n") # print(no_list) return no_list if __name__ == '__main__': # df = pd.read_csv('C:\\Users\\Administrator\\Desktop\\orgs.csv') # create_longSentence_image(df) # s = 100 # image = cv2.imread("text_384178.jpg") # print(image.shape) # list1 = create_text_list(100) # for l in list1: # print(l) # print(len(list1)) create_mix_txt() # get_mix_data_from_file(2) # create_number_list(10) # with open("../../train_data/bidi_data/orgs_data/rec_gt_test.txt", "r") as f: # _list = f.readlines() # for line in _list: # _str = line.split("\t")[-1][:-1] # print(_str, type(_str)) # create_image("../../train_data/bidi_data/mix_data2/", "", _str) # get_data_from_appendix() # get_data_from_paddle() # delete_image("../../train_data/bidi_data/mix_data/", "train/text_0.jpg") # 获取paddle数据 ####################################### with open("paddle_data.txt", "r") as f: list1 = f.readlines() print(len(list1)) list1 = list1[:1000000] with open(train_data_path, "r") as f: list2 = f.readlines() train_data_list = list2 + list1[0:int(len(list1)*0.95)] with open(train_data_path, "w") as f: f.writelines(train_data_list) with open(test_data_path, "r") as f: list3 = f.readlines() test_data_list = list3 + list1[int(len(list1)*0.95):] with open(test_data_path, "w") as f: f.writelines(test_data_list) ####################################################### # no_list = create_number_list3(2000000) # i = 23000000 # train_list = [] # for no in no_list: # train_list.append("train/text_" + str(i) + ".jpg" + "\t" + no) # i += 1 # # print(train_list) # # with open(train_data_path, "r") as f: # list3 = f.readlines() # _list = list3[:-2000000] + train_list # with open(train_data_path, "w") as f: # f.writelines(_list)