|
@@ -52,11 +52,11 @@ class Punish_Extract():
|
|
|
self.logits = self.sess.graph.get_tensor_by_name("CRF/output/logits:0")
|
|
|
|
|
|
# 处罚编号预测
|
|
|
- def predict_punishCode(self,list_sentences, MAXlLEN=5000):
|
|
|
+ def predict_punishCode(self,list_sentences, MAX_AREA=5000):
|
|
|
'''
|
|
|
每个句子预测处罚编号
|
|
|
:param list_sentences: 多篇文章句子列表[[每篇文章句子列表]]
|
|
|
- :param MAXlLEN: 控制最大每个句子长度,超过截断
|
|
|
+ :param MAX_AREA: 控制最大每个句子长度,超过截断
|
|
|
:return: 处罚编号字符串,若有多个;号隔开
|
|
|
'''
|
|
|
re_ner = re.compile("12+?3")
|
|
@@ -67,35 +67,44 @@ class Punish_Extract():
|
|
|
for sentences in list_sentences:
|
|
|
count += 1
|
|
|
# print(count)
|
|
|
- sentence_len = [len(sentence.sentence_text) for sentence in sentences]
|
|
|
- maxlen = min(max(sentence_len),MAXlLEN)
|
|
|
- sentences_x = []
|
|
|
- for sentence in sentences:
|
|
|
- sentence = sentence.sentence_text
|
|
|
- sentence = list(sentence)
|
|
|
- sentence2id = [getIndexOfWord(word) for word in sentence]
|
|
|
- sentences_x.append(sentence2id)
|
|
|
- sentences_x = pad_sequences(sentences_x, maxlen=maxlen, padding="post", truncating="post")
|
|
|
- sentences_x = [np.array(x) for x in sentences_x]
|
|
|
- _logits, _trans = self.sess.run([self.logits, self.trans],
|
|
|
- feed_dict={self.char_input: np.array(sentences_x), self.length: sentence_len})
|
|
|
- viterbi_sequence = decode(logits=_logits, trans=_trans, sequence_lengths=sentence_len, tag_num=4)
|
|
|
+ sentences.sort(key=lambda x: len(x.sentence_text), reverse=True)
|
|
|
+ _begin_index = 0
|
|
|
+ while True:
|
|
|
+ MAX_LEN = len(sentences[_begin_index].sentence_text)
|
|
|
+ if MAX_LEN > MAX_AREA:
|
|
|
+ MAX_LEN = MAX_AREA
|
|
|
+ _LEN = MAX_AREA // MAX_LEN
|
|
|
+ sentence_len = [len(sentence.sentence_text) for sentence in sentences[_begin_index:_begin_index+_LEN]]
|
|
|
+ sentences_x = []
|
|
|
+ for sentence in sentences[_begin_index:_begin_index+_LEN]:
|
|
|
+ sentence = sentence.sentence_text
|
|
|
+ sentence = list(sentence)
|
|
|
+ sentence2id = [getIndexOfWord(word) for word in sentence]
|
|
|
+ sentences_x.append(sentence2id)
|
|
|
+ sentences_x = pad_sequences(sentences_x, maxlen=MAX_LEN, padding="post", truncating="post")
|
|
|
+ sentences_x = [np.array(x) for x in sentences_x]
|
|
|
+ _logits, _trans = self.sess.run([self.logits, self.trans],
|
|
|
+ feed_dict={self.char_input: np.array(sentences_x), self.length: sentence_len})
|
|
|
+ viterbi_sequence = decode(logits=_logits, trans=_trans, sequence_lengths=sentence_len, tag_num=4)
|
|
|
|
|
|
- ner_list = []
|
|
|
- for _seq, sentence in zip(viterbi_sequence, sentences):
|
|
|
- sentence = sentence.sentence_text
|
|
|
- seq_id = ''.join([str(s) for s in _seq])
|
|
|
- if re_ner.search(seq_id):
|
|
|
- # print("sentence: ",sentence)
|
|
|
- for _ner in re_ner.finditer(seq_id):
|
|
|
- start = _ner.start()
|
|
|
- end = _ner.end()
|
|
|
- n = sentence[start:end]
|
|
|
- # print(n,'<==>',start,end)
|
|
|
- # ner_list.append((n, start, end))
|
|
|
- ner_list.append(n) # 改为只返回实体字符
|
|
|
- # article_ner_list.append(ner_list)
|
|
|
- article_ner_list.append(';'.join(set(ner_list)))
|
|
|
+ ner_list = []
|
|
|
+ for _seq, sentence in zip(viterbi_sequence, sentences[_begin_index:_begin_index+_LEN]):
|
|
|
+ sentence = sentence.sentence_text
|
|
|
+ seq_id = ''.join([str(s) for s in _seq])
|
|
|
+ if re_ner.search(seq_id):
|
|
|
+ # print("sentence: ",sentence)
|
|
|
+ for _ner in re_ner.finditer(seq_id):
|
|
|
+ start = _ner.start()
|
|
|
+ end = _ner.end()
|
|
|
+ n = sentence[start:end]
|
|
|
+ # print(n,'<==>',start,end)
|
|
|
+ # ner_list.append((n, start, end))
|
|
|
+ ner_list.append(n) # 改为只返回实体字符
|
|
|
+ # article_ner_list.append(ner_list)
|
|
|
+ article_ner_list.append(';'.join(set(ner_list)))
|
|
|
+ if _begin_index+_LEN >= len(sentences):
|
|
|
+ break
|
|
|
+ _begin_index += _LEN
|
|
|
return article_ner_list[0]
|
|
|
|
|
|
# 处罚类型
|