Browse Source

中标联系人提取修复

znj 4 tuần trước cách đây
mục cha
commit
238ac1f82d

+ 15 - 4
BiddingKG/dl/channel/channel_bert.py

@@ -443,6 +443,17 @@ def channel_predict(title,text):
     # to torch data
     text = [text]
     text_max_len = 2000
+    # text = [tokenizer.encode_plus(
+    #     _t,
+    #     add_special_tokens=True,  # 添加特殊标记,如[CLS]和[SEP]
+    #     max_length=text_max_len,  # 设置最大长度
+    #     padding='max_length',  # 填充到最大长度
+    #     truncation=True,  # 截断超过最大长度的文本
+    #     return_attention_mask=True,  # 返回attention_mask
+    #     return_tensors='pt'  # 返回PyTorch张量
+    # ) for _t in text]
+    # text = [torch.LongTensor(np.array([_t['input_ids'].numpy()[0] for _t in text])).to(device),
+    #      torch.LongTensor(np.array([_t['attention_mask'].numpy()[0] for _t in text])).to(device)]
     text = [tokenizer.encode_plus(
         _t,
         add_special_tokens=True,  # 添加特殊标记,如[CLS]和[SEP]
@@ -450,10 +461,11 @@ def channel_predict(title,text):
         padding='max_length',  # 填充到最大长度
         truncation=True,  # 截断超过最大长度的文本
         return_attention_mask=True,  # 返回attention_mask
-        return_tensors='pt'  # 返回PyTorch张量
+        # return_tensors='pt'  # 返回PyTorch张量
+        return_tensors=None  #不返回PyTorch张量
     ) for _t in text]
-    text = [torch.LongTensor(np.array([_t['input_ids'].numpy()[0] for _t in text])).to(device),
-         torch.LongTensor(np.array([_t['attention_mask'].numpy()[0] for _t in text])).to(device)]
+    text = [torch.LongTensor(np.array([_t['input_ids'] for _t in text])).to(device),
+            torch.LongTensor(np.array([_t['attention_mask'] for _t in text])).to(device)]
     # predict
     with torch.no_grad():
         outputs = model(None, text)
@@ -581,7 +593,6 @@ def merge_channel(list_articles,channel_dic,original_docchannel):
             # print(text, '\n pred_res', pred)
             if pred is not None and original_docchannel: # 无original_docchannel时不进行对比校正
                 channel_dic = merge_rule(title,text,docchannel,pred,channel_dic,original_docchannel)
-
     elif doctype=='采招数据' and docchannel=="":
         pred = channel_predict(title, text)
         # print(text, '\n pred_res', pred)

+ 8 - 4
BiddingKG/dl/interface/getAttributes.py

@@ -1785,7 +1785,8 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_senten
             # print('loc_relation2',_company.entity_text,_relation.entity_text)
             _company.pointer_address = _relation
     # "联系人——联系电话" 链接规则补充
-    person_phone_EntityList = [ent for ent in pre_entity+ phone_entitys if ent.entity_type not in ['company','org','location']]
+    # person_phone_EntityList = [ent for ent in pre_entity+ phone_entitys if ent.entity_type not in ['company','org','location']]
+    person_phone_EntityList = [ent for ent in pre_entity+ phone_entitys if ent.entity_type not in ['location']]
     person_phone_EntityList = sorted(person_phone_EntityList, key=lambda x: (x.sentence_index, x.begin_index))
     t_match_list = []
     for ent_idx in range(len(person_phone_EntityList)):
@@ -1815,7 +1816,7 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_senten
                             else:
                                 break
                     else:
-                        if distance < 40:
+                        if distance < 30:
                             # value = (-1 / 2 * (distance ** 2)) / 10000
                             t_match_list.append(Match(entity, after_entity, value))
                             match_nums += 1
@@ -1823,8 +1824,10 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_senten
                                 byNotPerson_match_nums += 1
                             else:
                                 break
-                else:
+                elif after_entity.entity_type == "person":
                     person_nums += 1
+                elif after_entity.entity_type in ["company","org"]:
+                    break
             # 前向查找属性
             if ent_idx != 0 and (not match_nums or not byNotPerson_match_nums):
                 previous_entity = person_phone_EntityList[ent_idx - 1]
@@ -1832,12 +1835,13 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_senten
                     # if previous_entity.sentence_index == entity.sentence_index:
                     distance = (tokens_num_dict[entity.sentence_index] + entity.begin_index) - (
                             tokens_num_dict[previous_entity.sentence_index] + previous_entity.end_index)
-                    if distance < 40:
+                    if distance < 30:
                         # 前向 没有 /10000
                         value = (-1 / 2 * (distance ** 2))
                         t_match_list.append(Match(entity, previous_entity, value))
     # km算法分配求解(person-phone)
     t_match_list = [mat for mat in t_match_list if mat.main_role not in linked_connetPerson and mat.attribute not in linked_phone]
+    # print([(mat.main_role.entity_text,mat.attribute.entity_text) for mat in t_match_list])
     personphone_result = dispatch(t_match_list)
     personphone_result = sorted(personphone_result, key=lambda x: (x[0].sentence_index, x[0].begin_index))
     for match in personphone_result: