Jelajahi Sumber

调整产品字段提取及处罚编号提取逻辑,控制每批最大数据,防止内存不足

bidi 4 tahun lalu
induk
melakukan
5c090fa4d6
2 mengubah file dengan 69 tambahan dan 53 penghapusan
  1. 39 30
      BiddingKG/dl/complaint/punish_predictor.py
  2. 30 23
      BiddingKG/dl/interface/predictor.py

+ 39 - 30
BiddingKG/dl/complaint/punish_predictor.py

@@ -52,11 +52,11 @@ class Punish_Extract():
                     self.logits = self.sess.graph.get_tensor_by_name("CRF/output/logits:0")
 
     # 处罚编号预测
-    def predict_punishCode(self,list_sentences, MAXlLEN=5000):
+    def predict_punishCode(self,list_sentences, MAX_AREA=5000):
         '''
         每个句子预测处罚编号
         :param list_sentences: 多篇文章句子列表[[每篇文章句子列表]]
-        :param MAXlLEN: 控制最大每个句子长度,超过截断
+        :param MAX_AREA: 控制最大每个句子长度,超过截断
         :return: 处罚编号字符串,若有多个;号隔开
         '''
         re_ner = re.compile("12+?3")
@@ -67,35 +67,44 @@ class Punish_Extract():
                 for sentences in list_sentences:
                     count += 1
                     # print(count)
-                    sentence_len = [len(sentence.sentence_text) for sentence in sentences]
-                    maxlen = min(max(sentence_len),MAXlLEN)
-                    sentences_x = []
-                    for sentence in sentences:
-                        sentence = sentence.sentence_text
-                        sentence = list(sentence)
-                        sentence2id = [getIndexOfWord(word) for word in sentence]
-                        sentences_x.append(sentence2id)
-                    sentences_x = pad_sequences(sentences_x, maxlen=maxlen, padding="post", truncating="post")
-                    sentences_x = [np.array(x) for x in sentences_x]
-                    _logits, _trans = self.sess.run([self.logits, self.trans],
-                                               feed_dict={self.char_input: np.array(sentences_x), self.length: sentence_len})
-                    viterbi_sequence = decode(logits=_logits, trans=_trans, sequence_lengths=sentence_len, tag_num=4)
+                    sentences.sort(key=lambda x: len(x.sentence_text), reverse=True)
+                    _begin_index = 0
+                    while True:
+                        MAX_LEN = len(sentences[_begin_index].sentence_text)
+                        if MAX_LEN > MAX_AREA:
+                            MAX_LEN = MAX_AREA
+                        _LEN = MAX_AREA // MAX_LEN
+                        sentence_len = [len(sentence.sentence_text) for sentence in sentences[_begin_index:_begin_index+_LEN]]
+                        sentences_x = []
+                        for sentence in sentences[_begin_index:_begin_index+_LEN]:
+                            sentence = sentence.sentence_text
+                            sentence = list(sentence)
+                            sentence2id = [getIndexOfWord(word) for word in sentence]
+                            sentences_x.append(sentence2id)
+                        sentences_x = pad_sequences(sentences_x, maxlen=MAX_LEN, padding="post", truncating="post")
+                        sentences_x = [np.array(x) for x in sentences_x]
+                        _logits, _trans = self.sess.run([self.logits, self.trans],
+                                                   feed_dict={self.char_input: np.array(sentences_x), self.length: sentence_len})
+                        viterbi_sequence = decode(logits=_logits, trans=_trans, sequence_lengths=sentence_len, tag_num=4)
 
