import os from torch.utils.data import DataLoader from transformers import AutoTokenizer import torch from botr.nsp.model import ElectraNSPModel from botr.nsp.processing import NSPProcessor, collate_fn model_path = os.path.abspath(os.path.dirname(__file__)) + '/model/' max_seq_len = 32 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(model_path) model = ElectraNSPModel.from_pretrained(model_path) model.eval() batch_size = 8 def nsp_predict(data_list, has_label=False): if not has_label: data_list = [x + [False] for x in data_list] next_list = [] for prompt, next_sentence, label in data_list: data_list = [prompt + '\t' + next_sentence + '\t0'] processor = NSPProcessor(data_list) predict_dataset = processor.load_examples(max_seq_len, tokenizer, data_type='predict') predict_dataloader = DataLoader(predict_dataset, batch_size=batch_size, collate_fn=collate_fn) with torch.no_grad(): for step, batch in enumerate(predict_dataloader): batch = tuple(t.to(device) for t in batch) inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} outputs = model(**inputs) logits = outputs.logits if logits[0, 0] < logits[0, 1]: next_flag = True else: next_flag = False print(prompt, next_sentence, label, next_flag) next_list.append(next_flag) return next_list if __name__ == '__main__': # 包头市青山区人民政府 青山路办事处 地址:包头市青山区自由路1号 _list = [ ["青山路办事处", "我们", False], ['包头市青山区人', '民政府', True], ['民政府', '包头市青山区人', False], ['地址:包头市青山区自由', '路1号', True], ['路1号', '地址:包头市青山区自由', False], ['采购人:', '地址', False], ['地址', '联系人', False], ['联系人', '电话', False], ['电话', '传真', False], ['传真', '电子邮件', False], ['中国铁塔股份有限公司乌鲁木', '齐市分公司', True], ['齐市分公司', '新疆乌鲁木齐市北京北路', False], ['新疆乌鲁木齐市北京北路', '1999号万象天地商务公园1-1', True], ['1999号万象天地商务公园1', '汤经理', False], ['汤经理', '17799606095', False], ['采购代理机', '构', True], ['构', '地址', False], ['地址', '新疆项目部', False], ['新疆项目部', '联系人', False], ['联系人', '电话', False], ['电话', '传真', False], ['传真', '电子邮箱', False], ['公诚管理咨询有限公司', '广州市天河区中山大道西89号A栋9', False], ['广州市天河区中山大道西89号A栋9', '层908-913房', True], ['层908-913房', '新疆乌鲁木齐市沙依巴克区喀纳斯湖', False], ['新疆乌鲁木齐市沙依巴克区喀纳斯湖', '北路455号新建软件园F1栋8层', True], ['北路455号新建软件园F1栋8层', '李朝杰、王小丫、杨磊、马季青、姚锐', False], ['李朝杰、王小丫、杨磊、马季青、姚锐', '锐、邹海燕、尉振军', True], ['锐、邹海燕、尉振军', '17699660609、18699181606、', False], ['17699660609、18699181606、', '13201381157、15292850990', True], ] nsp_predict(_list, has_label=True)