''' Created on 2019年4月12日 @author: User ''' import BiddingKG.dl.interface.Preprocessing as Preprocessing from bs4 import BeautifulSoup import os import sys import codecs sys.path.append(os.path.abspath("../..")) from keras import models from keras import layers from keras_contrib.layers import CRF from keras.preprocessing.sequence import pad_sequences from keras import optimizers,losses,metrics from BiddingKG.dl.common.Utils import * import glob import pandas as pd def getArticle(article): return Preprocessing.segment(Preprocessing.tableToText(BeautifulSoup(article,"lxml"))) #编号名称模型 class CodeNamePredict(): def __init__(self,EMBED_DIM=None,BiRNN_UNITS=None): self.model = None self.MAX_LEN = None if EMBED_DIM is None: self.EMBED_DIM = 60 else: self.EMBED_DIM = EMBED_DIM if BiRNN_UNITS is None: self.BiRNN_UNITS = 40 else: self.BiRNN_UNITS = BiRNN_UNITS self.filepath = "../projectCode/models/model_project_"+str(self.EMBED_DIM)+"_"+str(self.BiRNN_UNITS)+".hdf5" vocabpath = "../projectCode/models/vocab.pk" classlabelspath = "../projectCode/models/classlabels.pk" self.vocab = load(vocabpath) self.class_labels = load(classlabelspath) #生成提取编号和名称的正则 id_PC_B = self.class_labels.index("PC_B") id_PC_M = self.class_labels.index("PC_M") id_PC_E = self.class_labels.index("PC_E") id_PN_B = self.class_labels.index("PN_B") id_PN_M = self.class_labels.index("PN_M") id_PN_E = self.class_labels.index("PN_E") self.PC_pattern = re.compile(str(id_PC_B)+str(id_PC_M)+"+"+str(id_PC_E)) self.PN_pattern = re.compile(str(id_PN_B)+str(id_PN_M)+"+"+str(id_PN_E)) self.word2index = dict((w,i) for i,w in enumerate(np.array(self.vocab))) def getModel(self): ''' @summary: 取得编号和名称模型 ''' if self.model is None: self.model = self.getBiLSTMCRFModel(self.MAX_LEN, self.vocab, self.EMBED_DIM, self.BiRNN_UNITS, self.class_labels,weights=None) self.model.load_weights(self.filepath) return self.model def getBiLSTMCRFModel(self,MAX_LEN,vocab,EMBED_DIM,BiRNN_UNITS,chunk_tags,weights): ''' model = models.Sequential() model.add(layers.Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding model.add(layers.Bidirectional(layers.LSTM(BiRNN_UNITS // 2, return_sequences=True))) crf = CRF(len(chunk_tags), sparse_target=True) model.add(crf) model.summary() model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) return model ''' input = layers.Input(shape=(None,)) if weights is not None: embedding = layers.embeddings.Embedding(len(vocab),EMBED_DIM,mask_zero=True,weights=[weights],trainable=True)(input) else: embedding = layers.embeddings.Embedding(len(vocab),EMBED_DIM,mask_zero=True)(input) bilstm = layers.Bidirectional(layers.LSTM(BiRNN_UNITS//2,return_sequences=True))(embedding) bilstm_dense = layers.TimeDistributed(layers.Dense(len(chunk_tags)))(bilstm) crf = CRF(len(chunk_tags),sparse_target=True) crf_out = crf(bilstm_dense) model = models.Model(input=[input],output = [crf_out]) model.summary() model.compile(optimizer = 'adam', loss = crf.loss_function, metrics = [crf.accuracy]) return model #根据规则补全编号或名称两边的符号 def fitDataByRule(self,data): symbol_dict = {"(":")", "(":")", "[":"]", "【":"】", ")":"(", ")":"(", "]":"[", "】":"【"} leftSymbol_pattern = re.compile("[\((\[【]") rightSymbol_pattern = re.compile("[\))\]】]") leftfinds = re.findall(leftSymbol_pattern,data) rightfinds = re.findall(rightSymbol_pattern,data) result = data if len(leftfinds)+len(rightfinds)==0: return data elif len(leftfinds)==len(rightfinds): return data elif abs(len(leftfinds)-len(rightfinds))==1: if len(leftfinds)>len(rightfinds): if symbol_dict.get(data[0]) is not None: result = data[1:] else: #print(symbol_dict.get(leftfinds[0])) result = data+symbol_dict.get(leftfinds[0]) else: if symbol_dict.get(data[-1]) is not None: result = data[:-1] else: result = symbol_dict.get(rightfinds[0])+data return result def predict(self,articles,list_data,set_data,filename,MAX_LEN = None): sentences = [] for article in articles: for sentence in article.split("。"): sentences.append([sentence,"1"]) if MAX_LEN is None: sent_len = [len(sentence[0]) for sentence in sentences] MAX_LEN = max(sent_len) #print(MAX_LEN) #若为空,则直接返回空 result = [] if MAX_LEN==0: for article in articles: result.append(["1",{"code":[],"name":""}]) return result index_unk = self.word2index.get("") index_pad = self.word2index.get("") x = [[self.word2index.get(word,index_unk)for word in sentence[0]]for sentence in sentences] x = pad_sequences(x,maxlen=MAX_LEN,padding="post",truncating="post") predict_y = self.getModel().predict(x) last_doc_id = "" item = [] for sentence,predict in zip(sentences,np.argmax(predict_y,-1)): pad_sentence = sentence[0][:MAX_LEN] #print("11",pad_sentence) doc_id = sentence[1] join_predict = "".join([str(s) for s in predict]) if doc_id!=last_doc_id: if last_doc_id!="": result.append(item) item = [doc_id,{"code":[],"name":""}] code_set = set() for iter in re.finditer(self.PC_pattern,join_predict): get_len = 40 if iter.span()[0]