create_real_data_from_csv.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import copy
  2. import json
  3. import random
  4. import numpy as np
  5. import cv2
  6. import pandas as pd
  7. import imgkit
  8. import time
  9. from bs4 import BeautifulSoup
  10. from PIL import Image
  11. from selenium.webdriver.chrome.options import Options
  12. from selenium import webdriver
  13. from create_labelme_data import create_lines_labelme, create_official_seal
  14. DESKTOP_PATH = r"C:\Users\Administrator\Desktop"
  15. def read_csv(path):
  16. df = pd.read_csv(path, encoding='gbk')
  17. df = df[:2000]
  18. _list = []
  19. for index, row in df.iterrows():
  20. if type(row['json_table']) == float:
  21. print(row['json_table'])
  22. continue
  23. _list.append(json.loads(row['json_table']))
  24. # print(_list)
  25. return _list
  26. def create_html_table(_list, html_dir, count=0):
  27. html_path_list = []
  28. i = count
  29. for table in _list:
  30. table_text = '<table id="table0" border="1px" cellspacing="0" cellpadding="0">' + "\n"
  31. for row in table:
  32. table_text += "<tr>" + "\n"
  33. for col in row:
  34. table_text += "<td>" + col[0] + "</td>" + "\n"
  35. table_text += "</tr>" + "\n"
  36. table_text += "</table>" + "\n"
  37. css_text = '<style type="text/css"> html{width=1000px;height=300px} ' \
  38. 'body{margin:5px} ' \
  39. 'table{border-collapse: collapse;text-align:left;} ' \
  40. '</style>'
  41. # 虚线框
  42. # css_text = '<style type="text/css"> ' \
  43. # 'html{width=1000px;height=300px} body{margin:5px} ' \
  44. # 'td{border-style: dashed;} ' \
  45. # 'table{border-collapse: collapse;text-align:right;} ' \
  46. # '</style>'
  47. # 粗虚线 上下或左右无边线
  48. # css_text = '<style type="text/css"> ' \
  49. # 'html{width=1000px;height=300px} body{margin:5px} ' \
  50. # 'td{border-style: dashed; border-left:0px;} ' \
  51. # 'table{border-collapse: collapse;} ' \
  52. # '</style>'
  53. # css_text = '<style> .divcss5{ width:500px}.divcss5{ height:500px}</style>'
  54. html_text = "<!DOCTYPE html>" + "\n" \
  55. + "<html>" + "\n" \
  56. + css_text + "\n" \
  57. + '<head><meta charset="UTF-8"></head>' + "\n" \
  58. + '<div class="divcss5">' + "\n" \
  59. + table_text + "\n" \
  60. + '</div>' + "\n" \
  61. + "</html>"
  62. html_path = html_dir + str(i) + ".html"
  63. with open(html_path, 'w') as f:
  64. f.write(html_text)
  65. html_path_list.append(html_path)
  66. i += 1
  67. return html_path_list
  68. def get_table_size(html_path):
  69. print(html_path)
  70. chrome_options = Options()
  71. chrome_options.add_argument('--headless')
  72. chrome_options.add_argument('--no-sandbox')
  73. driver_path = "./chromedriver.exe"
  74. driver = webdriver.Chrome(executable_path=driver_path,
  75. chrome_options=chrome_options)
  76. # driver = webdriver.Chrome()
  77. driver.maximize_window()
  78. driver.implicitly_wait(0)
  79. driver.get(html_path)
  80. # JS
  81. js = """
  82. _t = document.getElementById('table0')
  83. function myFunc(){
  84. var cells_size = new Array();
  85. for (i = 0; i<_t.rows.length; i++){
  86. var rows_size = new Array();
  87. for (j = 0; j < _t.rows[i].cells.length; j++){
  88. col = _t.rows[i].cells[j];
  89. col_width = window.getComputedStyle(col).width;
  90. col_height = window.getComputedStyle(col).height;
  91. rows_size[j] = [col_width, col_height];
  92. }
  93. cells_size[i] = rows_size
  94. }
  95. return cells_size;
  96. }
  97. return myFunc();
  98. """
  99. rows_list = driver.execute_script(js)
  100. # print(rows_list)
  101. # print(type(rows_list))
  102. js = """
  103. _t = document.getElementById('table0')
  104. function myFunc(){
  105. table_width = window.getComputedStyle(_t).width;
  106. table_height = window.getComputedStyle(_t).height;
  107. var table_size = [table_width, table_height];
  108. return table_size
  109. }
  110. return myFunc()
  111. """
  112. table_size = driver.execute_script(js)
  113. table_size = [int(float(table_size[0][:-2])), int(float(table_size[1][:-2]))]
  114. print("table_size", table_size)
  115. # get_table_lines(rows_list, table_size)
  116. # table = driver.find_element_by_id('table0')
  117. # print("rows_list", rows_list)
  118. return rows_list, table_size
  119. def html2image1(html_path, table_size):
  120. # 工具路径
  121. path_wkimg = r'D:\Software\html_to_pdf\wkhtmltopdf\bin\wkhtmltoimage.exe'
  122. cfg = imgkit.config(wkhtmltoimage=path_wkimg)
  123. options = {
  124. 'width': table_size[0],
  125. 'height': table_size[1],
  126. 'encoding': 'UTF-8',
  127. }
  128. # 1、将html文件转为图片
  129. image_path = html_path.split(".")[0]+".jpg"
  130. print("html2image", image_path)
  131. imgkit.from_file(html_path, image_path, config=cfg, options=options)
  132. # 2、从url获取html,再转为图片
  133. # imgkit.from_url('https://httpbin.org/ip', 'ip.jpg', config=cfg)
  134. # 3、将字符串转为图片
  135. # imgkit.from_string('Hello!','hello.jpg', config=cfg)
  136. return image_path
  137. def html2image(html_path, table_size):
  138. # 将html文件转为图片
  139. image_path = html_path.split(".")[0]+".jpg"
  140. print("html2image", image_path)
  141. chrome_options = Options()
  142. chrome_options.add_argument('--headless')
  143. chrome_options.add_argument('--no-sandbox')
  144. broswer = webdriver.Chrome(executable_path="./chromedriver.exe",
  145. chrome_options=chrome_options)
  146. broswer.get(html_path)
  147. # 截全图
  148. width = broswer.execute_script("return document.documentElement.scrollWidth")
  149. height = broswer.execute_script("return document.documentElement.scrollHeight")
  150. broswer.set_window_size(width, height)
  151. broswer.save_screenshot(image_path)
  152. table = broswer.find_element_by_id('table0')
  153. left = table.location['x']
  154. top = table.location['y']
  155. elementWidth = table.location['x'] + table.size['width']
  156. elementHeight = table.location['y'] + table.size['height']
  157. picture = Image.open(image_path)
  158. picture = picture.crop((left, top, elementWidth, elementHeight))
  159. picture = picture.convert("RGB")
  160. picture.save(image_path)
  161. return image_path
  162. def get_table_lines(rows_list):
  163. row_line_list = []
  164. col_line_list = []
  165. x = 0
  166. y = 0
  167. # 横线line
  168. width = 0
  169. height = 0
  170. i = 0
  171. for row in rows_list:
  172. if i == 0:
  173. for col in row:
  174. width += float(col[0][:-2]) + 1
  175. row_line_list.append([[x, y], [x+width, y]])
  176. height += float(row[0][1][:-2]) + 1
  177. else:
  178. row_line_list.append([[x, y+height], [x+width, y+height]])
  179. height += float(row[0][1][:-2]) + 1
  180. i += 1
  181. row_line_list.append([[x, y+height], [x+width, y+height]])
  182. # 竖线line
  183. width = 0
  184. height = 0
  185. i = 0
  186. for col_num in range(len(rows_list)):
  187. height += float(rows_list[col_num][0][1][:-2]) + 1
  188. # print("height", height)
  189. for row in rows_list:
  190. if i == 0:
  191. col_line_list.append([[x, y], [x, y+height]])
  192. for col in row:
  193. width += float(col[0][:-2]) + 1
  194. col_line_list.append([[x+width, y], [x+width, y+height]])
  195. break
  196. # print("row_line_list", row_line_list)
  197. # print("col_line_list", col_line_list)
  198. # draw_lines(row_line_list+col_line_list, table_size, )
  199. return row_line_list+col_line_list
  200. def draw_lines(line_list, table_size, image_path, expand=False):
  201. img = np.zeros((table_size[1], table_size[0]), np.uint8)
  202. img.fill(255)
  203. if expand:
  204. image_origin = cv2.imread(image_path)
  205. # print(image_origin.shape, img.shape)
  206. # expand_height = int((image_origin.shape[1] - img.shape[1]) / 2)
  207. # expand_width = int((image_origin.shape[0] - img.shape[0]) / 2)
  208. # img = cv2.copyMakeBorder(img, expand_height, expand_height, expand_width,
  209. # expand_width, cv2.BORDER_CONSTANT, value=(255, 255, 255))
  210. img = np.zeros((image_origin.shape[0], image_origin.shape[1]), np.uint8)
  211. print(image_origin.shape, img.shape)
  212. img.fill(255)
  213. for line in line_list:
  214. cv2.line(img, (int(line[0][0]), int(line[0][1])),
  215. (int(line[1][0]), int(line[1][1])), (0, 0, 255), 1)
  216. cv2.imwrite(image_path.split(".")[0] + ".png", img)
  217. # cv2.imshow("label", img)
  218. # cv2.waitKey(0)
  219. # image = cv2.imread(image_path)
  220. # cv2.imshow("image", image)
  221. # cv2.waitKey(0)
  222. return line_list
  223. def image_expand(image_path, line_list):
  224. # image_np = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
  225. # 随机选择边缘扩充px
  226. # expand_height = random.randint(0, 500)
  227. # expand_width = int(expand_height / 1.3333)
  228. image_origin = cv2.imread(image_path)
  229. expand_height = int((1123 - image_origin.shape[0]) / 2)
  230. expand_width = int((794 - image_origin.shape[1]) / 2)
  231. if expand_width < 0:
  232. expand_width = 0
  233. if expand_height < 0:
  234. expand_height = 0
  235. # print(expand_height, expand_width)
  236. # 图像边缘扩充
  237. image_np = cv2.imread(image_path)
  238. image_np = cv2.copyMakeBorder(image_np, expand_height, expand_height, expand_width,
  239. expand_width, cv2.BORDER_CONSTANT, value=(255, 255, 255))
  240. cv2.imwrite(image_path, image_np)
  241. # 线条坐标全部加上增加的宽高
  242. new_line_list = []
  243. for line in line_list:
  244. new_line_list.append([[line[0][0]+expand_width, line[0][1]+expand_height],
  245. [line[1][0]+expand_width, line[1][1]+expand_height]])
  246. return image_path, new_line_list
  247. def create_table_line_data():
  248. csv_path = "D:\\BIDI_DOC\\比地_文档\\websource_67000_table.csv"
  249. table_list = read_csv(csv_path)
  250. i = 500
  251. stop_i = 700
  252. table_list = table_list[:10]
  253. save_dir = "D:\\Project\\table-detect-master\\data_process\\create_data\\"
  254. html_path_list = create_html_table(table_list, save_dir, i)
  255. for html in html_path_list:
  256. print("Loop", i)
  257. rows_list, table_size = get_table_size(html)
  258. image_path = html2image(html, table_size)
  259. line_list = get_table_lines(rows_list)
  260. # 图片扩展
  261. # image_path, line_list = image_expand(image_path, line_list)
  262. # 添加公章
  263. image_np = cv2.imread(image_path)
  264. image_np = create_official_seal(image_np)
  265. cv2.imshow("image", image_np)
  266. cv2.waitKey(0)
  267. cv2.imwrite(image_path, image_np)
  268. with open(image_path, 'rb') as f:
  269. image_bytes = f.read()
  270. image_np = cv2.imread(image_path)
  271. labelme = create_lines_labelme(line_list, image_bytes, image_np.shape[1], image_np.shape[0])
  272. with open('../train/dataset-line/7/train_' + str(i) + '.json', 'w') as f:
  273. json.dump(labelme, f)
  274. draw_lines(line_list, table_size, image_path, False)
  275. i += 1
  276. if i > stop_i:
  277. break
  278. def html_table2label_table(html_path):
  279. with open(html_path, "r") as f:
  280. html_text = f.read()
  281. soup = BeautifulSoup(html_text)
  282. tr_list = []
  283. for tr in soup.findAll('tr'):
  284. # get td text
  285. td_list = []
  286. for td in tr.findAll('td'):
  287. td_list.append(td.getText())
  288. tr_list.append(td_list)
  289. # 一列处理成相同长度
  290. padding_tr_list = copy.deepcopy(tr_list)
  291. for col_num in range(len(tr_list[0])):
  292. one_col_multi_row = []
  293. # 拿到该列所有行
  294. for row_num in range(len(tr_list)):
  295. one_col_multi_row.append(tr_list[row_num][col_num])
  296. # padding
  297. max_col_len = len(sorted(one_col_multi_row, key=lambda x: len(x))[-1])
  298. print(one_col_multi_row)
  299. print(max_col_len)
  300. for i in range(len(one_col_multi_row)):
  301. text = one_col_multi_row[i]
  302. if len(text) >= max_col_len:
  303. padding_tr_list[i][col_num] = text
  304. continue
  305. if (max_col_len - len(text)) % 2 == 1:
  306. left_space_num = int((max_col_len - len(text)) / 2)
  307. right_space_num = left_space_num + 1
  308. new_text = " " * left_space_num + text + " " * right_space_num
  309. else:
  310. space_num = int((max_col_len - len(text)) / 2)
  311. new_text = " " * space_num + text + " " * space_num
  312. padding_tr_list[i][col_num] = new_text
  313. # print(padding_tr_list)
  314. # labeled html table
  315. prefix = """
  316. <!DOCTYPE html>
  317. <html>
  318. <style type="text/css"> html{width=1000px;height=300px} body{margin:5px} table{border-collapse: collapse;text-align:left;} </style>
  319. <head><meta charset="UTF-8"></head>
  320. <div class="divcss5">
  321. <table id="table0" border="0px" cellspacing="0" cellpadding="0">\n
  322. """
  323. label_table_text = prefix
  324. for tr in padding_tr_list:
  325. label_table_text += "<tr>" + "\n"
  326. for td in tr:
  327. label_table_text += "<td>" + "|" + td + "|" + "</td>"
  328. label_table_text += "</tr>" + "\n"
  329. label_table_text += "</table>" + "\n"
  330. label_table_text += "</div>" + "\n"
  331. label_table_text += "</html>" + "\n"
  332. with open(html_path+".html", "w") as f:
  333. f.write(label_table_text)
  334. def create_table_label_data():
  335. csv_path = "D:\\BIDI_DOC\\比地_文档\\websource_67000_table.csv"
  336. table_list = read_csv(csv_path)
  337. start_no = 500
  338. table_list = table_list[:1]
  339. save_dir = "D:\\Project\\table-detect-master\\data_process\\create_data\\"
  340. html_path_list = create_html_table(table_list, save_dir, start_no)
  341. for html in html_path_list:
  342. html_table2label_table(html)
  343. if __name__ == '__main__':
  344. create_table_label_data()