predict.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. from torch.utils.data import DataLoader
  3. from transformers import AutoTokenizer
  4. import torch
  5. from botr.nsp.model import ElectraNSPModel
  6. from botr.nsp.processing import NSPProcessor, collate_fn
  7. model_path = os.path.abspath(os.path.dirname(__file__)) + '/model/'
  8. max_seq_len = 32
  9. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  10. tokenizer = AutoTokenizer.from_pretrained(model_path)
  11. model = ElectraNSPModel.from_pretrained(model_path)
  12. model.eval()
  13. batch_size = 8
  14. def nsp_predict(data_list, has_label=False):
  15. if not has_label:
  16. data_list = [x + [False] for x in data_list]
  17. next_list = []
  18. for prompt, next_sentence, label in data_list:
  19. data_list = [prompt + '\t' + next_sentence + '\t0']
  20. processor = NSPProcessor(data_list)
  21. predict_dataset = processor.load_examples(max_seq_len, tokenizer, data_type='predict')
  22. predict_dataloader = DataLoader(predict_dataset, batch_size=batch_size, collate_fn=collate_fn)
  23. with torch.no_grad():
  24. for step, batch in enumerate(predict_dataloader):
  25. batch = tuple(t.to(device) for t in batch)
  26. inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
  27. outputs = model(**inputs)
  28. logits = outputs.logits
  29. if logits[0, 0] < logits[0, 1]:
  30. next_flag = True
  31. else:
  32. next_flag = False
  33. print(prompt, next_sentence, label, next_flag)
  34. next_list.append(next_flag)
  35. return next_list
  36. if __name__ == '__main__':
  37. # 包头市青山区人民政府 青山路办事处 地址:包头市青山区自由路1号
  38. _list = [
  39. ["青山路办事处", "我们", False],
  40. ['包头市青山区人', '民政府', True],
  41. ['民政府', '包头市青山区人', False],
  42. ['地址:包头市青山区自由', '路1号', True],
  43. ['路1号', '地址:包头市青山区自由', False],
  44. ['采购人:', '地址', False],
  45. ['地址', '联系人', False],
  46. ['联系人', '电话', False],
  47. ['电话', '传真', False],
  48. ['传真', '电子邮件', False],
  49. ['中国铁塔股份有限公司乌鲁木', '齐市分公司', True],
  50. ['齐市分公司', '新疆乌鲁木齐市北京北路', False],
  51. ['新疆乌鲁木齐市北京北路', '1999号万象天地商务公园1-1', True],
  52. ['1999号万象天地商务公园1', '汤经理', False],
  53. ['汤经理', '17799606095', False],
  54. ['采购代理机', '构', True],
  55. ['构', '地址', False],
  56. ['地址', '新疆项目部', False],
  57. ['新疆项目部', '联系人', False],
  58. ['联系人', '电话', False],
  59. ['电话', '传真', False],
  60. ['传真', '电子邮箱', False],
  61. ['公诚管理咨询有限公司', '广州市天河区中山大道西89号A栋9', False],
  62. ['广州市天河区中山大道西89号A栋9', '层908-913房', True],
  63. ['层908-913房', '新疆乌鲁木齐市沙依巴克区喀纳斯湖', False],
  64. ['新疆乌鲁木齐市沙依巴克区喀纳斯湖', '北路455号新建软件园F1栋8层', True],
  65. ['北路455号新建软件园F1栋8层', '李朝杰、王小丫、杨磊、马季青、姚锐', False],
  66. ['李朝杰、王小丫、杨磊、马季青、姚锐', '锐、邹海燕、尉振军', True],
  67. ['锐、邹海燕、尉振军', '17699660609、18699181606、', False],
  68. ['17699660609、18699181606、', '13201381157、15292850990', True],
  69. ]
  70. nsp_predict(_list, has_label=True)