-                    ner_list = []
-                    for _seq, sentence in zip(viterbi_sequence, sentences):
-                        sentence = sentence.sentence_text
-                        seq_id = ''.join([str(s) for s in _seq])
-                        if re_ner.search(seq_id):
-                            # print("sentence: ",sentence)
-                            for _ner in re_ner.finditer(seq_id):
-                                start = _ner.start()
-                                end = _ner.end()
-                                n = sentence[start:end]
-                                # print(n,'<==>',start,end)
-                                # ner_list.append((n, start, end))
-                                ner_list.append(n)  # 改为只返回实体字符
-                    # article_ner_list.append(ner_list)
-                    article_ner_list.append(';'.join(set(ner_list)))
+                        ner_list = []
+                        for _seq, sentence in zip(viterbi_sequence, sentences[_begin_index:_begin_index+_LEN]):
+                            sentence = sentence.sentence_text
+                            seq_id = ''.join([str(s) for s in _seq])
+                            if re_ner.search(seq_id):
+                                # print("sentence: ",sentence)
+                                for _ner in re_ner.finditer(seq_id):
+                                    start = _ner.start()
+                                    end = _ner.end()
+                                    n = sentence[start:end]
+                                    # print(n,'<==>',start,end)
+                                    # ner_list.append((n, start, end))
+                                    ner_list.append(n)  # 改为只返回实体字符
+                        # article_ner_list.append(ner_list)
+                        article_ner_list.append(';'.join(set(ner_list)))
+                        if _begin_index+_LEN >= len(sentences):
+                            break
+                        _begin_index += _LEN
         return article_ner_list[0]
 
     # 处罚类型

+ 30 - 23
BiddingKG/dl/interface/predictor.py

@@ -1245,29 +1245,36 @@ class ProductPredictor():
                     _begin_index = 0
                     item = {"product":[]}
                     temp_list = []
-                    MAX_LEN = len(list_sentence[_begin_index].sentence_text)
-                    if MAX_LEN > MAX_AREA:
-                        MAX_LEN = MAX_AREA
-                    chars = process_data([sentence.sentence_text[:MAX_LEN] for sentence in list_sentence])
-                    lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran], feed_dict={self.char_input: np.asarray(chars),
-                                                                                        self.dropout: 1.0
-                                                                                        })
-                    batch_paths = decode(scores, lengths, tran_)
-                    for sentence, path, length in zip(list_sentence,batch_paths, lengths):
-                        tags = ''.join([str(it) for it in path[:length]])
-                        for it in re.finditer("12*3", tags):
-                            start = it.start()
-                            end = it.end()
-                            _entity = Entity(doc_id=sentence.doc_id, entity_id="%s_%s_%s_%s" % (
-                            sentence.doc_id, sentence.sentence_index, start, end),
-                                             entity_text=sentence.sentence_text[start:end],
-                                             entity_type="product", sentence_index=sentence.sentence_index, begin_index=0,
-                                             end_index=0, wordOffset_begin=start,
-                                             wordOffset_end=end)
-                            list_entity.append(_entity)
-                            temp_list.append(sentence.sentence_text[start:end])
-                    item["product"] = list(set(temp_list))
-                    result.append(item)
+                    while True:
+                        MAX_LEN = len(list_sentence[_begin_index].sentence_text)
+                        if MAX_LEN > MAX_AREA:
+                            MAX_LEN = MAX_AREA
+                        _LEN = MAX_AREA//MAX_LEN
+                        chars = process_data([sentence.sentence_text[:MAX_LEN] for sentence in list_sentence[_begin_index:_begin_index+_LEN]])
+                        lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran],
+                                                          feed_dict={
+                                                                    self.char_input: np.asarray(chars),
+                                                                    self.dropout: 1.0
+                                                                    })
+                        batch_paths = decode(scores, lengths, tran_)
+                        for sentence, path, length in zip(list_sentence[_begin_index:_begin_index+_LEN],batch_paths, lengths):
+                            tags = ''.join([str(it) for it in path[:length]])
+                            for it in re.finditer("12*3", tags):
+                                start = it.start()
+                                end = it.end()
+                                _entity = Entity(doc_id=sentence.doc_id, entity_id="%s_%s_%s_%s" % (
+                                sentence.doc_id, sentence.sentence_index, start, end),
+                                                 entity_text=sentence.sentence_text[start:end],
+                                                 entity_type="product", sentence_index=sentence.sentence_index,
+                                                 begin_index=0, end_index=0, wordOffset_begin=start,
+                                                 wordOffset_end=end)
+                                list_entity.append(_entity)
+                                temp_list.append(sentence.sentence_text[start:end])
+                        item["product"] = list(set(temp_list))
+                        result.append(item)
+                        if _begin_index+_LEN >= len(list_sentence):
+                            break
+                        _begin_index += _LEN
                 return result
 
 def getSavedModel():