浏览代码

属性链接模型优化

znj 3 年之前
父节点
当前提交
c7fd07753e

+ 4 - 4
BiddingKG/dl/interface/getAttributes.py

@@ -760,8 +760,8 @@ def getPackagesFromArticle(list_sentence,list_entity):
                 PackageList.append(copy_pack)
     return PackageList,PackageSet,dict_packageCode
 
-from BiddingKG.dl.relation_extraction.model import Relation_extraction
-relationExtraction_model = Relation_extraction()
+from BiddingKG.dl.interface.modelFactory import Model_relation_extraction
+relationExtraction_model = Model_relation_extraction()
 def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity,list_sentence,on_value = 0.5,on_value_person=0.5,sentence_len=4):
     '''
     @param:
@@ -1134,12 +1134,12 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity
         if (ent.entity_type in ['company','org','phone'] and is_company(ent,list_sentence[ent.sentence_index].sentence_text)) or (ent.entity_type=='person' and ent.label in [1,2,3]) \
                 or (ent.entity_type=='location' and len(ent.entity_text)>5):
             pre_entity.append(ent)
-    text_data,pre_data = relationExtraction_model.get_predata(pre_entity+phone_entitys,list_sentence)
+    text_data,pre_data = relationExtraction_model.encode(pre_entity + phone_entitys, list_sentence)
     # print(pre_data)
     maxlen = 512
     relation_list = []
     if 0<len(text_data)<=maxlen:
-        relation_list = relationExtraction_model.predict(text_data,pre_data)
+        relation_list = relationExtraction_model.predict(text_data, pre_data)
     else:
         # 公告大于maxlen时,分段预测
         start = 0

+ 142 - 0
BiddingKG/dl/interface/modelFactory.py

@@ -155,6 +155,148 @@ class Model_money_classify():
         with self.graph.as_default():
             return self.getModel().predict([x[0],x[1],x[2]])
         '''
+
+from itertools import groupby
+from BiddingKG.dl.relation_extraction.model import get_words_matrix
+class Model_relation_extraction():
+    def __init__(self,lazyLoad=getLazyLoad()):
+        if USE_PAI_EAS:
+            lazyLoad = True
+        self.subject_model_file = os.path.dirname(__file__)+"/../relation_extraction/models2/subject_model"
+        self.object_model_file = os.path.dirname(__file__)+"/../relation_extraction/models2/object_model"
+        self.model_subject = None
+        self.model_object = None
+        self.sess_subject = tf.Session(graph=tf.Graph())
+        self.sess_object = tf.Session(graph=tf.Graph())
+        if not lazyLoad:
+            self.getModel1()
+            self.getModel2()
+        self.entity_type_dict = {
+            'org': '<company/org>',
+            'company': '<company/org>',
+            'location': '<location>',
+            'phone': '<phone>',
+            'person': '<contact_person>'
+        }
+        self.id2predicate = {
+            0: "rel_person",  # 公司——联系人
+            1: "rel_phone",  # 联系人——电话
+            2: "rel_address"  # 公司——地址
+        }
+        self.words_size = 128
+
+    # subject_model
+    def getModel1(self):
+        if self.model_subject is None:
+            with self.sess_subject.as_default() as sess:
+                with sess.graph.as_default():
+                    meta_graph_def = tf.saved_model.loader.load(sess,tags=["serve"],export_dir=self.subject_model_file)
+                    signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+                    signature_def = meta_graph_def.signature_def
+                    input0 = sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["input0"].name)
+                    input1 = sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["input1"].name)
+                    output = sess.graph.get_tensor_by_name(signature_def[signature_key].outputs["outputs"].name)
+                    self.model_subject = [[input0,input1],output]
+        return self.model_subject
+    # object_model
+    def getModel2(self):
+        if self.model_object is None:
+            with self.sess_object.as_default() as sess:
+                  with sess.graph.as_default():
+                        meta_graph_def = tf.saved_model.loader.load(sess,tags=["serve"],export_dir=self.object_model_file)
+                        signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+                        signature_def = meta_graph_def.signature_def
+                        input0 = sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["input0"].name)
+                        input1 = sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["input1"].name)
+                        input2 = sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["input2"].name)
+                        output = sess.graph.get_tensor_by_name(signature_def[signature_key].outputs["outputs"].name)
+                        self.model_object = [[input0,input1,input2],output]
+        return self.model_object
+
+    def encode(self,entity_list,list_sentence):
+        list_sentence = sorted(list_sentence, key=lambda x: x.sentence_index)
+        entity_list = sorted(entity_list, key=lambda x: (x.sentence_index, x.begin_index))
+        pre_data = []
+        text_data = []
+        last_sentence_index = -1
+        for key, group in groupby(entity_list, key=lambda x: x.sentence_index):
+            if key - last_sentence_index > 1:
+                for i in range(last_sentence_index + 1, key):
+                    pre_data.extend(list_sentence[i].tokens)
+                    text_data.extend([0] * len(list_sentence[i].tokens))
+            group = list(group)
+            for i in range(len(group)):
+                ent = group[i]
+                _tokens = list_sentence[key].tokens
+                if i == len(group) - 1:
+                    if i == 0:
+                        pre_data.extend(_tokens[:ent.begin_index])
+                        text_data.extend([0] * len(_tokens[:ent.begin_index]))
+                        pre_data.append(self.entity_type_dict[ent.entity_type])
+                        text_data.append(ent)
+                        pre_data.extend(_tokens[ent.end_index + 1:])
+                        text_data.extend([0] * len(_tokens[ent.end_index + 1:]))
+                        break
+                    else:
+                        pre_data.append(self.entity_type_dict[ent.entity_type])
+                        text_data.append(ent)
+                        pre_data.extend(_tokens[ent.end_index + 1:])
+                        text_data.extend([0] * len(_tokens[ent.end_index + 1:]))
+                        break
+                if i == 0:
+                    pre_data.extend(_tokens[:ent.begin_index])
+                    text_data.extend([0] * len(_tokens[:ent.begin_index]))
+                    pre_data.append(self.entity_type_dict[ent.entity_type])
+                    text_data.append(ent)
+                    pre_data.extend(_tokens[ent.end_index + 1:group[i + 1].begin_index])
+                    text_data.extend([0] * len(_tokens[ent.end_index + 1:group[i + 1].begin_index]))
+                else:
+                    pre_data.append(self.entity_type_dict[ent.entity_type])
+                    text_data.append(ent)
+                    pre_data.extend(_tokens[ent.end_index + 1:group[i + 1].begin_index])
+                    text_data.extend([0] * len(_tokens[ent.end_index + 1:group[i + 1].begin_index]))
+            last_sentence_index = key
+        return text_data, pre_data
+
+    def predict(self,text_in, words, rate=0.5):
+        # text_words = text_in
+        R = []
+        # _t2 = [self.words2id.get(c, 1) for c in words]
+        _t2 = np.zeros((len(words), self.words_size))
+        for i in range(len(words)):
+            _t2[i] = np.array(get_words_matrix(words[i]))
+        _t2 = np.array([_t2])
+        _t3 = [1 for _ in words]
+        _t3 = np.array([_t3])
+        # _k1 = self.model_subject.predict([_t2, _t3])
+        _k1 = limitRun(self.sess_subject,[self.model_subject[1]],feed_dict={self.model_subject[0][0]:_t2,
+                                                                            self.model_subject[0][1]:_t3})[0]
+        _k1 = _k1[0, :, 0]
+        _k1 = np.where(_k1 > rate)[0]
+        # print('k1',_k1)
+        _subjects = []
+        for i in _k1:
+            _subject = text_in[i]
+            _subjects.append((_subject, i, i))
+        if _subjects:
+            _t2 = np.repeat(_t2, len(_subjects), 0)
+            _t3 = np.repeat(_t3, len(_subjects), 0)
+            _k1, _ = np.array([_s[1:] for _s in _subjects]).T.reshape((2, -1, 1))
+            # _o1 = self.model_object.predict([_t2, _t3, _k1])
+            _o1 = limitRun(self.sess_object, [self.model_object[1]], feed_dict={self.model_object[0][0]: _t2,
+                                                                                  self.model_object[0][1]: _t3,
+                                                                                  self.model_object[0][2]: _k1})[0]
+            for i, _subject in enumerate(_subjects):
+                _oo1 = np.where(_o1[i] > 0.5)
+                # print('_oo1', _oo1)
+                for _ooo1, _c1 in zip(*_oo1):
+                    _object = text_in[_ooo1]
+                    _predicate = self.id2predicate[_c1]
+                    R.append((_subject[0], _predicate, _object))
+            return R
+        else:
+            return []
+
     
 class Model_person_classify():
     def __init__(self,lazyLoad=getLazyLoad()):

