Sfoglia il codice sorgente

属性链接模型优化

znj 3 anni fa
parent
commit
ca622ef984

+ 40 - 5
BiddingKG/dl/interface/getAttributes.py

@@ -983,6 +983,7 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity
     tenderee_phone = set()
     agency_contact = set()
     agency_phone = set()
+    winter_contact = set()
     for _person in person_list:
         if _person.label == 1:
             tenderee_contact.add(_person.entity_text)
@@ -1238,11 +1239,15 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity
                                 break
                 # print(3,combo[0].entity_text,combo[1].entity_text)
     # 更新 PackDict
+    not_sure_linked = []
     for link_p in list(linked_company):
         for k in PackDict.keys():
             for i in range(len(PackDict[k]["roleList"])):
                 if PackDict[k]["roleList"][i].role_name == "tenderee":
-                    if PackDict[k]["roleList"][i].entity_text == link_p.entity_text or link_p.label == 0:
+                    if PackDict[k]["roleList"][i].entity_text != link_p.entity_text and link_p.label == 0:
+                        not_sure_linked.append(link_p)
+                        continue
+                    if PackDict[k]["roleList"][i].entity_text == link_p.entity_text:
                         for per in link_p.pointer_person:
                             person_phone = [phone for phone in per.person_phone] if per.person_phone else []
                             if not person_phone:
@@ -1253,7 +1258,10 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity
                                 if per.entity_text not in agency_contact and _p.entity_text not in agency_phone:
                                     PackDict[k]["roleList"][i].linklist.append((per.entity_text, _p.entity_text))
                 elif PackDict[k]["roleList"][i].role_name == "agency":
-                    if PackDict[k]["roleList"][i].entity_text == link_p.entity_text or link_p.label == 1:
+                    if PackDict[k]["roleList"][i].entity_text != link_p.entity_text and link_p.label == 1:
+                        not_sure_linked.append(link_p)
+                        continue
+                    if PackDict[k]["roleList"][i].entity_text == link_p.entity_text:
                         for per in link_p.pointer_person:
                             person_phone = [phone for phone in per.person_phone] if per.person_phone else []
                             if not person_phone:
@@ -1270,12 +1278,39 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity
                             if not person_phone:
                                 if per.entity_text not in tenderee_contact and per.entity_text not in agency_contact:
                                     PackDict[k]["roleList"][i].linklist.append((per.entity_text, ""))
+                                    winter_contact.add(per.entity_text)
                                     continue
                             for _p in person_phone:
                                 if per.entity_text not in tenderee_contact and _p.entity_text not in tenderee_phone and \
                                         per.entity_text not in agency_contact and _p.entity_text not in agency_phone:
                                     PackDict[k]["roleList"][i].linklist.append((per.entity_text, _p.entity_text))
-
+                                    winter_contact.add(per.entity_text)
+    # 更新org/company实体label为0,1的链接
+    for link_p in not_sure_linked:
+        for k in PackDict.keys():
+            for i in range(len(PackDict[k]["roleList"])):
+                if PackDict[k]["roleList"][i].role_name == "tenderee":
+                    if link_p.label == 0:
+                        for per in link_p.pointer_person:
+                            person_phone = [phone for phone in per.person_phone] if per.person_phone else []
+                            if not person_phone:
+                                if per.entity_text not in agency_contact and per.entity_text not in winter_contact:
+                                    PackDict[k]["roleList"][i].linklist.append((per.entity_text, ""))
+                                    continue
+                            for _p in person_phone:
+                                if per.entity_text not in agency_contact and _p.entity_text not in agency_phone and per.entity_text not in winter_contact:
+                                    PackDict[k]["roleList"][i].linklist.append((per.entity_text, _p.entity_text))
+                elif PackDict[k]["roleList"][i].role_name == "agency":
+                    if link_p.label == 1:
+                        for per in link_p.pointer_person:
+                            person_phone = [phone for phone in per.person_phone] if per.person_phone else []
+                            if not person_phone:
+                                if per.entity_text not in tenderee_contact and per.entity_text not in winter_contact:
+                                    PackDict[k]["roleList"][i].linklist.append((per.entity_text, ""))
+                                    continue
+                            for _p in person_phone:
+                                if per.entity_text not in tenderee_contact and _p.entity_text not in tenderee_phone and per.entity_text not in winter_contact:
+                                    PackDict[k]["roleList"][i].linklist.append((per.entity_text, _p.entity_text))
 
     re_split = re.compile("[^\u4e00-\u9fa5、](十一|十二|十三|十四|十五|一|二|三|四|五|六|七|八|九|十)、")
     split_list = [0] * 16
