pre_process.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. import random
  2. import psycopg2
  3. import numpy as np
  4. from BiddingKG.dl.common.Utils import embedding_word
  5. def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
  6. with open(dict_path, 'r') as f:
  7. character_list = f.readlines()
  8. for i in range(len(character_list)):
  9. character_list[i] = character_list[i][:-1]
  10. index_list = []
  11. for character in sentence:
  12. if character == '':
  13. index_list.append(0)
  14. elif character in character_list:
  15. _index = character_list.index(character) + 1
  16. index_list.append(_index)
  17. else:
  18. index_list.append(0)
  19. return index_list
  20. def postgresql_util(sql, limit):
  21. conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres",
  22. host="192.168.2.103")
  23. cursor = conn.cursor()
  24. cursor.execute(sql)
  25. print(sql)
  26. rows = cursor.fetchmany(1000)
  27. cnt = 0
  28. all_rows = []
  29. while rows:
  30. if cnt >= limit:
  31. break
  32. all_rows += rows
  33. cnt += len(rows)
  34. rows = cursor.fetchmany(1000)
  35. return all_rows
  36. def get_data_from_sql(dim=10):
  37. # sql = """
  38. # select table_text, pre_label, post_label, id
  39. # from label_table_head_info
  40. # where update_user <> 'test27' and update_user <> 'test20' and table_box_cnt >= 4 and table_box_cnt <= 200
  41. # ;
  42. # """
  43. sql = """
  44. select table_text, pre_label, post_label, id
  45. from label_table_head_info
  46. where status = 1 and update_time >= '2022-01-17'
  47. ;
  48. """
  49. result_list = postgresql_util(sql, limit=1000000)
  50. all_data_list = []
  51. all_data_label_list = []
  52. i = 0
  53. # 一行就是一篇表格
  54. for table in result_list:
  55. i += 1
  56. if i % 100 == 0:
  57. print("Loop", i)
  58. pre_label = eval(table[1])
  59. post_label = eval(table[2])
  60. _id = table[3]
  61. # table_text需要特殊处理
  62. try:
  63. table_text = table[0]
  64. if table_text[0] == '"':
  65. table_text = eval(table_text)
  66. else:
  67. table_text = table_text
  68. table_text = table_text.replace('\\', '/')
  69. table_text = eval(table_text)
  70. except:
  71. print("无法识别table_text", _id)
  72. continue
  73. # 只有一行的也不要
  74. if len(post_label) >= 2:
  75. data_list, data_label_list = table_pre_process(table_text, post_label, _id)
  76. elif len(pre_label) >= 2:
  77. data_list, data_label_list = table_pre_process(table_text, pre_label, _id)
  78. else:
  79. data_list, data_label_list = [], []
  80. all_data_list += data_list
  81. all_data_label_list += data_label_list
  82. print("len(all_data_list)", len(all_data_list))
  83. return all_data_list, all_data_label_list
  84. def table_pre_process(text_list, label_list, _id, is_train=True):
  85. """
  86. 表格处理,每个单元格生成2条数据,横竖各1条
  87. :param text_list:
  88. :param label_list:
  89. :param _id:
  90. :param is_train:
  91. :return:
  92. """
  93. if is_train:
  94. if len(text_list) != len(label_list):
  95. print("文字单元格与标注单元格数量不匹配!", _id)
  96. print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
  97. return [], []
  98. data_list = []
  99. data_label_list = []
  100. for i in range(len(text_list)):
  101. row = text_list[i]
  102. if is_train:
  103. row_label = label_list[i]
  104. if i > 0:
  105. last_row = text_list[i-1]
  106. if is_train:
  107. last_row_label = label_list[i-1]
  108. else:
  109. last_row = []
  110. if is_train:
  111. last_row_label = []
  112. if i < len(text_list) - 1:
  113. next_row = text_list[i+1]
  114. if is_train:
  115. next_row_label = label_list[i+1]
  116. else:
  117. next_row = []
  118. if is_train:
  119. next_row_label = []
  120. for j in range(len(row)):
  121. col = row[j]
  122. if is_train:
  123. col_label = row_label[j]
  124. # 超出表格置为None, 0
  125. if j > 0:
  126. last_col = row[j-1]
  127. if is_train:
  128. last_col_label = row_label[j-1]
  129. else:
  130. last_col = col
  131. if is_train:
  132. last_col_label = col_label
  133. if j < len(row) - 1:
  134. next_col = row[j+1]
  135. if is_train:
  136. next_col_label = row_label[j+1]
  137. else:
  138. next_col = col
  139. if is_train:
  140. next_col_label = col_label
  141. if last_row:
  142. last_row_col = last_row[j]
  143. if is_train:
  144. last_row_col_label = last_row_label[j]
  145. else:
  146. last_row_col = col
  147. if is_train:
  148. last_row_col_label = col_label
  149. if next_row:
  150. next_row_col = next_row[j]
  151. if is_train:
  152. next_row_col_label = next_row_label[j]
  153. else:
  154. next_row_col = col
  155. if is_train:
  156. next_row_col_label = col_label
  157. # data_list.append([last_col, col, next_col])
  158. # if is_train:
  159. # data_label_list.append([int(last_col_label), int(col_label),
  160. # int(next_col_label)])
  161. #
  162. # data_list.append([last_row_col, col, next_row_col])
  163. # if is_train:
  164. # data_label_list.append([int(last_row_col_label), int(col_label),
  165. # int(next_row_col_label)])
  166. if is_train:
  167. dup_list = [str(x) for x in data_list]
  168. data = [last_col, col, next_col, last_row_col, col, next_row_col]
  169. if str(data) not in dup_list:
  170. data_list.append([last_col, col, next_col, last_row_col, col, next_row_col])
  171. data_label_list.append(int(col_label))
  172. else:
  173. data_list.append([last_col, col, next_col, last_row_col, col, next_row_col])
  174. if is_train:
  175. return data_list, data_label_list
  176. else:
  177. return data_list
  178. def get_data_from_file(file_type):
  179. if file_type == 'np':
  180. data_path = 'train_data/data_3.npy'
  181. data_label_path = 'train_data/data_label_3.npy'
  182. array1 = np.load(data_path)
  183. array2 = np.load(data_label_path)
  184. return array1, array2
  185. elif file_type == 'txt':
  186. data_path = 'train_data/data3.txt'
  187. data_label_path = 'train_data/data_label3.txt'
  188. with open(data_path, 'r') as f:
  189. data_list = f.readlines()
  190. with open(data_label_path, 'r') as f:
  191. data_label_list = f.readlines()
  192. # for i in range(len(data_list)):
  193. # data_list[i] = eval(data_list[i][:-1])
  194. # data_label_list[i] = eval(data_label_list[i][:-1])
  195. return data_list, data_label_list
  196. else:
  197. print("file type error! only np and txt supported")
  198. raise Exception
  199. def processed_save_to_np():
  200. array1, array2 = get_data_from_sql()
  201. np.save('train_data/data_3.npy', array1)
  202. np.save('train_data/data_label_3.npy', array2)
  203. # with open('train_data/data.txt', 'w') as f:
  204. # for line in list1:
  205. # f.write(str(line) + "\n")
  206. # with open('train_data/data_label.txt', 'w') as f:
  207. # for line in list2:
  208. # f.write(str(line) + "\n")
  209. def processed_save_to_txt():
  210. list1, list2 = get_data_from_sql()
  211. # 打乱
  212. zip_list = list(zip(list1, list2))
  213. random.shuffle(zip_list)
  214. list1[:], list2[:] = zip(*zip_list)
  215. with open('train_data/data3.txt', 'w') as f:
  216. for line in list1:
  217. f.write(str(line) + "\n")
  218. with open('train_data/data_label3.txt', 'w') as f:
  219. for line in list2:
  220. f.write(str(line) + "\n")
  221. def data_balance():
  222. data_list, data_label_list = get_data_from_file('txt')
  223. all_cnt = len(data_label_list)
  224. cnt_0 = 0
  225. cnt_1 = 0
  226. for data in data_label_list:
  227. if eval(data[:-1])[1] == 1:
  228. cnt_1 += 1
  229. else:
  230. cnt_0 += 1
  231. print("all_cnt", all_cnt)
  232. print("label has 1", cnt_1)
  233. print("label all 0", cnt_0)
  234. def test_embedding():
  235. output_shape = (2, 1, 60)
  236. data = [[None], [None]]
  237. result = embedding_word(data, output_shape)
  238. print(result)
  239. def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
  240. data_num = len(data_list)
  241. # 定义Embedding输出
  242. output_shape = (6, 10, 60)
  243. # batch循环取数据
  244. i = 0
  245. if is_train:
  246. while True:
  247. new_data_list = []
  248. new_data_label_list = []
  249. for j in range(batch_size):
  250. if i >= data_num:
  251. i = 0
  252. # 中文字符映射为Embedding
  253. data = eval(data_list[i][:-1])
  254. data_label = eval(data_label_list[i][:-1])
  255. data = embedding_word(data, output_shape)
  256. if data.shape == output_shape:
  257. new_data_list.append(data)
  258. new_data_label_list.append(data_label)
  259. i += 1
  260. new_data_list = np.array(new_data_list)
  261. new_data_label_list = np.array(new_data_label_list)
  262. X = new_data_list
  263. Y = new_data_label_list
  264. # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
  265. X = np.transpose(X, (1, 0, 2, 3))
  266. if (X[0] == X[1]).all():
  267. X[0] = np.zeros_like(X[1], dtype='float32')
  268. if (X[2] == X[1]).all():
  269. X[2] = np.zeros_like(X[1], dtype='float32')
  270. if (X[3] == X[1]).all():
  271. X[3] = np.zeros_like(X[1], dtype='float32')
  272. if (X[5] == X[1]).all():
  273. X[5] = np.zeros_like(X[1], dtype='float32')
  274. yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
  275. 'input_4': X[3], 'input_5': X[4], 'input_6': X[5]}, \
  276. {'output': Y}
  277. else:
  278. while True:
  279. new_data_list = []
  280. for j in range(batch_size):
  281. if i >= data_num:
  282. i = 0
  283. # 中文字符映射为Embedding
  284. data = data_list[i]
  285. data = embedding_word(data, output_shape)
  286. if data.shape == output_shape:
  287. new_data_list.append(data)
  288. i += 1
  289. new_data_list = np.array(new_data_list)
  290. X = new_data_list
  291. X = np.transpose(X, (1, 0, 2, 3))
  292. yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
  293. 'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
  294. if __name__ == '__main__':
  295. processed_save_to_txt()
  296. # data_balance()
  297. # test_embedding()