create_real_data_from_csv.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import json
  2. import random
  3. import numpy as np
  4. import cv2
  5. import pandas as pd
  6. import imgkit
  7. import time
  8. from PIL import Image
  9. from selenium.webdriver.chrome.options import Options
  10. from selenium import webdriver
  11. from create_labelme_data import create_lines_labelme, create_official_seal
  12. DESKTOP_PATH = r"C:\Users\Administrator\Desktop"
  13. def read_csv(path):
  14. df = pd.read_csv(path, encoding='gbk')
  15. df = df[:2000]
  16. _list = []
  17. for index, row in df.iterrows():
  18. if type(row['json_table']) == float:
  19. print(row['json_table'])
  20. continue
  21. _list.append(json.loads(row['json_table']))
  22. # print(_list)
  23. return _list
  24. def create_html_table(_list, html_dir, count=0):
  25. html_path_list = []
  26. i = count
  27. for table in _list:
  28. table_text = '<table id="table0" border="1px" cellspacing="0" cellpadding="0">' + "\n"
  29. for row in table:
  30. table_text += "<tr>" + "\n"
  31. for col in row:
  32. table_text += "<td>" + col[0] + "</td>" + "\n"
  33. table_text += "</tr>" + "\n"
  34. table_text += "</table>" + "\n"
  35. css_text = '<style type="text/css"> html{width=1000px;height=300px} ' \
  36. 'body{margin:5px} ' \
  37. 'table{border-collapse: collapse;text-align:left;} ' \
  38. '</style>'
  39. # 虚线框
  40. # css_text = '<style type="text/css"> ' \
  41. # 'html{width=1000px;height=300px} body{margin:5px} ' \
  42. # 'td{border-style: dashed;} ' \
  43. # 'table{border-collapse: collapse;text-align:right;} ' \
  44. # '</style>'
  45. # 粗虚线 上下或左右无边线
  46. # css_text = '<style type="text/css"> ' \
  47. # 'html{width=1000px;height=300px} body{margin:5px} ' \
  48. # 'td{border-style: dashed; border-left:0px;} ' \
  49. # 'table{border-collapse: collapse;} ' \
  50. # '</style>'
  51. # css_text = '<style> .divcss5{ width:500px}.divcss5{ height:500px}</style>'
  52. html_text = "<!DOCTYPE html>" + "\n" \
  53. + "<html>" + "\n" \
  54. + css_text + "\n" \
  55. + '<head><meta charset="UTF-8"></head>' + "\n" \
  56. + '<div class="divcss5">' + "\n" \
  57. + table_text + "\n" \
  58. + '</div>' + "\n" \
  59. + "</html>"
  60. html_path = html_dir + str(i) + ".html"
  61. with open(html_path, 'w') as f:
  62. f.write(html_text)
  63. html_path_list.append(html_path)
  64. i += 1
  65. return html_path_list
  66. def get_table_size(html_path):
  67. print(html_path)
  68. chrome_options = Options()
  69. chrome_options.add_argument('--headless')
  70. chrome_options.add_argument('--no-sandbox')
  71. driver = webdriver.Chrome(chrome_options=chrome_options)
  72. # driver = webdriver.Chrome()
  73. driver.maximize_window()
  74. driver.implicitly_wait(0)
  75. driver.get(html_path)
  76. # JS
  77. js = """
  78. _t = document.getElementById('table0')
  79. function myFunc(){
  80. var cells_size = new Array();
  81. for (i = 0; i<_t.rows.length; i++){
  82. var rows_size = new Array();
  83. for (j = 0; j < _t.rows[i].cells.length; j++){
  84. col = _t.rows[i].cells[j];
  85. col_width = window.getComputedStyle(col).width;
  86. col_height = window.getComputedStyle(col).height;
  87. rows_size[j] = [col_width, col_height];
  88. }
  89. cells_size[i] = rows_size
  90. }
  91. return cells_size;
  92. }
  93. return myFunc();
  94. """
  95. rows_list = driver.execute_script(js)
  96. # print(rows_list)
  97. # print(type(rows_list))
  98. js = """
  99. _t = document.getElementById('table0')
  100. function myFunc(){
  101. table_width = window.getComputedStyle(_t).width;
  102. table_height = window.getComputedStyle(_t).height;
  103. var table_size = [table_width, table_height];
  104. return table_size
  105. }
  106. return myFunc()
  107. """
  108. table_size = driver.execute_script(js)
  109. table_size = [int(float(table_size[0][:-2])), int(float(table_size[1][:-2]))]
  110. print("table_size", table_size)
  111. # get_table_lines(rows_list, table_size)
  112. # table = driver.find_element_by_id('table0')
  113. # print("rows_list", rows_list)
  114. return rows_list, table_size
  115. def html2image1(html_path, table_size):
  116. # 工具路径
  117. path_wkimg = r'D:\Software\html_to_pdf\wkhtmltopdf\bin\wkhtmltoimage.exe'
  118. cfg = imgkit.config(wkhtmltoimage=path_wkimg)
  119. options = {
  120. 'width': table_size[0],
  121. 'height': table_size[1],
  122. 'encoding': 'UTF-8',
  123. }
  124. # 1、将html文件转为图片
  125. image_path = html_path.split(".")[0]+".jpg"
  126. print("html2image", image_path)
  127. imgkit.from_file(html_path, image_path, config=cfg, options=options)
  128. # 2、从url获取html,再转为图片
  129. # imgkit.from_url('https://httpbin.org/ip', 'ip.jpg', config=cfg)
  130. # 3、将字符串转为图片
  131. # imgkit.from_string('Hello!','hello.jpg', config=cfg)
  132. return image_path
  133. def html2image(html_path, table_size):
  134. # 将html文件转为图片
  135. image_path = html_path.split(".")[0]+".jpg"
  136. print("html2image", image_path)
  137. chrome_options = Options()
  138. chrome_options.add_argument('--headless')
  139. chrome_options.add_argument('--no-sandbox')
  140. broswer = webdriver.Chrome(chrome_options=chrome_options)
  141. broswer.get(html_path)
  142. # 截全图
  143. width = broswer.execute_script("return document.documentElement.scrollWidth")
  144. height = broswer.execute_script("return document.documentElement.scrollHeight")
  145. broswer.set_window_size(width, height)
  146. broswer.save_screenshot(image_path)
  147. table = broswer.find_element_by_id('table0')
  148. left = table.location['x']
  149. top = table.location['y']
  150. elementWidth = table.location['x'] + table.size['width']
  151. elementHeight = table.location['y'] + table.size['height']
  152. picture = Image.open(image_path)
  153. picture = picture.crop((left, top, elementWidth, elementHeight))
  154. picture = picture.convert("RGB")
  155. picture.save(image_path)
  156. return image_path
  157. def get_table_lines(rows_list):
  158. row_line_list = []
  159. col_line_list = []
  160. x = 0
  161. y = 0
  162. # 横线line
  163. width = 0
  164. height = 0
  165. i = 0
  166. for row in rows_list:
  167. if i == 0:
  168. for col in row:
  169. width += float(col[0][:-2]) + 1
  170. row_line_list.append([[x, y], [x+width, y]])
  171. height += float(row[0][1][:-2]) + 1
  172. else:
  173. row_line_list.append([[x, y+height], [x+width, y+height]])
  174. height += float(row[0][1][:-2]) + 1
  175. i += 1
  176. row_line_list.append([[x, y+height], [x+width, y+height]])
  177. # 竖线line
  178. width = 0
  179. height = 0
  180. i = 0
  181. for col_num in range(len(rows_list)):
  182. height += float(rows_list[col_num][0][1][:-2]) + 1
  183. # print("height", height)
  184. for row in rows_list:
  185. if i == 0:
  186. col_line_list.append([[x, y], [x, y+height]])
  187. for col in row:
  188. width += float(col[0][:-2]) + 1
  189. col_line_list.append([[x+width, y], [x+width, y+height]])
  190. break
  191. # print("row_line_list", row_line_list)
  192. # print("col_line_list", col_line_list)
  193. # draw_lines(row_line_list+col_line_list, table_size, )
  194. return row_line_list+col_line_list
  195. def draw_lines(line_list, table_size, image_path, expand=False):
  196. img = np.zeros((table_size[1], table_size[0]), np.uint8)
  197. img.fill(255)
  198. if expand:
  199. image_origin = cv2.imread(image_path)
  200. # print(image_origin.shape, img.shape)
  201. # expand_height = int((image_origin.shape[1] - img.shape[1]) / 2)
  202. # expand_width = int((image_origin.shape[0] - img.shape[0]) / 2)
  203. # img = cv2.copyMakeBorder(img, expand_height, expand_height, expand_width,
  204. # expand_width, cv2.BORDER_CONSTANT, value=(255, 255, 255))
  205. img = np.zeros((image_origin.shape[0], image_origin.shape[1]), np.uint8)
  206. print(image_origin.shape, img.shape)
  207. img.fill(255)
  208. for line in line_list:
  209. cv2.line(img, (int(line[0][0]), int(line[0][1])),
  210. (int(line[1][0]), int(line[1][1])), (0, 0, 255), 1)
  211. cv2.imwrite(image_path.split(".")[0] + ".png", img)
  212. # cv2.imshow("label", img)
  213. # cv2.waitKey(0)
  214. # image = cv2.imread(image_path)
  215. # cv2.imshow("image", image)
  216. # cv2.waitKey(0)
  217. return line_list
  218. def image_expand(image_path, line_list):
  219. # image_np = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
  220. # 随机选择边缘扩充px
  221. # expand_height = random.randint(0, 500)
  222. # expand_width = int(expand_height / 1.3333)
  223. image_origin = cv2.imread(image_path)
  224. expand_height = int((1123 - image_origin.shape[0]) / 2)
  225. expand_width = int((794 - image_origin.shape[1]) / 2)
  226. if expand_width < 0:
  227. expand_width = 0
  228. if expand_height < 0:
  229. expand_height = 0
  230. # print(expand_height, expand_width)
  231. # 图像边缘扩充
  232. image_np = cv2.imread(image_path)
  233. image_np = cv2.copyMakeBorder(image_np, expand_height, expand_height, expand_width,
  234. expand_width, cv2.BORDER_CONSTANT, value=(255, 255, 255))
  235. cv2.imwrite(image_path, image_np)
  236. # 线条坐标全部加上增加的宽高
  237. new_line_list = []
  238. for line in line_list:
  239. new_line_list.append([[line[0][0]+expand_width, line[0][1]+expand_height],
  240. [line[1][0]+expand_width, line[1][1]+expand_height]])
  241. return image_path, new_line_list
  242. if __name__ == '__main__':
  243. csv_path = "D:\\BIDI_DOC\\比地_文档\\websource_67000_table.csv"
  244. table_list = read_csv(csv_path)
  245. i = 500
  246. stop_i = 700
  247. table_list = table_list[i:]
  248. save_dir = "D:\\Project\\table-detect-master\\data_process\\create_data\\"
  249. html_path_list = create_html_table(table_list, save_dir, i)
  250. for html in html_path_list:
  251. print("Loop", i)
  252. rows_list, table_size = get_table_size(html)
  253. image_path = html2image(html, table_size)
  254. line_list = get_table_lines(rows_list)
  255. # 图片扩展
  256. # image_path, line_list = image_expand(image_path, line_list)
  257. # 添加公章
  258. image_np = cv2.imread(image_path)
  259. image_np = create_official_seal(image_np)
  260. cv2.imwrite(image_path, image_np)
  261. with open(image_path, 'rb') as f:
  262. image_bytes = f.read()
  263. image_np = cv2.imread(image_path)
  264. labelme = create_lines_labelme(line_list, image_bytes, image_np.shape[1], image_np.shape[0])
  265. with open('../train/dataset-line/6/train_' + str(i) + '.json', 'w') as f:
  266. json.dump(labelme, f)
  267. draw_lines(line_list, table_size, image_path, False)
  268. i += 1
  269. if i > stop_i:
  270. break
  271. # break
  272. # img_path = DESKTOP_PATH + '/1.jpg'
  273. # get_table_lines(img_path)