@@ -1528,7 +1563,7 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity
                 if PackDict[k]["roleList"][i].role_name == "tenderee":
                     if not PackDict[k]["roleList"][i].linklist:
                         if PackDict[k]["roleList"][i].entity_text == entity.entity_text or entity.label == 0:
-                            if person_ not in agency_contact and len(set(phone_)&set(agency_phone))==0:
+                            if person_ not in agency_contact and len(set(phone_)&set(agency_phone))==0 and person_ not in winter_contact:
                                 if not phone_:
                                     PackDict[k]["roleList"][i].linklist.append((person_, ""))
                                 for p in phone_:
@@ -1536,7 +1571,7 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity
                                 is_update = True
                 elif PackDict[k]["roleList"][i].role_name == "agency":
                     if not PackDict[k]["roleList"][i].linklist:
-                        if PackDict[k]["roleList"][i].entity_text == entity.entity_text or entity.label == 1:
+                        if PackDict[k]["roleList"][i].entity_text == entity.entity_text or entity.label == 1 and person_ not in winter_contact:
                             if person_ not in tenderee_contact and len(set(phone_)&set(tenderee_phone))==0:
                                 if not phone_:
                                     PackDict[k]["roleList"][i].linklist.append((person_, ""))

BIN
BiddingKG/dl/relation_extraction/add_words_dict.pkl


+ 113 - 27
BiddingKG/dl/relation_extraction/model.py

@@ -1,13 +1,13 @@
 #! -*- coding:utf-8 -*-
 import os,sys
-parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-sys.path.insert(0,parentdir)
+# parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+# sys.path.insert(0,parentdir)
 
-import json
+# import json
 import numpy as np
-from random import choice
-from tqdm import tqdm
-from BiddingKG.dl.common.models import save,load
+# from random import choice
+# from tqdm import tqdm
+from BiddingKG.dl.common.models import *
 from itertools import groupby
 
 
@@ -63,6 +63,62 @@ def dilated_gated_conv1d(seq, mask, dilation_rate=1):
     seq = Lambda(lambda x: x[0] * x[1])([seq, mask])
     return seq
 