+ 21 - 3
BiddingKG/dl/relation_extraction/model.py

@@ -401,14 +401,32 @@ class Relation_extraction():
                     text_data.append(ent)
                     pre_data.extend(_tokens[ent.end_index+1:group[i + 1].begin_index])
                     text_data.extend([0] * len(_tokens[ent.end_index+1:group[i+1].begin_index]))
-            last_sentencee_index = key
+            last_sentence_index = key
         return text_data,pre_data
 
 
+def save_model():
+    graph = tf.Graph()
+    with graph.as_default() as graph:
+        with tf.Session(graph=graph).as_default() as sess:
+            test_model = Relation_extraction()
+            tf.saved_model.simple_save(sess,
+                                       "models2/object_model/",
+                                       inputs={"input0": test_model.object_model.input[0],
+                                               "input1": test_model.object_model.input[1],
+                                               "input2": test_model.object_model.input[2]},
+                                       outputs={"outputs": test_model.object_model.output})
+            tf.saved_model.simple_save(sess,
+                                       "models2/subject_model/",
+                                       inputs={"input0": test_model.subject_model.input[0],
+                                               "input1": test_model.subject_model.input[1]},
+                                       outputs={"outputs": test_model.subject_model.output})
 
 
 if __name__ == '__main__':
     test_model = Relation_extraction()
+    # save_model()
+
     # ['<pad>','<unk>','<company/org>','<location>','<phone>','<contact_person>']
     # add_words = ['<unk>','<company/org>','<location>','<phone>','<contact_person>']
     # add_dict = dict()
@@ -433,6 +451,6 @@ if __name__ == '__main__':
             "||单位||:||<company/org>||,||比价||代理||机构||:||<company/org>||,||联系人||:||<contact_person>||,||电话||:||<phone>||,||江西新立建设管理有限公司九江分公司"
     # text_in = "索引"
     # words = "索引"
-    res = test_model.predict(text_in.split("||"),words.split("||"))
-    print(res)
+    # res = test_model.predict(text_in.split("||"),words.split("||"))
+    # print(res)
     # print(test_model.predict(text_in.split("||"),words.split("||")))

二进制
BiddingKG/dl/relation_extraction/models2/object_model/saved_model.pb


二进制
BiddingKG/dl/relation_extraction/models2/object_model/variables/variables.data-00000-of-00001


二进制
BiddingKG/dl/relation_extraction/models2/object_model/variables/variables.index


二进制
BiddingKG/dl/relation_extraction/models2/subject_model/saved_model.pb


二进制
BiddingKG/dl/relation_extraction/models2/subject_model/variables/variables.data-00000-of-00001


二进制
BiddingKG/dl/relation_extraction/models2/subject_model/variables/variables.index