Sfoglia il codice sorgente

优化表格处理模型,公告内相同实体名称替换

rogel 4 anni fa
parent
commit
d8421be517

+ 1 - 1
BiddingKG/dl/common/models.py

@@ -937,7 +937,7 @@ def getBiLSTMModel(input_shape,vocab,embedding_weights,classes,use_am=False):
 
     model = models.Model(list_input,out)
 
-    model.compile(optimizer=optimizers.Adam(lr=0.0005),loss=losses.categorical_crossentropy,metrics=[precision,recall,f1_score])
+    model.compile(optimizer=optimizers.Adam(lr=0.01),loss=losses.categorical_crossentropy,metrics=[precision,recall,f1_score])
 
     model.summary()
 

+ 78 - 20
BiddingKG/dl/complaint/punish_rule.py

@@ -13,6 +13,70 @@ from keras.preprocessing.sequence import pad_sequences
 import BiddingKG.dl.interface.Preprocessing as Preprocessing
 from BiddingKG.dl.interface.Preprocessing import *
 
+def BiLSTM_CRF_tfmodel(sess,weights):
+    BiRNN_Units = 140
+    chunk_tags = {
+        'O': 0,
+        'PN_B': 1,
+        'PN_M': 2,
+        'PN_E': 3
+    }
+
+    def embedding_layer(input):
+        embedding = tf.get_variable("embedding",initializer=np.array(weights,dtype=np.float32) if weights is not None else None,dtype=tf.float32)
+        return tf.nn.embedding_lookup(params=embedding,ids=input)
+
+    def BiLSTM_Layer(input,length):
+        with tf.variable_scope("BiLSTM"):
+            forward_cell = tf.contrib.rnn.BasicLSTMCell(BiRNN_Units//2,state_is_tuple=True)
+            backward_cell = tf.contrib.rnn.BasicLSTMCell(BiRNN_Units//2,state_is_tuple=True)
+        output, _ = tf.nn.bidirectional_dynamic_rnn(forward_cell,backward_cell,input,dtype=tf.float32,sequence_length=length)
+        output = tf.concat(output,2)
+        return output
+
+    def CRF_layer(input,num_tags,BiRNN_Units,time_step):
+        with tf.variable_scope("CRF"):
+            with tf.variable_scope("hidden"):
+                w_hidden = tf.get_variable(name='w_hidden',shape=(BiRNN_Units,BiRNN_Units//2),dtype=tf.float32,
+                                           initializer=initializers.xavier_initializer(),regularizer=tf.contrib.layers.l2_regularizer(0.001))
+                b_hidden = tf.get_variable(name='b_hidden',shape=(BiRNN_Units//2),dtype=tf.float32,initializer=tf.zeros_initializer())
+                # print(input)
+                input_reshape = tf.reshape(input,shape=(-1,BiRNN_Units))
+                hidden = tf.tanh(tf.nn.xw_plus_b(input_reshape,w_hidden,b_hidden))
+            with tf.variable_scope("output"):
+                w_output = tf.get_variable(name='w_output',shape=(BiRNN_Units//2,num_tags),dtype=tf.float32,initializer=initializers.xavier_initializer(),regularizer=tf.contrib.layers.l2_regularizer(0.001))
+                b_output = tf.get_variable(name='b_output',shape=(num_tags),dtype=tf.float32,initializer=tf.zeros_initializer())
+                pred = tf.nn.xw_plus_b(hidden,w_output,b_output)
+                logits_ = tf.reshape(pred,shape=(-1,time_step,num_tags),name='logits')
+        return logits_
+
+    def layer_loss(input,true_target,num_tags,length):
+        with tf.variable_scope("crf_loss"):
+            trans = tf.get_variable(name='transitons',shape=(num_tags,num_tags),dtype=tf.float32,initializer=initializers.xavier_initializer())
+            log_likelihood,trans = crf_log_likelihood(inputs=input,tag_indices=true_target,transition_params=trans,sequence_lengths=length)
+            return tf.reduce_mean(-log_likelihood),trans
+
+    with sess.graph.as_default():
+        char_input = tf.placeholder(name='char_input',shape=(None,None),dtype=tf.int32)
+        target = tf.placeholder(name='target',shape=(None,None),dtype=tf.int32)
+        length = tf.placeholder(name='length',shape=(None,),dtype=tf.int32)
+        # keepprob = tf.placeholder(name='keepprob',dtype=tf.float32)
+
+        _embedding = embedding_layer(char_input)
+        _shape = tf.shape(char_input)
+        batch_size = _shape[0]
+        step_size = _shape[-1]
+        bilstm = BiLSTM_Layer(_embedding,length)
+        _logits = CRF_layer(bilstm,num_tags=len(chunk_tags),BiRNN_Units=BiRNN_Units,time_step=step_size)
+        crf_loss,trans = layer_loss(_logits,true_target=target,num_tags=len(chunk_tags),length=length)
+        global_step = tf.Variable(0,trainable=False)
+        with tf.variable_scope("optimizer"):
+            opt = tf.train.AdamOptimizer(0.002)
+            grads_vars = opt.compute_gradients(crf_loss)
+            capped_grads_vars = [[tf.clip_by_value(g,-5,5),v] for g,v in grads_vars]
+            train_op = opt.apply_gradients(capped_grads_vars,global_step)
+            return char_input,_logits,target,length,crf_loss,trans,train_op
+
 def decode(logits, trans, sequence_lengths, tag_num):
     viterbi_sequences = []
     for logit, length in zip(logits, sequence_lengths):
@@ -22,8 +86,8 @@ def decode(logits, trans, sequence_lengths, tag_num):
     return viterbi_sequences
 
 class Punish_Extract():
-    def __init__(self, model_file = os.path.dirname(__file__)+"/models/complaint_code.pb"):
-        # print('model_file_path:',model_file)
+    def __init__(self, model_file = os.path.dirname(__file__)+"/models/21-0.9990081295021194-0.3647936/model.ckpt"):
+        print('model_file_path:',model_file)
         self.sess = tf.Session(graph=tf.Graph())
         self.code = ""
         self.punish_dicition = ""
@@ -34,23 +98,20 @@ class Punish_Extract():
     def load_model(self):
         with self.sess.as_default() as sess:
             with sess.graph.as_default():
-                output_graph_def = tf.GraphDef()
-                with open(self.model_file, 'rb') as f:
-                    output_graph_def.ParseFromString(f.read())
-                    tf.import_graph_def(output_graph_def, name="")
-                    sess.run(tf.global_variables_initializer())
-                    self.char_input = sess.graph.get_tensor_by_name('char_input:0')
-                    self.length = sess.graph.get_tensor_by_name('length:0')
-                    self.trans = sess.graph.get_tensor_by_name('crf_loss/transitons:0')
-                    self.logits = sess.graph.get_tensor_by_name('CRF/output/logits:0')
+                vocab_model = getModel_word()
+                vocab, w2v_matrix = getVocabAndMatrix(vocab_model, Embedding_size=60)
+                self.char_input, self.logits, self.target, self.length, self.crf_loss, self.trans, self.train_op = BiLSTM_CRF_tfmodel(sess, w2v_matrix)
+                sess.run(tf.global_variables_initializer())
+                saver = tf.train.Saver()
+                saver.restore(sess, self.model_file)
 
     # 处罚编号预测
     def predict_punishCode(self,list_sentences):
         re_ner = re.compile("12+?3")
         article_ner_list = []
         count = 0
-        with self.sess.as_default() as sess:
-            with sess.graph.as_default():
+        with self.sess.as_default():
+            with self.sess.graph.as_default():
                 for sentences in list_sentences:
                     count += 1
                     # print(count)
@@ -64,7 +125,7 @@ class Punish_Extract():
                         sentences_x.append(sentence2id)
                     sentences_x = pad_sequences(sentences_x, maxlen=maxlen, padding="post", truncating="post")
                     sentences_x = [np.array(x) for x in sentences_x]
-                    _logits,_trans = self.sess.run([self.logits, self.trans],
+                    _logits, _trans = self.sess.run([self.logits, self.trans],
                                                feed_dict={self.char_input: np.array(sentences_x), self.length: sentence_len})
                     viterbi_sequence = decode(logits=_logits, trans=_trans, sequence_lengths=sentence_len, tag_num=4)
 
@@ -419,8 +480,7 @@ class Punish_Extract():
         return punish_dic
 
 if __name__ == "__main__":
-    # punish = Punish_Extract(model_file='models/21-0.9990081295021194-0.3647936/model.ckpt')
-    punish = Punish_Extract()
+    punish = Punish_Extract(model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt")
 
     import pandas as pd
     # with open('G:/失信数据/ALLDATA_re2-3.xlsx') as f:
@@ -454,14 +514,12 @@ if __name__ == "__main__":
     #    'DETAILLINK', 'sentences', 'PAGE_TIME'])
     # t3 = time.time()
     # print('处理耗时:%.4f, 保存耗时:%.4f'%(t2-t1, t3-t2))
-    s = '投诉处理公告,投诉人:张三。编号:厦财企〔2020〕12号,各有关单位:341号。处罚编号:厦财企〔2020〕12号,文章编号:京财采投字(2018)第42号。公告编号:闽建筑招〔2018〕5号。处罚编号:松公管监[2020]2号,'
+    s = '编号:厦财企〔2020〕12号,各有关单位:341号。处罚编号:厦财企〔2020〕12号,文章编号:京财采投字(2018)第42号。公告编号:闽建筑招〔2018〕5号。处罚编号:松公管监[2020]2号,'
     # list_sentences = [s.split('。')]
     # punish_code= punish.predict_punishCode( list_sentences)
     # print(punish_code)
 
     # punish_code, punishType, punishDecision, complainants, punishPeople, punishWhether, institutions, punishTimes = \
     #             get_punish_extracts(text=s)
-    list_articles, list_sentences, list_entitys, _ = Preprocessing.get_preprocessed([['', s, "", "", ""]],
-                                                                                    useselffool=True)
-    punish_dic = punish.get_punish_extracts(list_sentences, list_entitys,text=s)
+    punish_dic = punish.get_punish_extracts_backup(text=s)
     print(punish_dic)

+ 10 - 2
BiddingKG/dl/entityLink/entityLink.py

@@ -4,6 +4,7 @@ Created on 2019年5月21日
 @author: User
 '''
 
+import re
 
 def edit_distance(source,target):
     dp = [["" for i in range(len(source)+1)] for j in range(len(target)+1)]
@@ -36,13 +37,20 @@ def link_entitys(list_entitys,on_value=0.8):
                 range_entity.append(_entity)
         range_entity = range_entity[:1000]
         for first_i in range(len(range_entity)):
-            _entity = list_entity[first_i]
+            _entity = range_entity[first_i]
             for second_i in range(first_i+1,len(range_entity)):
-                _ent = list_entity[second_i]
+                _ent = range_entity[second_i]
                 _score = jaccard_score(_entity.entity_text, _ent.entity_text)
                 if _entity.entity_text!=_ent.entity_text and _score>=on_value:
                     _entity.linked_entitys.append(_ent)
                     _ent.linked_entitys.append(_entity)
+        #替换公司名称
+        for _entity in range_entity:
+            if re.search("公司",_entity.entity_text) is None:
+                for _ent in _entity.linked_entitys:
+                    if re.search("公司$",_ent.entity_text) is not None:
+                        if len(_ent.entity_text)>len(_entity.entity_text):
+                            _entity.entity_text = _ent.entity_text
 
                                 
 if __name__=="__main__":

BIN
BiddingKG/dl/form/form_context_savedmodel/saved_model.pb


BIN
BiddingKG/dl/form/form_context_savedmodel/variables/variables.data-00000-of-00001


BIN
BiddingKG/dl/form/form_context_savedmodel/variables/variables.index


+ 42 - 13
BiddingKG/dl/form/generateData.py

@@ -625,17 +625,29 @@ def getTrainData(percent=0.9):
 
 
 
-def getTrainData_jsonTable():
-    def padding_table(table):
+def getTrainData_jsonTable(begin,end,return_text=False):
+    def encode_table(inner_table,size=30):
         def encode_item(_table,i,j):
             _x = [_table[j-1][i-1],_table[j-1][i],_table[j-1][i+1],
                   _table[j][i-1],_table[j][i],_table[j][i+1],
                   _table[j+1][i-1],_table[j+1][i],_table[j+1][i+1]]
             e_x = [encodeInput_form(_temp[0],MAX_LEN=30) for _temp in _x]
             _label = _table[j][i][1]
-            return e_x,_label
+            # print(_x)
+            # print(_x[4],_label)
+            return e_x,_label,_x
 
-        padding = ["###",0]
+        def copytable(inner_table):
+            table = []
+            for line in inner_table:
+                list_line = []
+                for item in line:
+                    list_line.append([item[0][:size],item[1]])
+                table.append(list_line)
+            return table
+
+        table = copytable(inner_table)
+        padding = ["#"*30,0]
         width = len(table[0])
         height = len(table)
         table.insert(0,[padding for i in range(width)])
@@ -645,36 +657,53 @@ def getTrainData_jsonTable():
             item.append(padding.copy())
         data_x = []
         data_y = []
+        data_text = []
+        data_position = []
         for _i in range(1,width+1):
             for _j in range(1,height+1):
-                _x,_y = encode_item(table,_i,_j)
+                _x,_y,_text = encode_item(table,_i,_j)
                 data_x.append(_x)
                 _label = [0,0]
                 _label[_y] = 1
                 data_y.append(_label)
-        return data_x,data_y
+                data_text.append(_text)
+                data_position.append([_i-1,_j-1])
+                # input = table[_j][_i][0]
+                # item_y = [0,0]
+                # item_y[table[_j][_i][1]] = 1
+                # data_x.append(encodeInput([input], word_len=50, word_flag=True,userFool=False)[0])
+                # data_y.append(item_y)
+        return data_x,data_y,data_text,data_position
 
-    def getDataSet(list_json_table):
+    def getDataSet(list_json_table,return_text=False):
         _count = 0
         _sum = len(list_json_table)
         data_x = []
         data_y = []
+        data_text = []
         for json_table in list_json_table:
             _count += 1
             print("%d/%d"%(_count,_sum))
             table = json.loads(json_table)
             if table is not None:
-                list_x,list_y = padding_table(table)
+                list_x,list_y,list_text = encode_table(table)
                 data_x.extend(list_x)
                 data_y.extend(list_y)
-        return np.array(data_x),np.array(data_y)
+                if return_text:
+                    data_text.extend(list_text)
+        return np.array(data_x),np.array(data_y),data_text
 
+    save_path = "./traindata/websource_67000_table_%d-%d-%s.pk"%(begin,end,"1" if return_text else "0")
+    if os.path.exists(save_path):
+        data_x,data_y,data_text = load(save_path)
+    else:
+        df = pd.read_csv("./traindata/websource_67000_table.csv",encoding="GBK")
+        import json
+        data_x,data_y,data_text = getDataSet(df["json_table"][begin:end],return_text=return_text)
 
-    df = pd.read_csv("websource_67000_table.csv",encoding="GBK")
-    import json
-    train_x,train_y = getDataSet(df["json_table"][:-1000])
-    test_x,test_y = getDataSet(df["json_table"][-1000:])
+        save((data_x,data_y,data_text),save_path)
 
+    return data_x,data_y,data_text
 
 
 if __name__=="__main__":

BIN
BiddingKG/dl/form/log/ep001-loss0.203-val_loss0.112-f10.9525.h5


BIN
BiddingKG/dl/complaint/models/complaint_code.pb → BiddingKG/dl/form/log/ep001-loss0.555-val_loss0.539-f1nan.h5


BIN
BiddingKG/dl/form/log/ep002-loss0.102-val_loss0.088-f10.9640.h5


BIN
BiddingKG/dl/form/log/ep003-loss0.552-val_loss0.539-f1nan.h5


+ 70 - 12
BiddingKG/dl/form/model.py

@@ -13,6 +13,7 @@ from keras import layers
 from keras import models
 from keras import losses,optimizers
 import tensorflow as tf
+from BiddingKG.dl.common.models import transformer_model
 
 def getBiLSTMModel():
     input = layers.Input(shape=(100,))
@@ -55,29 +56,86 @@ def getTextCNNModel():
     model.summary()
     return model
 
-def get_context_form_model(vocab_len,char_dim,lstm_dim,context_dim):
-    input = tf.placeholder(shape=[None,9,30])
+def get_context_form_model(vocab_len,char_dim,lstm_dim,context_dim,res_dim,residual_stacks):
+    input = tf.placeholder(shape=[None,9,30],dtype=tf.int32)
+    label = tf.placeholder(shape=[None,2],dtype=tf.int32)
 
+    # input_center = tf.slice(input,[0,4,0],[-1,1,-1])
     with tf.variable_scope("embedding"):
         embedding = tf.get_variable("char_embedding",shape=[vocab_len,char_dim])
         embedding_input = tf.nn.embedding_lookup(embedding,input)
 
-    with tf.variable_scope("bi_rnn"):
-        resize_input = tf.reshape(embedding_input,[-1,30,char_dim])
 
-        forward_cell = tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True)
-        backward_cell = tf.contrib.rnn.BasicLSTMCell(lstm_dim,state_is_tuple=True)
-        outputs,forward_output,backward_output = tf.nn.static_bidirectional_rnn(forward_cell,backward_cell,resize_input)
-        bi_output = tf.concat([forward_output,backward_output],-1) #[-1,lstm_dim*2]
+    with tf.variable_scope("center"):
+        center_embedding = tf.slice(embedding_input,[0,4,0,0],[-1,1,-1,-1])
+        with tf.variable_scope("bert"):
+            bert_input = tf.reshape(center_embedding,[-1,30,char_dim])
+            bert_output = transformer_model(input_tensor=bert_input,name="bert")
 
+            forward_cell = tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True)
+            backward_cell = tf.contrib.rnn.BasicLSTMCell(lstm_dim,state_is_tuple=True)
+            outputs, final_out = tf.nn.bidirectional_dynamic_rnn(forward_cell,backward_cell,bert_output, dtype=tf.float32)
+            bi_output = tf.concat(final_out,-1) #[-1,lstm_dim*2]
+            bi_output = tf.reshape(tf.slice(bi_output,[1,0,0],[1,-1,-1]),[-1,lstm_dim*2])
+
+            center_output = tf.nn.relu(tf.matmul(bi_output,tf.get_variable("center_v",shape=[lstm_dim*2,res_dim])))
+            center_output = tf.nn.dropout(center_output,1)
 
     with tf.variable_scope("context"):
-        context_out = tf.reshape(bi_output,[-1,9,lstm_dim*2])
-        context_v = tf.get_variable("context_v",shape=[lstm_dim*2,context_dim])
+        with tf.variable_scope("bi_rnn"):
+            resize_input = tf.reshape(embedding_input,[-1,30,char_dim])
+            forward_cell = tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True)
+            backward_cell = tf.contrib.rnn.BasicLSTMCell(lstm_dim,state_is_tuple=True)
+            outputs, final_out = tf.nn.bidirectional_dynamic_rnn(forward_cell,backward_cell,resize_input, dtype=tf.float32)
+            bi_output = tf.concat(final_out,-1) #[-1,lstm_dim*2]
+            bi_output = tf.slice(bi_output,[1,0,0],[1,-1,-1])
+
+            context_input = tf.reshape(bi_output,[-1,lstm_dim*2])
+            context_v = tf.get_variable("context_v",shape=[lstm_dim*2,context_dim])
+
+            context_emb = tf.nn.relu(tf.matmul(context_input,context_v))
+
+            context_output = tf.reshape(context_emb,[-1,9*context_dim])
+
+            context_output = tf.nn.relu(tf.matmul(context_output,tf.get_variable("context_output_v",shape=[9*context_dim,res_dim])))
+
+            context_output = tf.nn.dropout(context_output,1)
+
+
+    with tf.variable_scope("residual"):
+        input_x = tf.concat([context_output,center_output],axis=-1)
+        for stack in range(residual_stacks):
+            stack_w = tf.get_variable("stack_w_%d"%stack,shape=[res_dim*2,res_dim*2])
+            stack_b = tf.get_variable("stack_b_%d"%stack,shape=[1,res_dim*2])
+            stack_out = tf.matmul(input_x,stack_w)+stack_b
+            input_x = tf.nn.relu(stack_out+input_x)
+
+
+
+    with tf.variable_scope("softmax"):
+        softmax_input = tf.identity(input_x)
+        softmax_w = tf.get_variable("softmax_w",shape=[res_dim*2,2])
+        softmax_output = tf.nn.softmax(tf.matmul(softmax_input,softmax_w))
+
+
+
+    # with tf.variable_scope("softmax"):
+    #     softmax_input = tf.reshape(bi_output,[-1,lstm_dim*2])
+    #     softmax_w = tf.get_variable("softmax_w",shape=[lstm_dim*2,2])
+    #     softmax_output = tf.nn.softmax(tf.matmul(softmax_input,softmax_w))
+
+    _p = precision(tf.cast(label,tf.float32),softmax_output)
+    _r = recall(tf.cast(label,tf.float32),softmax_output)
+    loss = -tf.reduce_mean(tf.cast(label,tf.float32)*tf.log(softmax_output))
+
+    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
+    global_step = tf.Variable(0, trainable=False)
+    grads_vars = optimizer.compute_gradients(loss)
+    capped_grads_vars = [[tf.clip_by_value(g, -5, 5), v] for g, v in grads_vars]
+    train = optimizer.apply_gradients(capped_grads_vars, global_step)
 
-        context_emb = tf.nn.relu(tf.matmul(context_out,context_v))
+    return {"input":input,"output":softmax_output,"label":label,"train":train,"embedding":embedding,"loss":loss,"precision":_p,"recall":_r}
 
-        center_out = tf.slice(context_out,[0,4,0],[-1,1,-1])
 
 
 

+ 103 - 5
BiddingKG/dl/form/train.py

@@ -6,7 +6,8 @@ Created on 2019年3月26日
 
 import sys
 import os
-sys.path.append(os.path.abspath("../.."))
+os.environ['KERAS_BACKEND']='tensorflow'
+sys.path.append(os.path.abspath("../../.."))
 from BiddingKG.dl.common.models import *
 from keras.callbacks import ModelCheckpoint
 import numpy as np
@@ -16,6 +17,7 @@ import tensorflow as tf
 from generateData import *
 
 
+
 def train():
     model = getTextCNNModel()
     train_x, train_y = getData("train.xls")
@@ -28,12 +30,13 @@ def train():
 
 
 def train1():
-    data_pk = "all_data.pk"
+    data_pk = "./traindata/all_data.pk"
     if os.path.exists(data_pk):
         train_x, train_y, test_x, test_y, test_text = load(data_pk)
     else:
-        train_x, train_y, test_x, test_y, test_text = getTrainData()
-        save((train_x, train_y, test_x, test_y, test_text), data_pk)
+        # train_x, train_y, test_x, test_y, test_text = getTrainData()
+        # save((train_x, train_y, test_x, test_y, test_text), data_pk)
+        train_x,train_y,test_x,test_y = getTrainData_jsonTable()
 
     with tf.Session(graph=tf.Graph()) as sess:
         with sess.graph.as_default():
@@ -162,10 +165,105 @@ def save_form_model():
                                        inputs={"inputs":model.input},
                                        outputs = {"outputs":model.output})
 
+from BiddingKG.dl.form.generateData import getTrainData_jsonTable
+from BiddingKG.dl.form.model import get_context_form_model
+def train_context():
+
+
+    vocab,vocab_matrix = getVocabAndMatrix(getModel_word(), Embedding_size=60)
+
+    sess = tf.Session(graph=tf.Graph())
+    epochs = 100
+    batch_size = 1000
+    with sess.graph.as_default():
+        node_dict = get_context_form_model(vocab_len=len(vocab_matrix),char_dim=60,lstm_dim=24,context_dim=36,res_dim=16,residual_stacks=5)
+        sess.run(tf.global_variables_initializer())
+        sess.run(tf.assign(node_dict["embedding"],np.array(vocab_matrix)))
+
+
+        train_x,train_y,_ = getTrainData_jsonTable(begin=0,end=20000)
+        test_x,test_y,text_text = getTrainData_jsonTable(begin=-2000,end=-1,return_text=True)
+        saver = tf.train.Saver(max_to_keep=epochs)
+        for _e in range(epochs):
+            batch_begin = 0
+            avg_loss = 0
+            avg_precision = 0
+            _count = 0
+            while(batch_begin<len(train_x)):
+                feed_dict = {node_dict["input"]:train_x[batch_begin:batch_begin+batch_size],
+                             node_dict["label"]:train_y[batch_begin:batch_begin+batch_size]}
+
+                _,_loss,_precision,_recall = sess.run([node_dict["train"],node_dict["loss"],node_dict["precision"],node_dict["recall"]],feed_dict=feed_dict)
+                _count += 1
+                avg_loss += _loss
+                avg_precision  += _precision
+                if _count %10==0:
+                    print("training %d epoch %d/%d loss:%5f precision:%5f"%(_e,batch_begin,len(train_x),avg_loss/_count,avg_precision/_count))
+                batch_begin += batch_size
+            train_loss = avg_loss/_count
+            train_precision = avg_precision/_count
+            batch_begin = 0
+            avg_loss = 0
+            avg_precision = 0
+            _count = 0
+            while(batch_begin<len(test_x)):
+                feed_dict = {node_dict["input"]:test_x[batch_begin:batch_begin+batch_size],
+                             node_dict["label"]:test_y[batch_begin:batch_begin+batch_size]}
+
+                _loss,_precision,_recall = sess.run([node_dict["loss"],node_dict["precision"],node_dict["recall"]],feed_dict=feed_dict)
+                _count += 1
+                avg_loss += _loss
+                avg_precision  += _precision
+                if _count %10==0:
+                    print("testing %d epoch %d/%d loss:%5f precision:%5f"%(_e,batch_begin,len(test_x),avg_loss/_count,avg_precision/_count))
+                batch_begin += batch_size
+            saver.save(sess,"./traindata/log/epoch%d-train_loss%5f-train_precision%5f-val_loss%5f-val_precision%5f.ckpt"%(_e,train_loss,train_precision,avg_loss/_count,avg_precision/_count))
+
+def predict_context():
+    import os
+    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+    os.environ["CUDA_VISIBLE_DEVICES"] = ""
+    model_file = "./traindata/log/epoch18-train_loss0.075330-train_precision0.947963-val_loss0.088985-val_precision0.947385.ckpt"
+
+    sess = tf.Session(graph=tf.Graph())
+    with sess.graph.as_default():
+        vocab,vocab_matrix = getVocabAndMatrix(getModel_word(), Embedding_size=60)
+        node_dict = get_context_form_model(vocab_len=len(vocab_matrix),char_dim=60,lstm_dim=24,context_dim=36,res_dim=16,residual_stacks=5)
+        saver = tf.train.Saver()
+        saver.restore(sess,model_file)
+
+        # tf.saved_model.simple_save(session=sess,
+        #                            export_dir="./form_context_savedmodel",
+        #                            inputs={"inputs":node_dict["input"]},
+        #                            outputs={"outputs":node_dict["output"]})
+        # return
+
+        test_x,test_y,test_text = getTrainData_jsonTable(begin=-2000,end=-1,return_text=True)
+
+        batch_begin = 0
+        batch_size = 2000
+        while batch_begin<len(test_x):
+            feed_dict = {node_dict["input"]:test_x[batch_begin:batch_begin+batch_size]}
+            batch_out = sess.run([node_dict["output"]],feed_dict=feed_dict)
+            batch_label = test_y[batch_begin:batch_begin+batch_size]
+            batch_text = test_text[batch_begin:batch_begin+batch_size]
+            for _out,_label,_text in zip(batch_out[0],batch_label,batch_text):
+                if np.argmax(_out)!=np.argmax(_label):
+                    print("======")
+                    print(_text)
+                    print(_text[4])
+                    print(_out)
+                    print(_label)
+                    print("============")
+            batch_begin += batch_size
+
+
 if __name__ == "__main__":
     # train()
     # print(test(["序号|项目名称|中选人"]))
     # getBestThreshold()
     # train1()
     # vali()
-    save_form_model()
+    # save_form_model()
+    # train_context()
+    predict_context()

+ 231 - 192
BiddingKG/dl/interface/Preprocessing.py

@@ -155,201 +155,175 @@ def tableToText(soup):
                 if inner_table[i][j][0]==pad_col:
                     inner_table[i][j][0] = inner_table[i-1][j][0]
                     inner_table[i][j][1] = inner_table[i-1][j][1]
-                    
-    def setHead_initem(inner_table,pat_head,fix_value="~~",prob_min=0.5):
-        def repairTable(inner_table,dye_set = set(),key_set = set()):
+
+    def repairTable(inner_table,dye_set = set(),key_set = set(),fix_value="~~"):
+        '''
+        @summary: 修复表头识别,将明显错误的进行修正
+        '''
+        def repairNeeded(line):
+            first_1 = -1
+            last_1 = -1
+            first_0 = -1
+            last_0 = -1
+            count_1 = 0
+            count_0 = 0
+            for i in range(len(line)):
+                if line[i][0]==fix_value:
+                    continue
+                if line[i][1]==1:
+                    if first_1==-1:
+                        first_1 = i
+                    last_1 = i
+                    count_1 += 1
+                if line[i][1]==0:
+                    if first_0 == -1:
+                        first_0 = i
+                    last_0 = i
+                    count_0 += 1
+            if first_1 ==-1 or last_0 == -1:
+                return False
+            #异常情况:第一个不是表头;最后一个是表头;表头个数远大于属性值个数
+            if first_1-0>0 or last_0-len(line)+1<0 or last_1==len(line)-1 or count_1-count_0>=3:
+                return True
+            return False
+
+        def getsimilarity(line,line1):
+            same_count = 0
+            for item,item1 in zip(line,line1):
+                if item[1]==item1[1]:
+                    same_count += 1
+            return same_count/len(line)
+
+        def selfrepair(inner_table,index,dye_set,key_set):
             '''
-            @summary: 修复表头识别,将明显错误的进行修正
+            @summary: 计算每个节点受到的挤压度来判断是否需要染色
             '''
-            def repairNeeded(line):
-                first_1 = -1
-                last_1 = -1
-                first_0 = -1
-                last_0 = -1
-                count_1 = 0
-                count_0 = 0
-                for i in range(len(line)):
-                    if line[i][0]==fix_value:
-                        continue
-                    if line[i][1]==1:
-                        if first_1==-1:
-                            first_1 = i
-                        last_1 = i
-                        count_1 += 1
-                    if line[i][1]==0:
-                        if first_0 == -1:
-                            first_0 = i
-                        last_0 = i
-                        count_0 += 1
-                if first_1 ==-1 or last_0 == -1:
-                    return False
-                #异常情况:第一个不是表头;最后一个是表头;表头个数远大于属性值个数
-                if first_1-0>0 or last_0-len(line)+1<0 or last_1==len(line)-1 or count_1-count_0>=3:
-                    return True
-                return False
-            
-            def getsimilarity(line,line1):
-                same_count = 0
-                for item,item1 in zip(line,line1):
-                    if item[1]==item1[1]:
-                        same_count += 1
-                return same_count/len(line)
-            
-            def selfrepair(inner_table,index,dye_set,key_set):
-                '''
-                @summary: 计算每个节点受到的挤压度来判断是否需要染色
-                '''
-                #print("B",inner_table[index])
-                min_presure = 3
-                list_dye = []
-                first = None
-                count = 0
-                temp_set = set()
-                _index = 0
-                for item in inner_table[index]:
-                    if first is None:
-                        first = item[1]
+            #print("B",inner_table[index])
+            min_presure = 3
+            list_dye = []
+            first = None
+            count = 0
+            temp_set = set()
+            _index = 0
+            for item in inner_table[index]:
+                if first is None:
+                    first = item[1]
+                    if item[0] not in temp_set:
+                        count += 1
+                        temp_set.add(item[0])
+                else:
+                    if first == item[1]:
                         if item[0] not in temp_set:
-                            count += 1
                             temp_set.add(item[0])
+                            count += 1
                     else:
-                        if first == item[1]:
-                            if item[0] not in temp_set:
-                                temp_set.add(item[0])
-                                count += 1
-                        else:
-                            list_dye.append([first,count,_index])
-                            first = item[1]
-                            temp_set.add(item[0])
-                            count = 1
-                    _index += 1
-                list_dye.append([first,count,_index])
-                if len(list_dye)>1:
-                    begin = 0
-                    end = 0
-                    for i in range(len(list_dye)):
-                        end = list_dye[i][2]
-                        dye_flag = False
-                        #首尾要求压力减一
-                        if i==0:
-                            if list_dye[i+1][1]-list_dye[i][1]+1>=min_presure-1:
+                        list_dye.append([first,count,_index])
+                        first = item[1]
+                        temp_set.add(item[0])
+                        count = 1
+                _index += 1
+            list_dye.append([first,count,_index])
+            if len(list_dye)>1:
+                begin = 0
+                end = 0
+                for i in range(len(list_dye)):
+                    end = list_dye[i][2]
+                    dye_flag = False
+                    #首尾要求压力减一
+                    if i==0:
+                        if list_dye[i+1][1]-list_dye[i][1]+1>=min_presure-1:
+                            dye_flag = True
+                            dye_type = list_dye[i+1][0]
+                    elif i==len(list_dye)-1:
+                        if list_dye[i-1][1]-list_dye[i][1]+1>=min_presure-1:
+                            dye_flag = True
+                            dye_type = list_dye[i-1][0]
+                    else:
+                        if list_dye[i][1]>1:
+                            if list_dye[i+1][1]-list_dye[i][1]+1>=min_presure:
                                 dye_flag = True
                                 dye_type = list_dye[i+1][0]
-                        elif i==len(list_dye)-1:
-                            if list_dye[i-1][1]-list_dye[i][1]+1>=min_presure-1:
+                            if list_dye[i-1][1]-list_dye[i][1]+1>=min_presure:
                                 dye_flag = True
                                 dye_type = list_dye[i-1][0]
                         else:
-                            if list_dye[i][1]>1:
-                                if list_dye[i+1][1]-list_dye[i][1]+1>=min_presure:
-                                    dye_flag = True
-                                    dye_type = list_dye[i+1][0]
-                                if list_dye[i-1][1]-list_dye[i][1]+1>=min_presure:
-                                    dye_flag = True
-                                    dye_type = list_dye[i-1][0]
-                            else:
-                                if list_dye[i+1][1]+list_dye[i-1][1]-list_dye[i][1]+1>=min_presure:
-                                    dye_flag = True
-                                    dye_type = list_dye[i+1][0]
-                                if list_dye[i+1][1]+list_dye[i-1][1]-list_dye[i][1]+1>=min_presure:
-                                    dye_flag = True
-                                    dye_type = list_dye[i-1][0]
-                        if dye_flag:
-                            for h in range(begin,end):
-                                inner_table[index][h][1] = dye_type
-                                dye_set.add((inner_table[index][h][0],dye_type))
-                                key_set.add(inner_table[index][h][0])
-                        begin = end
-                    #print("E",inner_table[index])
-                        
-                            
-                                
-                
-            def otherrepair(inner_table,index,dye_set,key_set):
-                list_provide_repair = []
-                if index==0 and len(inner_table)>1:
-                    list_provide_repair.append(index+1)
-                elif index==len(inner_table)-1:
-                    list_provide_repair.append(index-1)
-                else:
-                    list_provide_repair.append(index+1)
-                    list_provide_repair.append(index-1)
-                for provide_index in list_provide_repair:
-                    if not repairNeeded(inner_table[provide_index]):
-                        same_prob = getsimilarity(inner_table[index], inner_table[provide_index])
-                        if same_prob>=0.8:
-                            for i in range(len(inner_table[provide_index])):
-                                if inner_table[index][i][1]!=inner_table[provide_index][i][1]:
-                                    dye_set.add((inner_table[index][i][0],inner_table[provide_index][i][1]))
-                                    key_set.add(inner_table[index][i][0])
-                                inner_table[index][i][1] = inner_table[provide_index][i][1]
-                        elif same_prob<=0.2:
-                            for i in range(len(inner_table[provide_index])):
-                                if inner_table[index][i][1]==inner_table[provide_index][i][1]:
-                                    dye_set.add((inner_table[index][i][0],inner_table[provide_index][i][1]))
-                                    key_set.add(inner_table[index][i][0])
-                                inner_table[index][i][1] = 0 if inner_table[provide_index][i][1] ==1 else 1
-            len_dye_set = len(dye_set)
-            height = len(inner_table)
+                            if list_dye[i+1][1]+list_dye[i-1][1]-list_dye[i][1]+1>=min_presure:
+                                dye_flag = True
+                                dye_type = list_dye[i+1][0]
+                            if list_dye[i+1][1]+list_dye[i-1][1]-list_dye[i][1]+1>=min_presure:
+                                dye_flag = True
+                                dye_type = list_dye[i-1][0]
+                    if dye_flag:
+                        for h in range(begin,end):
+                            inner_table[index][h][1] = dye_type
+                            dye_set.add((inner_table[index][h][0],dye_type))
+                            key_set.add(inner_table[index][h][0])
+                    begin = end
+                #print("E",inner_table[index])
+
+
+
+
+        def otherrepair(inner_table,index,dye_set,key_set):
+            list_provide_repair = []
+            if index==0 and len(inner_table)>1:
+                list_provide_repair.append(index+1)
+            elif index==len(inner_table)-1:
+                list_provide_repair.append(index-1)
+            else:
+                list_provide_repair.append(index+1)
+                list_provide_repair.append(index-1)
+            for provide_index in list_provide_repair:
+                if not repairNeeded(inner_table[provide_index]):
+                    same_prob = getsimilarity(inner_table[index], inner_table[provide_index])
+                    if same_prob>=0.8:
+                        for i in range(len(inner_table[provide_index])):
+                            if inner_table[index][i][1]!=inner_table[provide_index][i][1]:
+                                dye_set.add((inner_table[index][i][0],inner_table[provide_index][i][1]))
+                                key_set.add(inner_table[index][i][0])
+                            inner_table[index][i][1] = inner_table[provide_index][i][1]
+                    elif same_prob<=0.2:
+                        for i in range(len(inner_table[provide_index])):
+                            if inner_table[index][i][1]==inner_table[provide_index][i][1]:
+                                dye_set.add((inner_table[index][i][0],inner_table[provide_index][i][1]))
+                                key_set.add(inner_table[index][i][0])
+                            inner_table[index][i][1] = 0 if inner_table[provide_index][i][1] ==1 else 1
+        len_dye_set = len(dye_set)
+        height = len(inner_table)
+        for i in range(height):
+            if repairNeeded(inner_table[i]):
+                selfrepair(inner_table,i,dye_set,key_set)
+                #otherrepair(inner_table,i,dye_set,key_set)
+        for h in range(len(inner_table)):
+            for w in range(len(inner_table[0])):
+                if inner_table[h][w][0] in key_set:
+                    for item in dye_set:
+                        if inner_table[h][w][0]==item[0]:
+                            inner_table[h][w][1] = item[1]
+        #如果两个set长度不相同,则有同一个key被反复染色,将导致无限迭代
+        if len(dye_set)!=len(key_set):
             for i in range(height):
                 if repairNeeded(inner_table[i]):
                     selfrepair(inner_table,i,dye_set,key_set)
                     #otherrepair(inner_table,i,dye_set,key_set)
-            for h in range(len(inner_table)):
-                for w in range(len(inner_table[0])):
-                    if inner_table[h][w][0] in key_set:
-                        for item in dye_set:
-                            if inner_table[h][w][0]==item[0]:
-                                inner_table[h][w][1] = item[1]
-            #如果两个set长度不相同,则有同一个key被反复染色,将导致无限迭代
-            if len(dye_set)!=len(key_set):
-                for i in range(height):
-                    if repairNeeded(inner_table[i]):
-                        selfrepair(inner_table,i,dye_set,key_set)
-                        #otherrepair(inner_table,i,dye_set,key_set)
-                return
-            if len(dye_set)==len_dye_set:
-                '''
-                for i in range(height):
-                    if repairNeeded(inner_table[i]):
-                        otherrepair(inner_table,i,dye_set,key_set)
-                '''
-                return
-            repairTable(inner_table, dye_set, key_set)
-            
-                    
-                        
-                
-                    
-        set_item = set()
+            return
+        if len(dye_set)==len_dye_set:
+            '''
+            for i in range(height):
+                if repairNeeded(inner_table[i]):
+                    otherrepair(inner_table,i,dye_set,key_set)
+            '''
+            return
+        repairTable(inner_table, dye_set, key_set)
+
+    def sliceTable(inner_table,fix_value="~~"):
+        #进行分块
         height = len(inner_table)
         width = len(inner_table[0])
-        for i in range(height):
-            for j in range(width):
-                item = inner_table[i][j][0]
-                set_item.add(item)
-        list_item = list(set_item)
-        x = []
-        for item in list_item:
-            x.append(getPredictor("form").encode(item))
-        predict_y = getPredictor("form").predict(np.array(x),type="item")
-        _dict = dict()
-        
-        for item,values in zip(list_item,list(predict_y)):
-            _dict[item] = values[1]
-            # print("##",item,values)
-        #print(_dict)
-        for i in range(height):
-            for j in range(width):
-                item = inner_table[i][j][0]
-                inner_table[i][j][1] = 1 if _dict[item]>prob_min else (1 if re.search(pat_head,item) is not None and len(item)<8 else 0)
-        
-        repairTable(inner_table)
-        
-        #进行分块
         head_list = []
         head_list.append(0)
         last_head = None
+        last_is_same_value = False;
         for h in range(height):
             is_all_key = True#是否是全表头行
             is_all_value = True#是否是全属性值
@@ -361,13 +335,13 @@ def tableToText(soup):
                 if last_head is not None:
                     if inner_table[h-1][w][0]!=fix_value and inner_table[h-1][w][1] == 0:
                         is_all_key = False
-                        
+
                     if inner_table[h][w][0]==1:
                         is_all_value = False
-                    
+
                     if inner_table[h][w][1]!= inner_table[h-1][w][1]:
                         is_same_with_lastHead = False
-                    
+
                 if inner_table[h][w][0]!=fix_value and inner_table[h][w][0]!=same_value:
                     is_same_value = False
                 else:
@@ -375,20 +349,85 @@ def tableToText(soup):
                         is_same_value = False
             if h>0 and inner_table[h][0][0]!=inner_table[h-1][0][0]:
                 is_same_first_item = False
-                
+
             last_head = h
-                
+
+            if last_is_same_value:
+                last_is_same_value = is_same_value
+                continue
+
             if is_same_value:
                 head_list.append(h)
-                head_list.append(h+1)
+                last_is_same_value = is_same_value
                 continue
             if not is_all_key:
                 if not is_same_with_lastHead:
                     head_list.append(h)
-                    
-                    
+
+
         head_list.append(height)
+        return head_list
+                    
+    def setHead_initem(inner_table,pat_head,fix_value="~~",prob_min=0.5):
+
+
+        set_item = set()
+        height = len(inner_table)
+        width = len(inner_table[0])
+        for i in range(height):
+            for j in range(width):
+                item = inner_table[i][j][0]
+                set_item.add(item)
+        list_item = list(set_item)
+        x = []
+        for item in list_item:
+            x.append(getPredictor("form").encode(item))
+        predict_y = getPredictor("form").predict(np.array(x),type="item")
+        _dict = dict()
+        
+        for item,values in zip(list_item,list(predict_y)):
+            _dict[item] = values[1]
+            # print("##",item,values)
+        #print(_dict)
+        for i in range(height):
+            for j in range(width):
+                item = inner_table[i][j][0]
+                inner_table[i][j][1] = 1 if _dict[item]>prob_min else (1 if re.search(pat_head,item) is not None and len(item)<8 else 0)
+
+        # print("=====")
+        # for item in inner_table:
+        #     print(item)
+        # print("======")
+        
+        repairTable(inner_table)
+        head_list = sliceTable(inner_table)
+
         
+        return inner_table,head_list
+
+    def setHead_incontext(inner_table,pat_head,fix_value="~~",prob_min=0.5):
+
+        data_x,data_position = getPredictor("form").getModel("context").encode(inner_table)
+        predict_y = getPredictor("form").getModel("context").predict(data_x)
+
+        for _position,_y in zip(data_position,predict_y):
+            _w = _position[0]
+            _h = _position[1]
+            if _y[1]>prob_min:
+                inner_table[_h][_w][1] = 1
+            else:
+                inner_table[_h][_w][1] = 0
+
+
+        # print("=====")
+        # for item in inner_table:
+        #     print(item)
+        # print("======")
+
+        repairTable(inner_table)
+        head_list = sliceTable(inner_table)
+
+
         return inner_table,head_list
                     
     #设置表头
@@ -588,7 +627,7 @@ def tableToText(soup):
     def getTableText(inner_table,head_list,key_direct=False):
         # packPattern = "(标包|[标包][号段名])"
         packPattern = "(标包|[标包][号段名]|((项目|物资|设备|场次|标段|标的|产品)(名称)))"  # 2020/11/23 大网站规则,补充采购类包名
-        rankPattern = "(排名|排序|名次|序号|评标结果|评审结果)"  # 2020/11/23 大网站规则,添加序号为排序
+        rankPattern = "(排名|排序|名次|序号|评标结果|评审结果|是否中标)"  # 2020/11/23 大网站规则,添加序号为排序
         entityPattern = "(候选|([中投]标|报价)(人|单位|候选)|单位名称|供应商)"
         height = len(inner_table)
         width = len(inner_table[0])
@@ -770,7 +809,8 @@ def tableToText(soup):
         if len(inner_table)>0 and len(inner_table[0])>0:
             #inner_table,head_list = setHead_withRule(inner_table,pat_head,pat_value,3)
             #inner_table,head_list = setHead_inline(inner_table)
-            inner_table,head_list = setHead_initem(inner_table,pat_head)
+            # inner_table,head_list = setHead_initem(inner_table,pat_head)
+            inner_table,head_list = setHead_incontext(inner_table,pat_head)
             # print(inner_table)
             # for begin in range(len(head_list[:-1])):
             #     for item in inner_table[head_list[begin]:head_list[begin+1]]:
@@ -779,12 +819,11 @@ def tableToText(soup):
 
             removeFix(inner_table)
             
-            '''
-            print("----")
-            print(head_list)
-            for item in inner_table:
-                print(item)
-            '''
+            # print("----")
+            # print(head_list)
+            # for item in inner_table:
+            #     print(item)
+
 
             tbody.string = getTableText(inner_table,head_list)
             #print(tbody.string)

BIN
BiddingKG/dl/interface/form_context_savedmodel/saved_model.pb


BIN
BiddingKG/dl/interface/form_context_savedmodel/variables/variables.data-00000-of-00001


BIN
BiddingKG/dl/interface/form_context_savedmodel/variables/variables.index


+ 88 - 0
BiddingKG/dl/interface/modelFactory.py

@@ -292,6 +292,94 @@ class Model_form_item():
           return self.getModel().predict(x)
       '''
 
+class Model_form_context():
+    def __init__(self,lazyLoad=getLazyLoad()):
+        self.model_form = None
+        self.sess_form = tf.Session(graph=tf.Graph())
+        if not lazyLoad:
+            self.getModel()
+
+
+    def getModel(self):
+        if self.model_form is None:
+            with self.sess_form.as_default() as sess:
+                with sess.graph.as_default():
+                    meta_graph_def = tf.saved_model.loader.load(sess,tags=["serve"],export_dir="%s/form_context_savedmodel"%(os.path.dirname(__file__)))
+                    signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+                    signature_def = meta_graph_def.signature_def
+
+                    inputs = sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["inputs"].name)
+                    output = sess.graph.get_tensor_by_name(signature_def[signature_key].outputs["outputs"].name)
+
+                    self.model_form = [[inputs],output]
+        return self.model_form
+        '''
+          if self.model_form is None:
+              with self.graph.as_defalt():
+                  self.model_form = models.load_model(self.model_file,custom_objects={"precision":precision,"recall":recall,"f1_score":f1_score})
+          return self.model_form
+        '''
+
+    def encode_table(self,inner_table,size=30):
+        def encode_item(_table,i,j):
+            _x = [_table[j-1][i-1],_table[j-1][i],_table[j-1][i+1],
+                  _table[j][i-1],_table[j][i],_table[j][i+1],
+                  _table[j+1][i-1],_table[j+1][i],_table[j+1][i+1]]
+            e_x = [encodeInput_form(_temp[0],MAX_LEN=30) for _temp in _x]
+            _label = _table[j][i][1]
+            # print(_x)
+            # print(_x[4],_label)
+            return e_x,_label,_x
+
+        def copytable(inner_table):
+            table = []
+            for line in inner_table:
+                list_line = []
+                for item in line:
+                    list_line.append([item[0][:size],item[1]])
+                table.append(list_line)
+            return table
+
+        table = copytable(inner_table)
+        padding = ["#"*30,0]
+        width = len(table[0])
+        height = len(table)
+        table.insert(0,[padding for i in range(width)])
+        table.append([padding for i in range(width)])
+        for item in table:
+            item.insert(0,padding.copy())
+            item.append(padding.copy())
+        data_x = []
+        data_y = []
+        data_text = []
+        data_position = []
+        for _i in range(1,width+1):
+            for _j in range(1,height+1):
+                _x,_y,_text = encode_item(table,_i,_j)
+                data_x.append(_x)
+                _label = [0,0]
+                _label[_y] = 1
+                data_y.append(_label)
+                data_text.append(_text)
+                data_position.append([_i-1,_j-1])
+                # input = table[_j][_i][0]
+                # item_y = [0,0]
+                # item_y[table[_j][_i][1]] = 1
+                # data_x.append(encodeInput([input], word_len=50, word_flag=True,userFool=False)[0])
+                # data_y.append(item_y)
+        return data_x,data_y,data_text,data_position
+
+    def encode(self,inner_table,**kwargs):
+        data_x,_,_,data_position = self.encode_table(inner_table)
+        return data_x,data_position
+
+    def predict(self,x):
+
+        model_form = self.getModel()
+
+        list_result = limitRun(self.sess_form,[model_form[1]],feed_dict={model_form[0][0]:x})[0]
+        return list_result
+
 # class Model_form_item():
 #     def __init__(self,lazyLoad=False):
 #         self.model_file = os.path.dirname(__file__)+"/ep039-loss0.038-val_loss0.064-f10.9783.h5"

+ 11 - 6
BiddingKG/dl/interface/predictor.py

@@ -788,14 +788,17 @@ class FormPredictor():
         self.model_file_line = os.path.dirname(__file__)+"/../form/model/model_form.model_line.hdf5"
         self.model_file_item = os.path.dirname(__file__)+"/../form/model/model_form.model_item.hdf5"
         self.model_form_item = Model_form_item()
+        self.model_form_context = Model_form_context()
         self.model_dict = {"line":[None,self.model_file_line]}
         
         
     def getModel(self,type):
-        assert type in self.model_dict.keys()
-        if self.model_dict[type][0] is None:
-            self.model_dict[type][0] = models.load_model(self.model_dict[type][1],custom_objects={"precision":precision,"recall":recall,"f1_score":f1_score})
-        return self.model_dict[type][0]
+        if type=="item":
+            return self.model_form_item
+        elif type=="context":
+            return self.model_form_context
+        else:
+            return self.getModel(type)
 
     def encode(self,data,**kwargs):
         return encodeInput([data], word_len=50, word_flag=True,userFool=False)[0]
@@ -804,9 +807,11 @@ class FormPredictor():
     
     def predict(self,form_datas,type):
         if type=="item":
-          return self.model_form_item.predict(form_datas)
+            return self.model_form_item.predict(form_datas)
+        elif type=="context":
+            return self.model_form_context.predict(form_datas)
         else:
-          return self.getModel(type).predict(form_datas)
+            return self.getModel(type).predict(form_datas)
     
     
 #角色规则

BIN
BiddingKG/dl/test/list_sentence_entity.pk


+ 9 - 48
BiddingKG/dl/test/test4.py

@@ -23,7 +23,6 @@ import BiddingKG.dl.interface.predictor as predictor
 import BiddingKG.dl.interface.Preprocessing as Preprocessing
 import BiddingKG.dl.interface.getAttributes as getAttributes
 import BiddingKG.dl.entityLink.entityLink as entityLink
-import BiddingKG.dl.complaint.punish_rule as punish_rule
 import json
 
 
@@ -49,9 +48,6 @@ codeNamePredict = predictor.CodeNamePredict()
 premPredict = predictor.PREMPredict()
 epcPredict = predictor.EPCPredict()
 roleRulePredict = predictor.RoleRulePredictor()
-timePredict = predictor.TimePredictor()
-punish = punish_rule.Punish_Extract()
-
 
 #自定义jsonEncoder
 class MyEncoder(json.JSONEncoder):
@@ -83,43 +79,18 @@ def predict(doc_id,text):
     print("epcPredict")
     epcPredict.predict(list_sentences,list_entitys)
     print("entityLink")
-    timePredict.predict(list_sentences, list_entitys)
-    print("timePredict")
     entityLink.link_entitys(list_entitys)
     print("getPREMs")
     prem = getAttributes.getPREMs(list_sentences,list_entitys,list_articles)
     print("getPREMs")
-    punish_dic = punish.get_punish_extracts(list_sentences, list_entitys, title='投诉处理 ', text=text)
-    print(punish_dic)
-    prem[0][1]['punish'] = punish_dic
-    # 招标方式
-    bidway = [entity.entity_text for entity in list_entitys[0] if entity.entity_type=='bidway']
-    # 资金来源
-    moneySource = [entity.entity_text for entity in list_entitys[0] if entity.entity_type=='moneySource']
-    # 服务时间
-    servicetime = [entity.entity_text for entity in list_entitys[0] if entity.entity_type=='servicetime']
-    # 发布时间 time_release:1
-    time_release = [entity.entity_text for entity in list_entitys[0] if entity.entity_type == 'time' and entity.label==1]
-    # 开标时间  'time_bidopen':2,
-    time_bidopen = [entity.entity_text for entity in list_entitys[0] if entity.entity_type == 'time' and entity.label==2]
-    # 截标时间 'time_bidclose':3
-    time_bidclose = [entity.entity_text for entity in list_entitys[0] if entity.entity_type == 'time' and entity.label == 3]
-    prem[0][1]['bidway'] = ';'.join(set(bidway))
-    prem[0][1]['moneySource'] = ';'.join(set(moneySource))
-    prem[0][1]['servicetime'] = ';'.join(set(servicetime))
-    prem[0][1]['time_release'] = ';'.join(set(time_release))
-    prem[0][1]['time_bidopen'] = ';'.join(set(time_bidopen))
-    prem[0][1]['time_bidclose'] = ';'.join(set(time_bidclose))
-
-
-
+    
     
     ''''''
     
-
+    
     for entitys in list_entitys:
         for entity in entitys:
-            print(entity.entity_text,entity.entity_type,entity.label,entity.values,entity.sentence_index,entity.begin_index,entity.end_index,entity.wordOffset_begin,entity.wordOffset_end)
+            print(entity.entity_text,entity.entity_type,entity.label,entity.values,entity.sentence_index,entity.wordOffset_begin,entity.wordOffset_end)
 
     #print(prem)
     return json.dumps(Preprocessing.union_result(codeName, prem)[0][1],cls=MyEncoder,sort_keys=True,indent=4,ensure_ascii=False)
@@ -138,10 +109,10 @@ def test(name,content):
 
 
 if __name__=="__main__":
-    # filename = "比地_52_79929693.html"
-    # #text = codecs.open("C:\\Users\\User\\Desktop\\数据20191014\\"+filename,"r",encoding="utf8").read()
-    # text = codecs.open("C:\\Users\\User\\Desktop\\2.html","r",encoding="utf8").read()
-    # content = str(BeautifulSoup(text).find("div",id="pcontent"))
+    filename = "比地_52_79929693.html"
+    #text = codecs.open("C:\\Users\\User\\Desktop\\数据20191014\\"+filename,"r",encoding="utf8").read()
+    text = codecs.open("C:\\Users\\User\\Desktop\\2.html","r",encoding="utf8").read()
+    content = str(BeautifulSoup(text).find("div",id="pcontent"))
     # df_a = {"html":[]}
     # df_a["html"].append(re.sub('\r|\n|\r\n',"",content))
     # import pandas as pd
@@ -149,20 +120,10 @@ if __name__=="__main__":
     # df.to_csv("C:\\Users\\User\\Desktop\\ba.csv")
     # print()
     #text = codecs.open("C:\\Users\\User\\Desktop\\a.html","r",encoding="utf8").read()
-    # text = "张家港保税区宏宇建设咨询有限公司受张家港市给排水公司委托,就二次供水泵房浊度仪进行国内组织公开招标采购,欢迎符合条件的供应商参加投标。"
-    # text = 'a建设资金来源及性质:资本金40%,自筹60%,,xx.=建设资金来源自筹,项目出资比例为100%,as,建设资金来自呜呜呜。'
-    # text = '张家港保税区宏宇建设咨询有限公司受张家港市给排水公司委托,就二次供水泵房浊度仪进行国内组织公开招标采购,时间:2020-05-26,15:15:00,竞价结束时间:2020-05-26,15:45:00允许延时:是,延时规则:在剩余数量小于最小购买数量时,竞价进'
-    # text = '''大庆禾工煤炭分质清洁利用项目-临时用电二期工程设备、物资采购中标候选人公示,更多咨询报价请点击:http://bulletin.cebpubservice.com/candidateBulletin/2020-03-31/2678597.html,大庆禾工煤炭分质清洁利用顶目-临时用电二期工程设备、物资釆购中标候选人,(招标编号:XYwZ-20200309-5),公示结束时间:2020年04月03日,、评标情况,标段(包)[001大庆禾工煤嶽分质清洁利用项目-临时用屯二期工程设备、物资采购,中标候选人基本情况,
-    # 中标候选人第1名:哈尔滨龙网电力设备有限公司,投标报价:19.98万元,质量,合格,工期/交货期/服务期:30天,中标候选人第2名:
-    # 哈尔滨昊龙电气没备制造有限公司,投标报价:19.87万元,质,量:合格,工期/交货期/服务期:30天,'''
-    text = '中标候选人第1名:哈尔滨龙网电力设备有限公司,投标报价:19.98万元,质量,合格,工期/交货期/服务期:30天。\
-    投诉处理公告,投诉人:张三。文章编号:京财采投字(2018)第42号。政府采购项目招标方式:公开招标,联系人:黎明。\
-    建设资金来源及性质:资本金40%,自筹60%,,xx.=建设资金来源自筹,项目出资比例为100%,\
-    二次供水泵房浊度仪进行国内组织公开招标采购,时间:2020-05-26,15:15:00,竞价结束时间:2020-05-26,15:45:00允许延时:是,'
+    #text = "张家港保税区宏宇建设咨询有限公司受张家港市给排水公司委托,就二次供水泵房浊度仪进行国内组织公开招标采购,欢迎符合条件的供应商参加投标。"
     a = time.time()
     print("start")
-    # print(predict("12",content))
-    print(predict("投诉处理公告", text))
+    print(predict("12",content))
     #test("12",text)
     print("takes",time.time()-a)
     pass