123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import os
- from torch.utils.data import DataLoader
- from transformers import AutoTokenizer
- import torch
- from botr.nsp.model import ElectraNSPModel
- from botr.nsp.processing import NSPProcessor, collate_fn
- model_path = os.path.abspath(os.path.dirname(__file__)) + '/model/'
- max_seq_len = 32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- tokenizer = AutoTokenizer.from_pretrained(model_path)
- model = ElectraNSPModel.from_pretrained(model_path)
- model.eval()
- batch_size = 8
- def nsp_predict(data_list, has_label=False):
- if not has_label:
- data_list = [x + [False] for x in data_list]
- next_list = []
- for prompt, next_sentence, label in data_list:
- data_list = [prompt + '\t' + next_sentence + '\t0']
- processor = NSPProcessor(data_list)
- predict_dataset = processor.load_examples(max_seq_len, tokenizer, data_type='predict')
- predict_dataloader = DataLoader(predict_dataset, batch_size=batch_size, collate_fn=collate_fn)
- with torch.no_grad():
- for step, batch in enumerate(predict_dataloader):
- batch = tuple(t.to(device) for t in batch)
- inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
- outputs = model(**inputs)
- logits = outputs.logits
- if logits[0, 0] < logits[0, 1]:
- next_flag = True
- else:
- next_flag = False
- print(prompt, next_sentence, label, next_flag)
- next_list.append(next_flag)
- return next_list
- if __name__ == '__main__':
- # 包头市青山区人民政府 青山路办事处 地址:包头市青山区自由路1号
- _list = [
- ["青山路办事处", "我们", False],
- ['包头市青山区人', '民政府', True],
- ['民政府', '包头市青山区人', False],
- ['地址:包头市青山区自由', '路1号', True],
- ['路1号', '地址:包头市青山区自由', False],
- ['采购人:', '地址', False],
- ['地址', '联系人', False],
- ['联系人', '电话', False],
- ['电话', '传真', False],
- ['传真', '电子邮件', False],
- ['中国铁塔股份有限公司乌鲁木', '齐市分公司', True],
- ['齐市分公司', '新疆乌鲁木齐市北京北路', False],
- ['新疆乌鲁木齐市北京北路', '1999号万象天地商务公园1-1', True],
- ['1999号万象天地商务公园1', '汤经理', False],
- ['汤经理', '17799606095', False],
- ['采购代理机', '构', True],
- ['构', '地址', False],
- ['地址', '新疆项目部', False],
- ['新疆项目部', '联系人', False],
- ['联系人', '电话', False],
- ['电话', '传真', False],
- ['传真', '电子邮箱', False],
- ['公诚管理咨询有限公司', '广州市天河区中山大道西89号A栋9', False],
- ['广州市天河区中山大道西89号A栋9', '层908-913房', True],
- ['层908-913房', '新疆乌鲁木齐市沙依巴克区喀纳斯湖', False],
- ['新疆乌鲁木齐市沙依巴克区喀纳斯湖', '北路455号新建软件园F1栋8层', True],
- ['北路455号新建软件园F1栋8层', '李朝杰、王小丫、杨磊、马季青、姚锐', False],
- ['李朝杰、王小丫、杨磊、马季青、姚锐', '锐、邹海燕、尉振军', True],
- ['锐、邹海燕、尉振军', '17699660609、18699181606、', False],
- ['17699660609、18699181606、', '13201381157、15292850990', True],
- ]
- nsp_predict(_list, has_label=True)
|