pre_process.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. import os
  2. import random
  3. import sys
  4. import psycopg2
  5. import numpy as np
  6. sys.path.append(os.path.dirname(__file__) + "/../")
  7. from common.Utils import embedding_word, embedding_word_forward
  8. def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
  9. with open(dict_path, 'r') as f:
  10. character_list = f.readlines()
  11. for i in range(len(character_list)):
  12. character_list[i] = character_list[i][:-1]
  13. index_list = []
  14. for character in sentence:
  15. if character == '':
  16. index_list.append(0)
  17. elif character in character_list:
  18. _index = character_list.index(character) + 1
  19. index_list.append(_index)
  20. else:
  21. index_list.append(0)
  22. return index_list
  23. def postgresql_util(sql, limit):
  24. conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres",
  25. host="192.168.2.103")
  26. cursor = conn.cursor()
  27. cursor.execute(sql)
  28. print(sql)
  29. rows = cursor.fetchmany(1000)
  30. cnt = 0
  31. all_rows = []
  32. while rows:
  33. if cnt >= limit:
  34. break
  35. all_rows += rows
  36. cnt += len(rows)
  37. rows = cursor.fetchmany(1000)
  38. return all_rows
  39. def get_data_from_sql(dim=10, whole_table=False, padding=True):
  40. sql = """
  41. select table_text, pre_label, post_label, id
  42. from label_table_head_info
  43. where status = 0 and (update_user='test9' or update_user='test1' or update_user='test7' or update_user='test26')
  44. ;
  45. """
  46. # sql = """
  47. # select table_text, pre_label, post_label, id
  48. # from label_table_head_info
  49. # where status = 1 and update_time >= '2022-01-17' and update_time <= '2022-01-22'
  50. # ;
  51. # """
  52. result_list = postgresql_util(sql, limit=1000000)
  53. # 需排除的id
  54. with open(r"C:\Users\Administrator\Desktop\table_not_eval.txt", "r") as f:
  55. delete_id_list = eval(f.read())
  56. with open(r"C:\Users\Administrator\Desktop\table_delete.txt", "r") as f:
  57. delete_id_list += eval(f.read())
  58. all_data_list = []
  59. all_data_label_list = []
  60. i = 0
  61. # 一行就是一篇表格
  62. for table in result_list:
  63. i += 1
  64. if i % 100 == 0:
  65. print("Loop", i)
  66. pre_label = eval(table[1])
  67. post_label = eval(table[2])
  68. _id = table[3]
  69. if _id in delete_id_list:
  70. print("pass", _id)
  71. continue
  72. # table_text需要特殊处理
  73. try:
  74. table_text = table[0]
  75. if table_text[0] == '"':
  76. table_text = eval(table_text)
  77. else:
  78. table_text = table_text
  79. table_text = table_text.replace('\\', '/')
  80. table_text = eval(table_text)
  81. except:
  82. print("无法识别table_text", _id)
  83. continue
  84. if whole_table:
  85. if len(post_label) >= 2:
  86. data_list, data_label_list = table_pre_process_2(table_text, post_label,
  87. _id, padding=padding)
  88. elif len(pre_label) >= 2:
  89. data_list, data_label_list = table_pre_process_2(table_text, pre_label,
  90. _id, padding=padding)
  91. else:
  92. data_list, data_label_list = [], []
  93. else:
  94. # 只有一行的也不要
  95. if len(post_label) >= 2:
  96. data_list, data_label_list = table_pre_process(table_text, post_label, _id)
  97. elif len(pre_label) >= 2:
  98. data_list, data_label_list = table_pre_process(table_text, pre_label, _id)
  99. else:
  100. data_list, data_label_list = [], []
  101. all_data_list += data_list
  102. all_data_label_list += data_label_list
  103. # 按维度大小排序
  104. if whole_table:
  105. _list = []
  106. for data, label in zip(all_data_list, all_data_label_list):
  107. _list.append([data, label])
  108. _list.sort(key=lambda x: (len(x[0]), len(x[0][0])))
  109. all_data_list[:], all_data_label_list[:] = zip(*_list)
  110. print("len(all_data_list)", len(all_data_list))
  111. return all_data_list, all_data_label_list
  112. def table_pre_process(text_list, label_list, _id, is_train=True):
  113. """
  114. 表格处理,每个单元格生成2条数据,横竖各1条
  115. :param text_list:
  116. :param label_list:
  117. :param _id:
  118. :param is_train:
  119. :return:
  120. """
  121. if is_train:
  122. if len(text_list) != len(label_list):
  123. print("文字单元格与标注单元格数量不匹配!", _id)
  124. print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
  125. return [], []
  126. data_list = []
  127. data_label_list = []
  128. for i in range(len(text_list)):
  129. row = text_list[i]
  130. if is_train:
  131. row_label = label_list[i]
  132. if i > 0:
  133. last_row = text_list[i-1]
  134. if is_train:
  135. last_row_label = label_list[i-1]
  136. else:
  137. last_row = []
  138. if is_train:
  139. last_row_label = []
  140. if i < len(text_list) - 1:
  141. next_row = text_list[i+1]
  142. if is_train:
  143. next_row_label = label_list[i+1]
  144. else:
  145. next_row = []
  146. if is_train:
  147. next_row_label = []
  148. for j in range(len(row)):
  149. col = row[j]
  150. if is_train:
  151. col_label = row_label[j]
  152. # 超出表格置为None, 0
  153. if j > 0:
  154. last_col = row[j-1]
  155. if is_train:
  156. last_col_label = row_label[j-1]
  157. else:
  158. last_col = col
  159. if is_train:
  160. last_col_label = col_label
  161. if j < len(row) - 1:
  162. next_col = row[j+1]
  163. if is_train:
  164. next_col_label = row_label[j+1]
  165. else:
  166. next_col = col
  167. if is_train:
  168. next_col_label = col_label
  169. if last_row:
  170. last_row_col = last_row[j]
  171. if is_train:
  172. last_row_col_label = last_row_label[j]
  173. else:
  174. last_row_col = col
  175. if is_train:
  176. last_row_col_label = col_label
  177. if next_row:
  178. next_row_col = next_row[j]
  179. if is_train:
  180. next_row_col_label = next_row_label[j]
  181. else:
  182. next_row_col = col
  183. if is_train:
  184. next_row_col_label = col_label
  185. # data_list.append([last_col, col, next_col])
  186. # if is_train:
  187. # data_label_list.append([int(last_col_label), int(col_label),
  188. # int(next_col_label)])
  189. #
  190. # data_list.append([last_row_col, col, next_row_col])
  191. # if is_train:
  192. # data_label_list.append([int(last_row_col_label), int(col_label),
  193. # int(next_row_col_label)])
  194. if is_train:
  195. dup_list = [str(x) for x in data_list]
  196. data = [last_col, col, next_col, last_row_col, col, next_row_col]
  197. if str(data) not in dup_list:
  198. data_list.append([last_col, col, next_col, last_row_col, col, next_row_col])
  199. data_label_list.append(int(col_label))
  200. else:
  201. data_list.append([last_col, col, next_col, last_row_col, col, next_row_col])
  202. if is_train:
  203. return data_list, data_label_list
  204. else:
  205. return data_list
  206. def table_pre_process_2(text_list, label_list, _id, is_train=True, padding=True):
  207. """
  208. 表格处理,整个表格为一个数组,且填充长宽维度
  209. :param text_list:
  210. :param label_list:
  211. :param _id:
  212. :param is_train:
  213. :return:
  214. """
  215. # 判断表格长宽是否合理
  216. row_len = len(text_list)
  217. best_row_len = get_best_padding_size(row_len, min_len=8)
  218. col_len = len(text_list[0])
  219. best_col_len = get_best_padding_size(col_len, min_len=8)
  220. if best_row_len is None:
  221. if is_train:
  222. return [], []
  223. else:
  224. return []
  225. if best_col_len is None:
  226. if is_train:
  227. return [], []
  228. else:
  229. return []
  230. if is_train:
  231. if len(text_list) != len(label_list):
  232. print("文字单元格与标注单元格数量不匹配!", _id)
  233. print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
  234. return [], []
  235. if padding:
  236. for i in range(row_len):
  237. col_len = len(text_list[i])
  238. text_list[i] += [None]*(best_col_len-col_len)
  239. if is_train:
  240. label_list[i] += ["0"]*(best_col_len-col_len)
  241. text_list += [[None]*best_col_len]*(best_row_len-row_len)
  242. if is_train:
  243. label_list += [["0"]*best_col_len]*(best_row_len-row_len)
  244. if is_train:
  245. for i in range(len(label_list)):
  246. for j in range(len(label_list[i])):
  247. label_list[i][j] = int(label_list[i][j])
  248. return [text_list], [label_list]
  249. else:
  250. return [text_list]
  251. def get_best_padding_size(axis_len, min_len=3, max_len=300):
  252. # sizes = [8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120,
  253. # 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224,
  254. # 232, 240, 248, 256, 264, 272, 280, 288, 296]
  255. # sizes = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57,
  256. # 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111,
  257. # 114, 117, 120, 123, 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156,
  258. # 159, 162, 165, 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201,
  259. # 204, 207, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246,
  260. # 249, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291,
  261. # 294, 297]
  262. sizes = []
  263. for i in range(1, max_len):
  264. if i * min_len <= max_len:
  265. sizes.append(i * min_len)
  266. if axis_len > sizes[-1]:
  267. return axis_len
  268. best_len = sizes[-1]
  269. for height in sizes:
  270. if axis_len <= height:
  271. best_len = height
  272. break
  273. # print("get_best_padding_size", axis_len, best_len)
  274. return best_len
  275. def get_data_from_file(file_type, model_id=1):
  276. if file_type == 'np':
  277. data_path = 'train_data/data_3.npy'
  278. data_label_path = 'train_data/data_label_3.npy'
  279. array1 = np.load(data_path)
  280. array2 = np.load(data_label_path)
  281. return array1, array2
  282. elif file_type == 'txt':
  283. if model_id == 1:
  284. data_path = 'train_data/data1.txt'
  285. data_label_path = 'train_data/data_label1.txt'
  286. elif model_id == 2:
  287. data_path = 'train_data/data2.txt'
  288. data_label_path = 'train_data/data_label2.txt'
  289. elif model_id == 3:
  290. data_path = 'train_data/data3.txt'
  291. data_label_path = 'train_data/data_label3.txt'
  292. with open(data_path, 'r') as f:
  293. data_list = f.readlines()
  294. with open(data_label_path, 'r') as f:
  295. data_label_list = f.readlines()
  296. return data_list, data_label_list
  297. else:
  298. print("file type error! only np and txt supported")
  299. raise Exception
  300. def processed_save_to_np():
  301. array1, array2 = get_data_from_sql()
  302. np.save('train_data/data_3.npy', array1)
  303. np.save('train_data/data_label_3.npy', array2)
  304. # with open('train_data/data.txt', 'w') as f:
  305. # for line in list1:
  306. # f.write(str(line) + "\n")
  307. # with open('train_data/data_label.txt', 'w') as f:
  308. # for line in list2:
  309. # f.write(str(line) + "\n")
  310. def processed_save_to_txt(whole_table=False, padding=True):
  311. list1, list2 = get_data_from_sql(whole_table=whole_table, padding=padding)
  312. # 打乱
  313. # if not whole_table or not padding:
  314. zip_list = list(zip(list1, list2))
  315. random.shuffle(zip_list)
  316. list1[:], list2[:] = zip(*zip_list)
  317. with open('train_data/data1.txt', 'w') as f:
  318. for line in list1:
  319. f.write(str(line) + "\n")
  320. with open('train_data/data_label1.txt', 'w') as f:
  321. for line in list2:
  322. f.write(str(line) + "\n")
  323. def data_balance():
  324. data_list, data_label_list = get_data_from_file('txt')
  325. all_cnt = len(data_label_list)
  326. cnt_0 = 0
  327. cnt_1 = 0
  328. for data in data_label_list:
  329. if eval(data[:-1])[1] == 1:
  330. cnt_1 += 1
  331. else:
  332. cnt_0 += 1
  333. print("all_cnt", all_cnt)
  334. print("label has 1", cnt_1)
  335. print("label all 0", cnt_0)
  336. def test_embedding():
  337. output_shape = (2, 1, 60)
  338. data = [[None], [None]]
  339. result = embedding_word(data, output_shape)
  340. print(result)
  341. def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
  342. data_num = len(data_list)
  343. # 定义Embedding输出
  344. output_shape = (6, 20, 60)
  345. # batch循环取数据
  346. i = 0
  347. if is_train:
  348. while True:
  349. new_data_list = []
  350. new_data_label_list = []
  351. for j in range(batch_size):
  352. if i >= data_num:
  353. i = 0
  354. # 中文字符映射为Embedding
  355. data = eval(data_list[i][:-1])
  356. data_label = eval(data_label_list[i][:-1])
  357. data = embedding_word(data, output_shape)
  358. if data.shape == output_shape:
  359. new_data_list.append(data)
  360. new_data_label_list.append(data_label)
  361. i += 1
  362. new_data_list = np.array(new_data_list)
  363. new_data_label_list = np.array(new_data_label_list)
  364. X = new_data_list
  365. Y = new_data_label_list
  366. # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
  367. X = np.transpose(X, (1, 0, 2, 3))
  368. if (X[0] == X[1]).all():
  369. X[0] = np.zeros_like(X[1], dtype='float32')
  370. if (X[2] == X[1]).all():
  371. X[2] = np.zeros_like(X[1], dtype='float32')
  372. if (X[3] == X[1]).all():
  373. X[3] = np.zeros_like(X[1], dtype='float32')
  374. if (X[5] == X[1]).all():
  375. X[5] = np.zeros_like(X[1], dtype='float32')
  376. yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
  377. 'input_4': X[3], 'input_5': X[4], 'input_6': X[5]}, \
  378. {'output': Y}
  379. else:
  380. while True:
  381. new_data_list = []
  382. for j in range(batch_size):
  383. if i >= data_num:
  384. i = 0
  385. # 中文字符映射为Embedding
  386. data = data_list[i]
  387. data = embedding_word(data, output_shape)
  388. if data.shape == output_shape:
  389. new_data_list.append(data)
  390. i += 1
  391. new_data_list = np.array(new_data_list)
  392. X = new_data_list
  393. X = np.transpose(X, (1, 0, 2, 3))
  394. yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
  395. 'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
  396. def my_data_loader_2(table_list, table_label_list, batch_size, is_train=True):
  397. pad_len = 0
  398. table_num = len(table_list)
  399. if is_train and batch_size == 1:
  400. table_list, table_label_list = get_random(table_list, table_label_list)
  401. # Embedding shape
  402. output_shape = (20, 60)
  403. # batch循环取数据
  404. i = 0
  405. last_shape = None
  406. while True:
  407. new_table_list = []
  408. new_table_label_list = []
  409. for j in range(batch_size):
  410. if i >= table_num:
  411. i = 0
  412. if is_train:
  413. table_list, table_label_list = get_random(table_list, table_label_list,
  414. seed=random.randint(1, 40))
  415. if type(table_list[i]) != list:
  416. table = eval(table_list[i][:-1])
  417. else:
  418. table = table_list[i]
  419. if batch_size > 1:
  420. if last_shape is None:
  421. last_shape = (len(table), len(table[0]))
  422. continue
  423. if (len(table), len(table[0])) != last_shape:
  424. last_shape = (len(table), len(table[0]))
  425. break
  426. if is_train:
  427. table_label = eval(table_label_list[i][:-1])
  428. # 中文字符映射为Embedding
  429. for k in range(len(table)):
  430. table[k] = embedding_word_forward(table[k], (len(table[k]),
  431. output_shape[0],
  432. output_shape[1]))
  433. new_table_list.append(table)
  434. if is_train:
  435. new_table_label_list.append(table_label)
  436. i += 1
  437. new_table_list = np.array(new_table_list)
  438. X = new_table_list
  439. if X.shape[-2:] != output_shape:
  440. # print("Dimension not match!", X.shape)
  441. # print("\n")
  442. continue
  443. # 获取Padding大小
  444. pad_height = get_best_padding_size(X.shape[1], pad_len)
  445. pad_width = get_best_padding_size(X.shape[2], pad_len)
  446. input_2 = np.zeros([1, X.shape[1], X.shape[2], pad_height, pad_width])
  447. if is_train:
  448. new_table_label_list = np.array(new_table_label_list)
  449. Y = new_table_label_list
  450. # Y = Y.astype(np.float32)
  451. # yield {"input_1": X, "input_2": input_2}, \
  452. # {"output_1": Y, "output_2": Y}
  453. yield {"input_1": X, "input_2": input_2}, \
  454. {"output": Y}
  455. else:
  456. yield {"input_1": X, "input_2": input_2}
  457. def check_train_data():
  458. data_list, label_list = get_data_from_file('txt', model_id=2)
  459. for data in data_list:
  460. data = eval(data)
  461. if len(data) % 8 != 0:
  462. print(len(data))
  463. print(len(data[0]))
  464. for row in data:
  465. if len(row) % 8 != 0:
  466. print(len(data))
  467. print(len(row))
  468. def get_random(text_list, label_list, seed=42):
  469. random.seed(seed)
  470. zip_list = list(zip(text_list, label_list))
  471. random.shuffle(zip_list)
  472. text_list[:], label_list[:] = zip(*zip_list)
  473. return text_list, label_list
  474. if __name__ == '__main__':
  475. processed_save_to_txt(whole_table=False, padding=False)
  476. # data_balance()
  477. # test_embedding()
  478. # check_train_data()
  479. # _list = []
  480. # for i in range(1, 100):
  481. # _list.append(i*3)
  482. # print(_list)
  483. # print(get_best_padding_size(9, 5))