pre_process.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import psycopg2
  2. import numpy as np
  3. from BiddingKG.dl.common.Utils import embedding_word
  4. def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
  5. with open(dict_path, 'r') as f:
  6. character_list = f.readlines()
  7. for i in range(len(character_list)):
  8. character_list[i] = character_list[i][:-1]
  9. index_list = []
  10. for character in sentence:
  11. if character == '':
  12. index_list.append(0)
  13. elif character in character_list:
  14. _index = character_list.index(character) + 1
  15. index_list.append(_index)
  16. else:
  17. index_list.append(0)
  18. return index_list
  19. def postgresql_util(sql, limit):
  20. conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres",
  21. host="192.168.2.103")
  22. cursor = conn.cursor()
  23. cursor.execute(sql)
  24. print(sql)
  25. rows = cursor.fetchmany(1000)
  26. cnt = 0
  27. all_rows = []
  28. while rows:
  29. if cnt >= limit:
  30. break
  31. all_rows += rows
  32. cnt += len(rows)
  33. rows = cursor.fetchmany(1000)
  34. return all_rows
  35. def get_data_from_sql(dim=10):
  36. sql = """
  37. select table_text, pre_label, post_label, id
  38. from label_table_head_info
  39. where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200
  40. ;
  41. """
  42. # sql = """
  43. # select table_text, pre_label, post_label, id
  44. # from label_table_head_info
  45. # where id = 843
  46. # """
  47. result_list = postgresql_util(sql, limit=1000000)
  48. all_data_list = []
  49. all_data_label_list = []
  50. i = 0
  51. # 一行就是一篇表格
  52. for table in result_list:
  53. i += 1
  54. if i % 100 == 0:
  55. print("Loop", i)
  56. pre_label = eval(table[1])
  57. post_label = eval(table[2])
  58. _id = table[3]
  59. # table_text需要特殊处理
  60. try:
  61. table_text = table[0]
  62. if table_text[0] == '"':
  63. table_text = eval(table_text)
  64. else:
  65. table_text = table_text
  66. table_text = table_text.replace('\\', '/')
  67. table_text = eval(table_text)
  68. except:
  69. print("无法识别table_text", _id)
  70. continue
  71. # 只有一行的也不要
  72. if len(post_label) >= 2:
  73. data_list, data_label_list = table_pre_process(table_text, post_label, _id)
  74. elif len(pre_label) >= 2:
  75. data_list, data_label_list = table_pre_process(table_text, pre_label, _id)
  76. else:
  77. data_list, data_label_list = [], []
  78. all_data_list += data_list
  79. all_data_label_list += data_label_list
  80. print("len(all_data_list)", len(all_data_list))
  81. #
  82. # new_data_list = []
  83. # for data in data_list:
  84. # # 中文字符映射为index
  85. # # data[0] = get_sentence_index_list(data[0])
  86. # # data[1] = get_sentence_index_list(data[1])
  87. # # 维度不够,填充掩码0
  88. # # if len(data[0]) < dim:
  89. # # data[0] = data[0] + [0]*(dim-len(data[0]))
  90. # # elif len(data[0]) > dim:
  91. # # data[0] = data[0][:dim]
  92. # # if len(data[1]) < dim:
  93. # # data[1] = data[1] + [0]*(dim-len(data[1]))
  94. # # elif len(data[1]) > dim:
  95. # # data[1] = data[1][:dim]
  96. #
  97. # # 中文字符映射为Embedding
  98. # data = embedding_word(data, input_shape)
  99. # new_data_list.append(data)
  100. #
  101. # new_data_list = np.array(new_data_list)
  102. # data_label_list = np.array(data_label_list)
  103. # if np.array(new_data_list).shape[1:] == input_shape:
  104. # all_data_list.append(new_data_list)
  105. # all_data_label_list.append(data_label_list)
  106. # # 防止concat太慢
  107. # split_len = 1000
  108. # _len = int(len(all_data_list) / split_len)
  109. # all_data_list_1 = []
  110. # all_data_list_2 = []
  111. # for i in range(_len):
  112. # if i == _len - 1:
  113. # array1 = np.concatenate(all_data_list[i*split_len:])
  114. # array2 = np.concatenate(all_data_label_list[i*split_len:])
  115. # else:
  116. # array1 = np.concatenate(all_data_list[i*split_len:i*split_len+split_len])
  117. # array2 = np.concatenate(all_data_label_list[i*split_len:i*split_len+split_len])
  118. # all_data_list_1.append(array1)
  119. # all_data_list_2.append(array2)
  120. # all_data_list = np.concatenate(all_data_list_1)
  121. # all_data_label_list = np.concatenate(all_data_list_2)
  122. return all_data_list, all_data_label_list
  123. def table_pre_process(text_list, label_list, _id, is_train=True):
  124. """
  125. 表格处理,每个单元格生成2条数据,横竖各1条
  126. :param text_list:
  127. :param label_list:
  128. :param _id:
  129. :param is_train:
  130. :return:
  131. """
  132. if is_train:
  133. if len(text_list) != len(label_list):
  134. print("文字单元格与标注单元格数量不匹配!", _id)
  135. print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
  136. return [], []
  137. data_list = []
  138. data_label_list = []
  139. for i in range(len(text_list)):
  140. row = text_list[i]
  141. if is_train:
  142. row_label = label_list[i]
  143. if i > 0:
  144. last_row = text_list[i-1]
  145. if is_train:
  146. last_row_label = label_list[i-1]
  147. else:
  148. last_row = []
  149. if is_train:
  150. last_row_label = []
  151. if i < len(text_list) - 1:
  152. next_row = text_list[i+1]
  153. if is_train:
  154. next_row_label = label_list[i+1]
  155. else:
  156. next_row = []
  157. if is_train:
  158. next_row_label = []
  159. for j in range(len(row)):
  160. col = row[j]
  161. if is_train:
  162. col_label = row_label[j]
  163. # 超出表格置为None, 0
  164. if j > 0:
  165. last_col = row[j-1]
  166. if is_train:
  167. last_col_label = row_label[j-1]
  168. else:
  169. last_col = None
  170. if is_train:
  171. last_col_label = 0
  172. if j < len(row) - 1:
  173. next_col = row[j+1]
  174. if is_train:
  175. next_col_label = row_label[j+1]
  176. else:
  177. next_col = None
  178. if is_train:
  179. next_col_label = 0
  180. if last_row:
  181. last_row_col = last_row[j]
  182. if is_train:
  183. last_row_col_label = last_row_label[j]
  184. else:
  185. last_row_col = None
  186. if is_train:
  187. last_row_col_label = 0
  188. if next_row:
  189. next_row_col = next_row[j]
  190. if is_train:
  191. next_row_col_label = next_row_label[j]
  192. else:
  193. next_row_col = None
  194. if is_train:
  195. next_row_col_label = 0
  196. # 三元组有一对不相等就作为数据
  197. # if col != next_col or col != last_col:
  198. data_list.append([last_col, col, next_col])
  199. if is_train:
  200. data_label_list.append([int(last_col_label), int(col_label),
  201. int(next_col_label)])
  202. # if col != next_row_col or col != last_row_col:
  203. data_list.append([last_row_col, col, next_row_col])
  204. if is_train:
  205. data_label_list.append([int(last_row_col_label), int(col_label),
  206. int(next_row_col_label)])
  207. if is_train:
  208. return data_list, data_label_list
  209. else:
  210. return data_list
  211. def get_data_from_file(file_type):
  212. if file_type == 'np':
  213. data_path = 'train_data/data_3.npy'
  214. data_label_path = 'train_data/data_label_3.npy'
  215. array1 = np.load(data_path)
  216. array2 = np.load(data_label_path)
  217. return array1, array2
  218. elif file_type == 'txt':
  219. data_path = 'train_data/data.txt'
  220. data_label_path = 'train_data/data_label.txt'
  221. with open(data_path, 'r') as f:
  222. data_list = f.readlines()
  223. with open(data_label_path, 'r') as f:
  224. data_label_list = f.readlines()
  225. # for i in range(len(data_list)):
  226. # data_list[i] = eval(data_list[i][:-1])
  227. # data_label_list[i] = eval(data_label_list[i][:-1])
  228. return data_list, data_label_list
  229. else:
  230. print("file type error! only np and txt supported")
  231. raise Exception
  232. def processed_save_to_np():
  233. array1, array2 = get_data_from_sql()
  234. np.save('train_data/data_3.npy', array1)
  235. np.save('train_data/data_label_3.npy', array2)
  236. # with open('train_data/data.txt', 'w') as f:
  237. # for line in list1:
  238. # f.write(str(line) + "\n")
  239. # with open('train_data/data_label.txt', 'w') as f:
  240. # for line in list2:
  241. # f.write(str(line) + "\n")
  242. def processed_save_to_txt():
  243. list1, list2 = get_data_from_sql()
  244. with open('train_data/data.txt', 'w') as f:
  245. for line in list1:
  246. f.write(str(line) + "\n")
  247. with open('train_data/data_label.txt', 'w') as f:
  248. for line in list2:
  249. f.write(str(line) + "\n")
  250. def data_balance():
  251. array1, array2 = get_data_from_file()
  252. data_list = array2.tolist()
  253. all_cnt = len(data_list)
  254. cnt_0 = 0
  255. cnt_1 = 0
  256. for data in data_list:
  257. if data[0] == 1 or data[1] == 1:
  258. cnt_1 += 1
  259. else:
  260. cnt_0 += 1
  261. print("all_cnt", all_cnt)
  262. print("label has 1", cnt_1)
  263. print("label all 0", cnt_0)
  264. def test_embedding():
  265. output_shape = (2, 1, 60)
  266. data = [[None], [None]]
  267. result = embedding_word(data, output_shape)
  268. print(result)
  269. def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
  270. data_num = len(data_list)
  271. # 定义Embedding输出
  272. output_shape = (3, 10, 60)
  273. # batch循环取数据
  274. i = 0
  275. if is_train:
  276. while True:
  277. new_data_list = []
  278. new_data_label_list = []
  279. for j in range(batch_size):
  280. if i >= data_num:
  281. i = 0
  282. # 中文字符映射为Embedding
  283. data = eval(data_list[i][:-1])
  284. data_label = eval(data_label_list[i][:-1])
  285. data = embedding_word(data, output_shape)
  286. if data.shape == output_shape:
  287. new_data_list.append(data)
  288. new_data_label_list.append(data_label)
  289. i += 1
  290. new_data_list = np.array(new_data_list)
  291. new_data_label_list = np.array(new_data_label_list)
  292. X = new_data_list
  293. Y = new_data_label_list
  294. # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
  295. X = np.transpose(X, (1, 0, 2, 3))
  296. # print("input_1", X[0].shape, "input_2", X[1].shape, "input_3", X[2].shape, "Y", Y.shape)
  297. yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}, {'output': Y}
  298. else:
  299. while True:
  300. new_data_list = []
  301. for j in range(batch_size):
  302. if i >= data_num:
  303. i = 0
  304. # 中文字符映射为Embedding
  305. data = data_list[i]
  306. data = embedding_word(data, output_shape)
  307. if data.shape == output_shape:
  308. new_data_list.append(data)
  309. i += 1
  310. new_data_list = np.array(new_data_list)
  311. X = new_data_list
  312. X = np.transpose(X, (1, 0, 2, 3))
  313. yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}
  314. if __name__ == '__main__':
  315. processed_save_to_txt()
  316. # data_balance()
  317. # test_embedding()