processing.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import copy
  2. import csv
  3. import json
  4. import logging
  5. import os
  6. import random
  7. import re
  8. import time
  9. import torch
  10. from torch.utils.data import TensorDataset
  11. logger = logging.getLogger(__name__)
  12. def collate_fn(batch):
  13. """
  14. batch should be a list of (sequence, target, length) tuples...
  15. Returns a padded tensor of sequences sorted from longest to shortest,
  16. """
  17. all_input_ids, all_attention_mask, all_token_type_ids, all_labels = map(torch.stack, zip(*batch))
  18. # max_len = max(all_lens).item()
  19. # all_input_ids = all_input_ids[:, :max_len]
  20. # all_attention_mask = all_attention_mask[:, :max_len]
  21. # all_token_type_ids = all_token_type_ids[:, :max_len]
  22. # all_labels = all_labels[:, :max_len]
  23. return all_input_ids, all_attention_mask, all_token_type_ids, all_labels
  24. class NSPProcessor:
  25. def __init__(self, data_path_or_list, limit=100000):
  26. self.data_path = None
  27. self.str_list = None
  28. self.limit = limit
  29. self.ratio = 0.99
  30. if isinstance(data_path_or_list, str):
  31. self.data_path = data_path_or_list
  32. elif isinstance(data_path_or_list, list):
  33. self.str_list = data_path_or_list
  34. self.data = self.str_list
  35. if self.data_path:
  36. logging.info("Creating features from dataset file at %s", self.data_path)
  37. with open(self.data_path, 'r') as f:
  38. lines = f.readlines()
  39. # random.shuffle(lines)
  40. self.data = lines
  41. # print('len(self.data)', len(self.data))
  42. """Processor for the chinese ner data set."""
  43. def get_train_examples(self):
  44. """See base class."""
  45. return self.create_examples("train")
  46. def get_eval_examples(self):
  47. """See base class."""
  48. return self.create_examples("eval")
  49. def get_predict_examples(self):
  50. return self.create_examples("test")
  51. def create_examples(self, set_type):
  52. """Creates examples for the training and dev sets."""
  53. if set_type == 'train':
  54. random.shuffle(self.data)
  55. self.data = self.data[:self.limit]
  56. print('len(self.data)', len(self.data))
  57. if set_type in ['train', 'eval']:
  58. lines = [x[:-1] for x in self.data]
  59. if set_type == 'train':
  60. lines = lines[:int(len(lines)*self.ratio)]
  61. else:
  62. lines = lines[int(len(lines)*self.ratio):]
  63. else:
  64. lines = self.str_list
  65. examples = []
  66. for (i, line) in enumerate(lines):
  67. ss = line.split('\t')
  68. examples.append([ss[0], ss[1], ss[2]])
  69. return examples
  70. def load_examples(self, max_seq_len, tokenizer, data_type='train'):
  71. sep_token = tokenizer.sep_token
  72. cls_token = tokenizer.cls_token
  73. pad_token = tokenizer.pad_token
  74. # print(sep_token, cls_token, pad_token)
  75. if data_type == 'train':
  76. examples = self.get_train_examples()
  77. elif data_type == 'eval':
  78. examples = self.get_eval_examples()
  79. else:
  80. examples = self.get_predict_examples()
  81. features = []
  82. # print('len(examples)', len(examples))
  83. if data_type == 'train':
  84. print('loading example...')
  85. for (ex_index, example) in enumerate(examples):
  86. # if ex_index % 10000 == 0:
  87. # logging.info("loading example %d of %d", ex_index, len(examples))
  88. a_tokens = tokenizer.tokenize(example[0])
  89. b_tokens = tokenizer.tokenize(example[1])
  90. # Account for [CLS] and [SEP] with "- 2".
  91. special_tokens_count = 2
  92. # Truncate or Padding
  93. real_max_seq_len = int((max_seq_len - special_tokens_count) / 2)
  94. if len(a_tokens) >= real_max_seq_len:
  95. a_tokens = a_tokens[(len(a_tokens) - real_max_seq_len):]
  96. else:
  97. a_tokens += [pad_token] * (real_max_seq_len - len(a_tokens))
  98. if len(b_tokens) >= real_max_seq_len:
  99. b_tokens = b_tokens[:real_max_seq_len]
  100. else:
  101. b_tokens += [pad_token] * (real_max_seq_len - len(b_tokens))
  102. tokens = [cls_token] + a_tokens + [sep_token] + b_tokens + [sep_token]
  103. segment_ids = [0] + [1] * (real_max_seq_len + 1) + [2] * (real_max_seq_len + 1)
  104. # print('segment_ids', segment_ids)
  105. input_ids = tokenizer.convert_tokens_to_ids(tokens)
  106. # print('input_ids', input_ids)
  107. # The mask has 1 for real tokens and 0 for padding tokens. Only real
  108. # tokens are attended to.
  109. input_mask = [1 if x != pad_token else 0 for x in tokens]
  110. label = int(example[2])
  111. # if label == 0:
  112. # label = [0., 1.]
  113. # else:
  114. # label = [1., 0.]
  115. _dict = {
  116. 'input_ids': input_ids,
  117. 'input_mask': input_mask,
  118. 'segment_ids': segment_ids,
  119. 'label': label
  120. }
  121. features.append(_dict)
  122. if data_type == 'train':
  123. print('loading example finish!!!')
  124. # Convert to Tensors and build dataset
  125. all_input_ids = torch.tensor([f.get('input_ids') for f in features], dtype=torch.long)
  126. all_input_mask = torch.tensor([f.get('input_mask') for f in features], dtype=torch.long)
  127. all_segment_ids = torch.tensor([f.get('segment_ids') for f in features], dtype=torch.long)
  128. all_label_ids = torch.tensor([f.get('label') for f in features], dtype=torch.long)
  129. # all_lens = torch.tensor([f.get('input_len') for f in features], dtype=torch.long)
  130. # dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_lens, all_label_ids)
  131. dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
  132. return dataset
  133. def raw_data_process(path):
  134. with open(path, 'r') as f:
  135. data_list = f.readlines()
  136. # data_list = data_list[:100]
  137. # print(data_list[0])
  138. def generate_train_data(raw_data_path):
  139. with open(raw_data_path, 'r') as f:
  140. raw_file = f.readlines()
  141. # 提取表格中的文本
  142. text_list = []
  143. for raw_line in raw_file:
  144. try:
  145. table = eval(eval(raw_line))
  146. text_list += [re.sub(' ', '', y) for x in table for y in x]
  147. except:
  148. continue
  149. text_list = list(set(text_list))
  150. # 过滤
  151. temp_text = []
  152. for t in text_list:
  153. if len(t) <= 1 or not re.search('[\u4e00-\u9fa5]', t):
  154. continue
  155. t = re.sub('\t', '', t)
  156. if random.choice([0, 1]):
  157. temp_text.append(t[:30])
  158. else:
  159. temp_text.append(t[-30:])
  160. text_list = temp_text
  161. print('len(text_list)', len(text_list))
  162. # 获取数据对
  163. j = 0
  164. start_time = time.time()
  165. sentence_pairs = []
  166. with open('nsp_src_data.txt', 'w') as f:
  167. f.write('')
  168. for text in text_list:
  169. if j % 100000 == 0:
  170. print('j', j, len(text_list), time.time()-start_time)
  171. start_time = time.time()
  172. if sentence_pairs:
  173. temp_list = []
  174. for sen in sentence_pairs:
  175. if len(sen[0]) > 1 and len(sen[1]) > 1:
  176. temp_list.append(sen)
  177. elif len(sen[0]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[0]) or len(sen[1]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[1]):
  178. temp_list.append(sen)
  179. sentence_pairs = temp_list
  180. sentence_pairs = [str(x[0]) + '\t' + str(x[1]) + '\t' + str(x[2]) + '\n' for x in sentence_pairs]
  181. with open('nsp_src_data.txt', 'a') as f:
  182. f.writelines(sentence_pairs)
  183. sentence_pairs = []
  184. j += 1
  185. # 正样本
  186. for i in range(len(text)-1):
  187. if re.search('[\u4e00-\u9fa5]', text[i+1]) \
  188. and re.search('[\u4e00-\u9fa5]', text[i+1:]) \
  189. and re.search('[\u4e00-\u9fa5]', text[:i+1]):
  190. sentence_pairs.append([text[:i+1], text[i+1:], 1])
  191. sentence_pairs.append([text[random.randint(0, i):i+1], text[i+1:random.randint(i+1, len(text))], 1])
  192. sentence_pairs.append([text[i+1:], text[:i+1], 0])
  193. sentence_pairs.append([text[i+1:random.randint(i+1, len(text))], text[random.randint(0, i):i+1], 0])
  194. max_k = random.randint(0, 3)
  195. k = 0
  196. while True:
  197. if k >= max_k:
  198. break
  199. rand_t = random.sample(text_list, 1)[0][:random.randint(1, 16)]
  200. if re.search('[\u4e00-\u9fa5]', rand_t):
  201. sentence_pairs.append([text[random.randint(0, i):i+1], rand_t, 0])
  202. rand_t = random.sample(text_list, 1)[0][:random.randint(1, 16)]
  203. if re.search('[\u4e00-\u9fa5]', rand_t):
  204. sentence_pairs.append([rand_t, text[i+1:random.randint(i+1, len(text))], 0])
  205. k += 1
  206. rand_index = random.randint(1, 5)
  207. if len(text[:i+1]) > rand_index and len(text[i+1:]) > rand_index \
  208. and re.search('[\u4e00-\u9fa5]', text[rand_index:i+1]) \
  209. and re.search('[\u4e00-\u9fa5]', text[i+1:len(text)-rand_index]):
  210. sentence_pairs.append([text[rand_index:i+1], text[i+1:len(text)-rand_index], 1])
  211. sentence_pairs.append([text[i+1:len(text)-rand_index], text[rand_index:i+1], 0])
  212. # 负样本
  213. # for i in range(len(text)-1):
  214. # t = random.sample(text_list, 1)[0]
  215. # if t == text:
  216. # continue
  217. # if random.choice([0, 1]):
  218. # sentence_pairs.append([text, t, 0])
  219. # else:
  220. # sentence_pairs.append([t, text, 0])
  221. temp_list = []
  222. for sen in sentence_pairs:
  223. if len(sen[0]) > 0 and len(sen[1]) > 0:
  224. temp_list.append(sen)
  225. elif len(sen[0]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[0]) or len(sen[1]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[1]):
  226. temp_list.append(sen)
  227. sentence_pairs = temp_list
  228. sentence_pairs = [str(x[0]) + '\t' + str(x[1]) + '\t' + str(x[2]) + '\n' for x in sentence_pairs]
  229. with open('nsp_src_data.txt', 'a') as f:
  230. f.writelines(sentence_pairs)
  231. return
  232. def clean_train_data():
  233. with open('nsp_src_data.txt', 'r') as f:
  234. _list = f.readlines()
  235. _list = [json.dumps(x) for x in _list]
  236. _list = list(set(_list))
  237. _list = [json.loads(x) for x in _list]
  238. new_list = []
  239. for l in _list:
  240. ss = l[:-1].split('\t')
  241. ss = list(set(ss))
  242. if '' in ss:
  243. ss.remove('')
  244. if len(ss) == 3:
  245. new_list.append(l)
  246. with open('nsp_src_data.txt', 'w') as f:
  247. f.writelines(new_list)
  248. if __name__ == '__main__':
  249. # raw_data_process('datasets/product_ner/ZOL_PRODUCE_INFO.csv')
  250. generate_train_data(r'D:\Project\borderless-table-detect\torch_version\sentence_match\label_table_head_info.txt')
  251. # clean_train_data()
  252. # print('\t\tb'.split('\t'))