123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- import copy
- import csv
- import json
- import logging
- import os
- import random
- import re
- import time
- import torch
- from torch.utils.data import TensorDataset
- logger = logging.getLogger(__name__)
- def collate_fn(batch):
- """
- batch should be a list of (sequence, target, length) tuples...
- Returns a padded tensor of sequences sorted from longest to shortest,
- """
- all_input_ids, all_attention_mask, all_token_type_ids, all_labels = map(torch.stack, zip(*batch))
- # max_len = max(all_lens).item()
- # all_input_ids = all_input_ids[:, :max_len]
- # all_attention_mask = all_attention_mask[:, :max_len]
- # all_token_type_ids = all_token_type_ids[:, :max_len]
- # all_labels = all_labels[:, :max_len]
- return all_input_ids, all_attention_mask, all_token_type_ids, all_labels
- class NSPProcessor:
- def __init__(self, data_path_or_list, limit=100000):
- self.data_path = None
- self.str_list = None
- self.limit = limit
- self.ratio = 0.99
- if isinstance(data_path_or_list, str):
- self.data_path = data_path_or_list
- elif isinstance(data_path_or_list, list):
- self.str_list = data_path_or_list
- self.data = self.str_list
- if self.data_path:
- logging.info("Creating features from dataset file at %s", self.data_path)
- with open(self.data_path, 'r') as f:
- lines = f.readlines()
- # random.shuffle(lines)
- self.data = lines
- # print('len(self.data)', len(self.data))
- """Processor for the chinese ner data set."""
- def get_train_examples(self):
- """See base class."""
- return self.create_examples("train")
- def get_eval_examples(self):
- """See base class."""
- return self.create_examples("eval")
- def get_predict_examples(self):
- return self.create_examples("test")
- def create_examples(self, set_type):
- """Creates examples for the training and dev sets."""
- if set_type == 'train':
- random.shuffle(self.data)
- self.data = self.data[:self.limit]
- print('len(self.data)', len(self.data))
- if set_type in ['train', 'eval']:
- lines = [x[:-1] for x in self.data]
- if set_type == 'train':
- lines = lines[:int(len(lines)*self.ratio)]
- else:
- lines = lines[int(len(lines)*self.ratio):]
- else:
- lines = self.str_list
- examples = []
- for (i, line) in enumerate(lines):
- ss = line.split('\t')
- examples.append([ss[0], ss[1], ss[2]])
- return examples
- def load_examples(self, max_seq_len, tokenizer, data_type='train'):
- sep_token = tokenizer.sep_token
- cls_token = tokenizer.cls_token
- pad_token = tokenizer.pad_token
- # print(sep_token, cls_token, pad_token)
- if data_type == 'train':
- examples = self.get_train_examples()
- elif data_type == 'eval':
- examples = self.get_eval_examples()
- else:
- examples = self.get_predict_examples()
- features = []
- # print('len(examples)', len(examples))
- if data_type == 'train':
- print('loading example...')
- for (ex_index, example) in enumerate(examples):
- # if ex_index % 10000 == 0:
- # logging.info("loading example %d of %d", ex_index, len(examples))
- a_tokens = tokenizer.tokenize(example[0])
- b_tokens = tokenizer.tokenize(example[1])
- # Account for [CLS] and [SEP] with "- 2".
- special_tokens_count = 2
- # Truncate or Padding
- real_max_seq_len = int((max_seq_len - special_tokens_count) / 2)
- if len(a_tokens) >= real_max_seq_len:
- a_tokens = a_tokens[(len(a_tokens) - real_max_seq_len):]
- else:
- a_tokens += [pad_token] * (real_max_seq_len - len(a_tokens))
- if len(b_tokens) >= real_max_seq_len:
- b_tokens = b_tokens[:real_max_seq_len]
- else:
- b_tokens += [pad_token] * (real_max_seq_len - len(b_tokens))
- tokens = [cls_token] + a_tokens + [sep_token] + b_tokens + [sep_token]
- segment_ids = [0] + [1] * (real_max_seq_len + 1) + [2] * (real_max_seq_len + 1)
- # print('segment_ids', segment_ids)
- input_ids = tokenizer.convert_tokens_to_ids(tokens)
- # print('input_ids', input_ids)
- # The mask has 1 for real tokens and 0 for padding tokens. Only real
- # tokens are attended to.
- input_mask = [1 if x != pad_token else 0 for x in tokens]
- label = int(example[2])
- # if label == 0:
- # label = [0., 1.]
- # else:
- # label = [1., 0.]
- _dict = {
- 'input_ids': input_ids,
- 'input_mask': input_mask,
- 'segment_ids': segment_ids,
- 'label': label
- }
- features.append(_dict)
- if data_type == 'train':
- print('loading example finish!!!')
- # Convert to Tensors and build dataset
- all_input_ids = torch.tensor([f.get('input_ids') for f in features], dtype=torch.long)
- all_input_mask = torch.tensor([f.get('input_mask') for f in features], dtype=torch.long)
- all_segment_ids = torch.tensor([f.get('segment_ids') for f in features], dtype=torch.long)
- all_label_ids = torch.tensor([f.get('label') for f in features], dtype=torch.long)
- # all_lens = torch.tensor([f.get('input_len') for f in features], dtype=torch.long)
- # dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_lens, all_label_ids)
- dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
- return dataset
- def raw_data_process(path):
- with open(path, 'r') as f:
- data_list = f.readlines()
- # data_list = data_list[:100]
- # print(data_list[0])
- def generate_train_data(raw_data_path):
- with open(raw_data_path, 'r') as f:
- raw_file = f.readlines()
- # 提取表格中的文本
- text_list = []
- for raw_line in raw_file:
- try:
- table = eval(eval(raw_line))
- text_list += [re.sub(' ', '', y) for x in table for y in x]
- except:
- continue
- text_list = list(set(text_list))
- # 过滤
- temp_text = []
- for t in text_list:
- if len(t) <= 1 or not re.search('[\u4e00-\u9fa5]', t):
- continue
- t = re.sub('\t', '', t)
- if random.choice([0, 1]):
- temp_text.append(t[:30])
- else:
- temp_text.append(t[-30:])
- text_list = temp_text
- print('len(text_list)', len(text_list))
- # 获取数据对
- j = 0
- start_time = time.time()
- sentence_pairs = []
- with open('nsp_src_data.txt', 'w') as f:
- f.write('')
- for text in text_list:
- if j % 100000 == 0:
- print('j', j, len(text_list), time.time()-start_time)
- start_time = time.time()
- if sentence_pairs:
- temp_list = []
- for sen in sentence_pairs:
- if len(sen[0]) > 1 and len(sen[1]) > 1:
- temp_list.append(sen)
- elif len(sen[0]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[0]) or len(sen[1]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[1]):
- temp_list.append(sen)
- sentence_pairs = temp_list
- sentence_pairs = [str(x[0]) + '\t' + str(x[1]) + '\t' + str(x[2]) + '\n' for x in sentence_pairs]
- with open('nsp_src_data.txt', 'a') as f:
- f.writelines(sentence_pairs)
- sentence_pairs = []
- j += 1
- # 正样本
- for i in range(len(text)-1):
- if re.search('[\u4e00-\u9fa5]', text[i+1]) \
- and re.search('[\u4e00-\u9fa5]', text[i+1:]) \
- and re.search('[\u4e00-\u9fa5]', text[:i+1]):
- sentence_pairs.append([text[:i+1], text[i+1:], 1])
- sentence_pairs.append([text[random.randint(0, i):i+1], text[i+1:random.randint(i+1, len(text))], 1])
- sentence_pairs.append([text[i+1:], text[:i+1], 0])
- sentence_pairs.append([text[i+1:random.randint(i+1, len(text))], text[random.randint(0, i):i+1], 0])
- max_k = random.randint(0, 3)
- k = 0
- while True:
- if k >= max_k:
- break
- rand_t = random.sample(text_list, 1)[0][:random.randint(1, 16)]
- if re.search('[\u4e00-\u9fa5]', rand_t):
- sentence_pairs.append([text[random.randint(0, i):i+1], rand_t, 0])
- rand_t = random.sample(text_list, 1)[0][:random.randint(1, 16)]
- if re.search('[\u4e00-\u9fa5]', rand_t):
- sentence_pairs.append([rand_t, text[i+1:random.randint(i+1, len(text))], 0])
- k += 1
- rand_index = random.randint(1, 5)
- if len(text[:i+1]) > rand_index and len(text[i+1:]) > rand_index \
- and re.search('[\u4e00-\u9fa5]', text[rand_index:i+1]) \
- and re.search('[\u4e00-\u9fa5]', text[i+1:len(text)-rand_index]):
- sentence_pairs.append([text[rand_index:i+1], text[i+1:len(text)-rand_index], 1])
- sentence_pairs.append([text[i+1:len(text)-rand_index], text[rand_index:i+1], 0])
- # 负样本
- # for i in range(len(text)-1):
- # t = random.sample(text_list, 1)[0]
- # if t == text:
- # continue
- # if random.choice([0, 1]):
- # sentence_pairs.append([text, t, 0])
- # else:
- # sentence_pairs.append([t, text, 0])
- temp_list = []
- for sen in sentence_pairs:
- if len(sen[0]) > 0 and len(sen[1]) > 0:
- temp_list.append(sen)
- elif len(sen[0]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[0]) or len(sen[1]) >= 1 and re.search('[\u4e00-\u9fa5]', sen[1]):
- temp_list.append(sen)
- sentence_pairs = temp_list
- sentence_pairs = [str(x[0]) + '\t' + str(x[1]) + '\t' + str(x[2]) + '\n' for x in sentence_pairs]
- with open('nsp_src_data.txt', 'a') as f:
- f.writelines(sentence_pairs)
- return
- def clean_train_data():
- with open('nsp_src_data.txt', 'r') as f:
- _list = f.readlines()
- _list = [json.dumps(x) for x in _list]
- _list = list(set(_list))
- _list = [json.loads(x) for x in _list]
- new_list = []
- for l in _list:
- ss = l[:-1].split('\t')
- ss = list(set(ss))
- if '' in ss:
- ss.remove('')
- if len(ss) == 3:
- new_list.append(l)
- with open('nsp_src_data.txt', 'w') as f:
- f.writelines(new_list)
- if __name__ == '__main__':
- # raw_data_process('datasets/product_ner/ZOL_PRODUCE_INFO.csv')
- generate_train_data(r'D:\Project\borderless-table-detect\torch_version\sentence_match\label_table_head_info.txt')
- # clean_train_data()
- # print('\t\tb'.split('\t'))
|