+class OurLayer(Layer):
+    """定义新的Layer,增加reuse方法,允许在定义Layer时调用现成的层
+    """
+    def reuse(self, layer, *args, **kwargs):
+        if not layer.built:
+            if len(args) > 0:
+                inputs = args[0]
+            else:
+                inputs = kwargs['inputs']
+            if isinstance(inputs, list):
+                input_shape = [K.int_shape(x) for x in inputs]
+            else:
+                input_shape = K.int_shape(inputs)
+            layer.build(input_shape)
+        outputs = layer.call(*args, **kwargs)
+        for w in layer.trainable_weights:
+            if w not in self._trainable_weights:
+                self._trainable_weights.append(w)
+        for w in layer.non_trainable_weights:
+            if w not in self._non_trainable_weights:
+                self._non_trainable_weights.append(w)
+        for u in layer.updates:
+            if not hasattr(self, '_updates'):
+                self._updates = []
+            if u not in self._updates:
+                self._updates.append(u)
+        return outputs
+
+class OurBidirectional(OurLayer):
+    """自己封装双向RNN,允许传入mask,保证对齐
+    """
+    def __init__(self, layer, **args):
+        super(OurBidirectional, self).__init__(**args)
+        self.forward_layer = layer.__class__.from_config(layer.get_config())
+        self.backward_layer = layer.__class__.from_config(layer.get_config())
+        self.forward_layer.name = 'forward_' + self.forward_layer.name
+        self.backward_layer.name = 'backward_' + self.backward_layer.name
+    def reverse_sequence(self, x, mask):
+        """这里的mask.shape是[batch_size, seq_len, 1]
+        """
+        seq_len = K.round(K.sum(mask, 1)[:, 0])
+        seq_len = K.cast(seq_len, 'int32')
+        return tf.reverse_sequence(x, seq_len, seq_dim=1)
+    def call(self, inputs):
+        x, mask = inputs
+        x_forward = self.reuse(self.forward_layer, x)
+        x_backward = self.reverse_sequence(x, mask)
+        x_backward = self.reuse(self.backward_layer, x_backward)
+        x_backward = self.reverse_sequence(x_backward, mask)
+        x = K.concatenate([x_forward, x_backward], -1)
+        if K.ndim(x) == 3:
+            return x * mask
+        else:
+            return x
+    def compute_output_shape(self, input_shape):
+        return input_shape[0][:-1] + (self.forward_layer.units * 2,)
 
 class Attention(Layer):
     """多头注意力机制
@@ -140,6 +196,19 @@ def position_id(x):
     pid = K.tile(pid, [K.shape(x)[0], 1])
     return K.abs(pid - K.cast(r, 'int32'))
 
+add_dict = load(os.path.dirname(__file__)+'/../relation_extraction/add_words_dict.pkl')
+add_words = ['<unk>','<company/org>','<location>','<phone>','<contact_person>']
+model_w2v = getModel_w2v()
+def get_words_matrix(words):
+    if words in add_words:
+        return add_dict[words]
+    else:
+        item_not_space = re.sub("\s*", "", words)
+        if item_not_space in model_w2v.vocab:
+            return model_w2v[item_not_space]
+        else:
+            return add_dict['<unk>']
+
 entity_type_dict = {
     'org': '<company/org>',
     'company': '<company/org>',
@@ -150,9 +219,9 @@ entity_type_dict = {
 class Relation_extraction():
     def __init__(self,is_train=False):
         self.is_train = is_train
-        self.words_vocab = load(os.path.dirname(__file__)+'/../relation_extraction/words_vocab.pkl')
-        id2word = {i: j for i, j in enumerate(self.words_vocab)}
-        self.words2id = {j: i for i, j in id2word.items()}
+        # self.words_vocab = load(os.path.dirname(__file__)+'/../relation_extraction/words_vocab.pkl')
+        # id2word = {i: j for i, j in enumerate(self.words_vocab)}
+        # self.words2id = {j: i for i, j in id2word.items()}
         self.words_size = 128
         self.id2predicate = {
             0: "rel_person",    # 公司——联系人
@@ -162,9 +231,9 @@ class Relation_extraction():
         self.predicate2id = dict({j: i for i, j in self.id2predicate.items()})
         self.num_classes = len(self.id2predicate)
         self.maxlen = 512
-        self.word2vec = None
-        if self.is_train:
-            self.word2vec = load('words2v_matrix.pkl')
+        # self.word2vec = None
+        # if self.is_train:
+        #     self.word2vec = load('words2v_matrix.pkl')
         self.model_path = os.path.dirname(__file__)+'/../relation_extraction/models/my_best_model_oneoutput.weights'
         self.get_model()
         if self.model_path:
@@ -173,23 +242,24 @@ class Relation_extraction():
 
     def get_model(self):
         words_size = self.words_size
-        t2_in = Input(shape=(None,))
+        t2_in = Input(shape=(None,words_size)) # 词向量
+        t3_in = Input(shape=(None,)) # mask列表
         s1_in = Input(shape=(None,))
         k1_in = Input(shape=(1,))
         o1_in = Input(shape=(None, self.num_classes))
-        t2, s1, k1, o1 = t2_in, s1_in, k1_in, o1_in
-        mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(t2)
+        t2, t3, s1, k1, o1 = t2_in, t3_in, s1_in, k1_in, o1_in
+        mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(t3)
 
         pid = Lambda(position_id)(t2)
         position_embedding = Embedding(self.maxlen, words_size, embeddings_initializer='zeros')
         pv = position_embedding(pid)
 
-        t2 = Embedding(len(self.words2id), words_size, weights=[self.word2vec] if self.is_train else None, trainable=True)(t2)
+        # t2 = Embedding(len(self.words2id), words_size, weights=[self.word2vec] if self.is_train else None, trainable=True,name="words_embedding")(t2)
         t = Add()([t2, pv])
         t = Dropout(0.25)(t)
         t = Lambda(lambda x: x[0] * x[1])([t, mask])
         if K.tensorflow_backend._get_available_gpus():
-            t = Bidirectional(CuDNNGRU(64, return_sequences=True))(t)
+            t = OurBidirectional(CuDNNGRU(64, return_sequences=True))([t,mask])
         else:
             t = Bidirectional(GRU(64,return_sequences=True,reset_after=True))(t)
         t_dim = K.int_shape(t)[-1]
@@ -201,7 +271,7 @@ class Relation_extraction():
         h = Conv1D(words_size, 3, activation='relu', padding='same')(h)
         ps1 = Dense(1, activation='sigmoid')(h)
         ps1 = Lambda(lambda x: x[0] * x[1])([ps1, pn1])
-        self.subject_model = Model([t2_in], [ps1])  # 预测subject的模型
+        self.subject_model = Model([t2_in,t3_in], [ps1])  # 预测subject的模型
 
         t_max = Lambda(seq_maxpool)([t, mask])
         pc = Dense(words_size, activation='relu')(t_max)
@@ -230,9 +300,9 @@ class Relation_extraction():
         po = Dense(1, activation='sigmoid')(h)
         po1 = Dense(self.num_classes, activation='sigmoid')(h)
         po1 = Lambda(lambda x: x[0] * x[1] * x[2] * x[3])([po, po1, pc, pn1])
-        self.object_model = Model([t2_in, k1_in], [po1])
+        self.object_model = Model([t2_in,t3_in,k1_in], [po1])
 
-        train_model = Model([t2_in, s1_in, k1_in, o1_in],
+        train_model = Model([t2_in,t3_in, s1_in, k1_in, o1_in],
                             [ps1, po1])
         # loss
         s1 = K.expand_dims(s1, 2)
@@ -253,9 +323,14 @@ class Relation_extraction():
     def extract_items(self,text_in, words, rate=0.5):
         text_words = text_in
         R = []
-        _t2 = [self.words2id.get(c, 1) for c in words]
+        # _t2 = [self.words2id.get(c, 1) for c in words]
+        _t2 = np.zeros((len(words), self.words_size))
+        for i in range(len(words)):
+            _t2[i] = np.array(get_words_matrix(words[i]))
         _t2 = np.array([_t2])
-        _k1 = self.subject_model.predict([_t2])
+        _t3 = [1 for _ in words]
+        _t3 = np.array([_t3])
+        _k1 = self.subject_model.predict([_t2,_t3])
         _k1 = _k1[0, :, 0]
         _k1 = np.where(_k1 > rate)[0]
         _subjects = []
@@ -264,8 +339,9 @@ class Relation_extraction():
             _subjects.append((_subject, i, i))
         if _subjects:
             _t2 = np.repeat(_t2, len(_subjects), 0)
+            _t3 = np.repeat(_t3, len(_subjects), 0)
             _k1, _ = np.array([_s[1:] for _s in _subjects]).T.reshape((2, -1, 1))
-            _o1 = self.object_model.predict([_t2, _k1])
+            _o1 = self.object_model.predict([_t2,_t3,_k1])
             for i, _subject in enumerate(_subjects):
                 _oo1 = np.where(_o1[i] > 0.5)
                 for _ooo1, _c1 in zip(*_oo1):
@@ -331,6 +407,17 @@ class Relation_extraction():
 
 if __name__ == '__main__':
     test_model = Relation_extraction()
+    # ['<pad>','<unk>','<company/org>','<location>','<phone>','<contact_person>']
+    # add_words = ['<unk>','<company/org>','<location>','<phone>','<contact_person>']
+    # add_dict = dict()
+    # for layer in test_model.train_model.layers:
+    #     if layer.name=="words_embedding":
+    #         save(layer.get_weights()[0],"trained_words.pkl")
+    #         for i,j in zip(add_words,layer.get_weights()[0][1:6]):
+    #             add_dict[i] = j
+    #             print(i,'\n',j)
+    #         print(layer.get_weights()[0][1:6])
+    # save(add_dict,"add_words_dict.pkl")
     text_in = "索引||号||:||014583788||/||2018-00038||,||成文||日期||:||2018-11-19||,||关于||国家税务总局都昌县税务局||办公楼||七||楼||会议室||维修||改造||项目||综合||比价||成交||公告||,||关于||国家税务总局都昌县税务局||办公楼七楼会议室||维修||改造||项目||(||比价||编号||:||JXXL2018-JJ-DC001||)||综合||比价||成交||公告||,||江西新立建设管理有限公司九江分公司||受||国家税务总局都昌县税务局||委托||,||就||其||办公楼||七||楼||会议室||维修||改造||项目||(||控制||价||:||294788.86||元||)||进行||综合||比价||方式||,||比价||活动||于||2018年||11月||16日||15:30||在||都昌县万里大道和平宾馆旁三楼||江西新立建设管理有限公司九江分公司||进行||,||经||比价||小组||评审||,||比价人||确定||,||现||将||比价||结果||公式||如下||:||序号||:||1||,||比价||编号||,||JXXL2018-JJ-DC001||,||项目||内容||名称||,||都昌县税务局||办公楼||七||楼||会议室||维修||改造||项目||,||数量||:||1||,||成交||供应商||名称||,||江西芙蓉建筑工程有限公司||,||成交价||(||元||)||,||284687.67||。||一||、||比价||小组||成员||:||杨忠辉||李燕杨瑾||,||本||公告||自||发布||之||日||起||1||个||工作日||内||若||无||异议||,||将||向||中标人||发出||《||成交||通知书||》||,||二||、||联系||方式||,||单位||:||国家税务总局都昌县税务局||,||比价||代理||机构||:||江西新立建设管理有限公司九江分公司||,||联系人||:||詹女士||,||电话||:||15979976088||,||江西新立建设管理有限公司九江分公司"
     words = "索引||号||:||014583788||/||2018-00038||,||成文||日期||:||2018-11-19||,||关于||国家税务总局都昌县税务局||" \
             "办公楼||七||楼||会议室||维修||改造||项目||综合||比价||成交||公告||,||关于||国家税务总局都昌县税务局||办公楼七楼会议室||" \
@@ -342,9 +429,8 @@ if __name__ == '__main__':
             "供应商||名称||,||<company/org>||,||成交价||(||元||)||,||284687.67||。||一||、||比价||小组||成员||:||杨忠辉||李燕杨瑾||," \
             "||本||公告||自||发布||之||日||起||1||个||工作日||内||若||无||异议||,||将||向||中标人||发出||《||成交||通知书||》||,||二||、||联系||方式||," \
             "||单位||:||<company/org>||,||比价||代理||机构||:||<company/org>||,||联系人||:||<contact_person>||,||电话||:||<phone>||,||江西新立建设管理有限公司九江分公司"
-    text_in = "索引"
-    words = "索引"
+    # text_in = "索引"
+    # words = "索引"
     res = test_model.predict(text_in.split("||"),words.split("||"))
     print(res)
-    print(test_model.predict(text_in.split("||"),words.split("||")))
-    print(test_model.predict(text_in.split("||"),words.split("||")))
+    # print(test_model.predict(text_in.split("||"),words.split("||")))

BIN
BiddingKG/dl/relation_extraction/models/my_best_model_oneoutput.weights


+ 2 - 2
BiddingKG/dl/test/测试整个要素提取流程.py

@@ -430,8 +430,8 @@ if __name__=="__main__":
     a = time.time()
     print("start")
     # print(predict("12",content))
-    # result = predict("12",text)
-    result = predict("12",content)
+    result = predict("12",text)
+    # result = predict("12",content)
     # print(json.loads(result))
     #test("12",text)
     print("takes",time.time()-a)