import copy import csv import json import logging import os import random import re import time import torch from torch.utils.data import TensorDataset logger = logging.getLogger(__name__) def collate_fn(batch): """ batch should be a list of (sequence, target, length) tuples... Returns a padded tensor of sequences sorted from longest to shortest, """ all_input_ids, all_attention_mask, all_token_type_ids, all_labels = map(torch.stack, zip(*batch)) # max_len = max(all_lens).item() # all_input_ids = all_input_ids[:, :max_len] # all_attention_mask = all_attention_mask[:, :max_len] # all_token_type_ids = all_token_type_ids[:, :max_len] # all_labels = all_labels[:, :max_len] return all_input_ids, all_attention_mask, all_token_type_ids, all_labels class NSPProcessor: def __init__(self, data_path_or_list, limit=100000): self.data_path = None self.str_list = None self.limit = limit self.ratio = 0.99 if isinstance(data_path_or_list, str): self.data_path = data_path_or_list elif isinstance(data_path_or_list, list): self.str_list = data_path_or_list self.data = self.str_list if self.data_path: logging.info("Creating features from dataset file at %s", self.data_path) with open(self.data_path, 'r') as f: lines = f.readlines() # random.shuffle(lines) self.data = lines # print('len(self.data)', len(self.data)) """Processor for the chinese ner data set.""" def get_train_examples(self): """See base class.""" return self.create_examples("train") def get_eval_examples(self): """See base class.""" return self.create_examples("eval") def get_predict_examples(self): return self.create_examples("test") def create_examples(self, set_type): """Creates examples for the training and dev sets.""" if set_type == 'train': random.shuffle(self.data) self.data = self.data[:self.limit] print('len(self.data)', len(self.data)) if set_type in ['train', 'eval']: lines = [x[:-1] for x in self.data] if set_type == 'train': lines = lines[:int(len(lines)*self.ratio)] else: lines = lines[int(len(lines)*self.ratio):] else: lines = self.str_list examples = [] for (i, line) in enumerate(lines): ss = line.split('\t') examples.append([ss[0], ss[1], ss[2]]) return examples def load_examples(self, max_seq_len, tokenizer, data_type='train'): sep_token = tokenizer.sep_token cls_token = tokenizer.cls_token pad_token = tokenizer.pad_token # print(sep_token, cls_token, pad_token) if data_type == 'train': examples = self.get_train_examples() elif data_type == 'eval': examples = self.get_eval_examples() else: examples = self.get_predict_examples() features = [] # print('len(examples)', len(examples)) if data_type == 'train': print('loading example...') for (ex_index, example) in enumerate(examples): # if ex_index % 10000 == 0: # logging.info("loading example %d of %d", ex_index, len(examples)) a_tokens = tokenizer.tokenize(example[0]) b_tokens = tokenizer.tokenize(example[1]) # Account for [CLS] and [SEP] with "- 2". special_tokens_count = 2 # Truncate or Padding real_max_seq_len = int((max_seq_len - special_tokens_count) / 2) if len(a_tokens) >= real_max_seq_len: a_tokens = a_tokens[(len(a_tokens) - real_max_seq_len):] else: a_tokens += [pad_token] * (real_max_seq_len - len(a_tokens)) if len(b_tokens) >= real_max_seq_len: b_tokens = b_tokens[:real_max_seq_len] else: b_tokens += [pad_token] * (real_max_seq_len - len(b_tokens)) tokens = [cls_token] + a_tokens + [sep_token] + b_tokens + [sep_token] segment_ids = [0] + [1] * (real_max_seq_len + 1) + [2] * (real_max_seq_len + 1) # print('segment_ids', segment_ids) input_ids = tokenizer.convert_tokens_to_ids(tokens) # print('input_ids', input_ids) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1 if x != pad_token else 0 for x in tokens] label = int(example[2]) # if label == 0: # label = [0., 1.] # else: # label = [1., 0.] _dict = { 'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, 'label': label } features.append(_dict) if data_type == 'train': print('loading example finish!!!') # Convert to Tensors and build dataset all_input_ids = torch.tensor([f.get('input_ids') for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.get('input_mask') for f in features], dtype=torch.long) all_segment_ids = torch.tensor([f.get('segment_ids') for f in features], dtype=torch.long) all_label_ids = torch.tensor([f.get('label') for f in features], dtype=torch.long) # all_lens = torch.tensor([f.get('input_len') for f in features], dtype=torch.long) # dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_lens, all_label_ids) dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) return dataset def raw_data_process(path): with open(path, 'r') as f: data_list = f.readlines() # data_list = data_list[:100] # print(data_list[0]) def generate_train_data(raw_data_path): with open(raw_data_path, 'r') as f: raw_file = f.readlines() # 提取表格中的文本 text_list = [] for raw_line in raw_file: try: table = eval(eval(raw_line)) text_list += [re.sub(' ', '', y) for x in table for y in x] except: continue text_list = list(set(text_list)) # 过滤 temp_text = [] for t in text_list: if len(t) <= 1 or not re.search('[\u4e00-\u9fa5]', t): continue t = re.sub('\t', '', t) if random.choice([0, 1]): temp_text.append(t[:30]) else: temp_text.append(t[-30:]) text_list = temp_text print('len(text_list)', len(text_list)) # 获取数据对 j = 0 start_time = time.time() sentence_pairs = [] with open('nsp_src_data.txt', 'w') as f: f.write('') for text in text_list: if j % 100000 == 0: print('j', j, len(text_list), time.time()-start_time) start_time = time.time() if sentence_pairs: temp_list = [] for sen in sentence_pairs: if len(sen[0]) > 1 and len(sen[1]) > 1: temp_list.append(sen) elif len(sen[0]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[0]) or len(sen[1]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[1]): temp_list.append(sen) sentence_pairs = temp_list sentence_pairs = [str(x[0]) + '\t' + str(x[1]) + '\t' + str(x[2]) + '\n' for x in sentence_pairs] with open('nsp_src_data.txt', 'a') as f: f.writelines(sentence_pairs) sentence_pairs = [] j += 1 # 正样本 for i in range(len(text)-1): if re.search('[\u4e00-\u9fa5]', text[i+1]) \ and re.search('[\u4e00-\u9fa5]', text[i+1:]) \ and re.search('[\u4e00-\u9fa5]', text[:i+1]): sentence_pairs.append([text[:i+1], text[i+1:], 1]) sentence_pairs.append([text[random.randint(0, i):i+1], text[i+1:random.randint(i+1, len(text))], 1]) sentence_pairs.append([text[i+1:], text[:i+1], 0]) sentence_pairs.append([text[i+1:random.randint(i+1, len(text))], text[random.randint(0, i):i+1], 0]) max_k = random.randint(0, 3) k = 0 while True: if k >= max_k: break rand_t = random.sample(text_list, 1)[0][:random.randint(1, 16)] if re.search('[\u4e00-\u9fa5]', rand_t): sentence_pairs.append([text[random.randint(0, i):i+1], rand_t, 0]) rand_t = random.sample(text_list, 1)[0][:random.randint(1, 16)] if re.search('[\u4e00-\u9fa5]', rand_t): sentence_pairs.append([rand_t, text[i+1:random.randint(i+1, len(text))], 0]) k += 1 rand_index = random.randint(1, 5) if len(text[:i+1]) > rand_index and len(text[i+1:]) > rand_index \ and re.search('[\u4e00-\u9fa5]', text[rand_index:i+1]) \ and re.search('[\u4e00-\u9fa5]', text[i+1:len(text)-rand_index]): sentence_pairs.append([text[rand_index:i+1], text[i+1:len(text)-rand_index], 1]) sentence_pairs.append([text[i+1:len(text)-rand_index], text[rand_index:i+1], 0]) # 负样本 # for i in range(len(text)-1): # t = random.sample(text_list, 1)[0] # if t == text: # continue # if random.choice([0, 1]): # sentence_pairs.append([text, t, 0]) # else: # sentence_pairs.append([t, text, 0]) temp_list = [] for sen in sentence_pairs: if len(sen[0]) > 0 and len(sen[1]) > 0: temp_list.append(sen) elif len(sen[0]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[0]) or len(sen[1]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[1]): temp_list.append(sen) sentence_pairs = temp_list sentence_pairs = [str(x[0]) + '\t' + str(x[1]) + '\t' + str(x[2]) + '\n' for x in sentence_pairs] with open('nsp_src_data.txt', 'a') as f: f.writelines(sentence_pairs) return def clean_train_data(): with open('nsp_src_data.txt', 'r') as f: _list = f.readlines() _list = [json.dumps(x) for x in _list] _list = list(set(_list)) _list = [json.loads(x) for x in _list] new_list = [] for l in _list: ss = l[:-1].split('\t') ss = list(set(ss)) if '' in ss: ss.remove('') if len(ss) == 3: new_list.append(l) with open('nsp_src_data.txt', 'w') as f: f.writelines(new_list) if __name__ == '__main__': # raw_data_process('datasets/product_ner/ZOL_PRODUCE_INFO.csv') generate_train_data(r'D:\Project\borderless-table-detect\torch_version\sentence_match\label_table_head_info.txt') # clean_train_data() # print('\t\tb'.split('\t'))