data_util.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/1/13 0013 14:19
  5. import re
  6. import math
  7. import random
  8. import psycopg2
  9. import numpy as np
  10. from tensorflow.contrib.crf import viterbi_decode
  11. from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_word
  12. id_to_tag = {0:'O',1:'B',2:'I',3:'E'}
  13. word_model = getModel_word()
  14. vocab, matrix = getVocabAndMatrix(word_model, Embedding_size=60)
  15. word2id = {k: v for v, k in enumerate(vocab)}
  16. max_id = len(vocab)
  17. conn = psycopg2.connect(dbname='iepy_product', user='postgres', password='postgres', host='192.168.2.101')
  18. cursor = conn.cursor()
  19. def get_label_data():
  20. sql = "select human_identifier, text from corpus_iedocument where edittime NOTNULL AND jump_signal=0 \
  21. and creation_date > to_timestamp('2021-01-14 00:00:00','yyyy-MM-dd HH24:mi:ss');"
  22. cursor.execute(sql)
  23. writer = open('label_data.txt', 'w', encoding='utf-8')
  24. datas = []
  25. for row in cursor.fetchall():
  26. docid = row[0]
  27. text = row[1]
  28. # string = list(text)
  29. tags = [0]*len(text)
  30. sql_lb = "select b.value from brat_bratannotation as b where document_id = '{}' and b.value like 'T%product%';".format(docid)
  31. cursor.execute(sql_lb)
  32. for row_lb in cursor.fetchall():
  33. label = row_lb[0]
  34. _, _, begin, end, _ = re.split('\s',label)
  35. begin = int(begin)
  36. end = int(end)
  37. if end-begin>=2:
  38. tags[begin]=1
  39. tags[end-1]=3
  40. for i in range(begin+1,end-1):
  41. tags[i]=2
  42. # datas.append([string, tags])
  43. text_sentence = []
  44. ids_sentence = []
  45. tag_sentence = []
  46. for i in range(len(text)):
  47. text_sentence.append(text[i])
  48. ids_sentence.append(word2id.get(text[i], max_id))
  49. tag_sentence.append(tags[i])
  50. writer.write("%s\t%s\n"%(text[i],tags[i]))
  51. if text[i] in ['。','?','!',';']:
  52. writer.write('\n')
  53. if text_sentence:
  54. if len(text_sentence) > 100:
  55. # if len(text_sentence)>5 and len(text_sentence)<1000:
  56. datas.append([text_sentence, ids_sentence,tag_sentence])
  57. elif len(text_sentence) > 5:
  58. continue
  59. else:
  60. print('单句小于5或大于100,句子长度为:%d,文章ID:%s'%(len(text_sentence), docid))
  61. text_sentence = []
  62. ids_sentence = []
  63. tag_sentence = []
  64. if text_sentence:
  65. if len(text_sentence) > 5:
  66. # if len(text_sentence) > 5 and len(text_sentence) < 1000:
  67. datas.append([text_sentence, ids_sentence, tag_sentence])
  68. else:
  69. print('单句小于5或大于100,句子长度为:%d,文章ID:%s' % (len(text_sentence), docid))
  70. writer.close()
  71. return datas
  72. def input_from_line(line):
  73. string = list(line)
  74. ids = [word2id.get(k, max_id) for k in string]
  75. tags = []
  76. return [[string], [ids], [tags]]
  77. def process_data(sentences):
  78. '''
  79. 字符串数字化并统一长度
  80. :param sentences: 文章分句字符串列表['招标公告','招标代理']
  81. :return: 数字化后的统一长度
  82. '''
  83. maxLen = max([len(sentence) for sentence in sentences])
  84. tags = [[word2id.get(k, max_id) for k in sentence] for sentence in sentences]
  85. pad_tags = [tag[:maxLen]+[0]*(maxLen-len(tag)) for tag in tags]
  86. return pad_tags
  87. def get_ner(BIE_tag):
  88. ner = set()
  89. for it in re.finditer('BI*E',BIE_tag):
  90. ner.add((it.start(),it.end()))
  91. return ner
  92. def decode(logits, lengths, matrix):
  93. paths = []
  94. small = -1000.0
  95. start = np.asarray([[small]*4+[0]])
  96. for score, length in zip(logits, lengths):
  97. score = score[:length]
  98. pad = small * np.ones([length, 1])
  99. logits = np.concatenate([score, pad], axis=1)
  100. logits = np.concatenate([start, logits], axis=0)
  101. path, _ = viterbi_decode(logits, matrix)
  102. paths.append(path[1:])
  103. return paths
  104. def result_to_json(line, tags):
  105. result = []
  106. ner = []
  107. tags = ''.join([str(it) for it in tags])
  108. for it in re.finditer("12*3", tags):
  109. start = it.start()
  110. end = it.end()
  111. ner.append([line[start:end], (start, end)])
  112. result.append([line, ner])
  113. print(tags)
  114. return result
  115. class BatchManager(object):
  116. def __init__(self, data, batch_size):
  117. self.batch_data = self.sort_and_pad(data, batch_size)
  118. self.len_data = len(self.batch_data)
  119. def sort_and_pad(self, data, batch_size):
  120. num_batch = int(math.ceil(len(data)/batch_size))
  121. sorted_data = sorted(data, key=lambda x:len(x[0]))
  122. print('最小句子长度:%d;最大句子长度:%d' % (len(sorted_data[0][0]), len(sorted_data[-1][0]))) # 临时增加打印句子长度
  123. batch_data = list()
  124. for i in range(num_batch):
  125. batch_data.append(self.pad_data(sorted_data[i*int(batch_size):(i+1)*int(batch_size)]))
  126. return batch_data
  127. @staticmethod
  128. def pad_data(data):
  129. strings = []
  130. chars = []
  131. targets = []
  132. max_length = max([len(sentence[0]) for sentence in data])
  133. for line in data:
  134. string, char, target = line
  135. padding = [0]*(max_length-len(string))
  136. strings.append(string + padding)
  137. chars.append(char + padding)
  138. targets.append(target + padding)
  139. return [strings, chars, targets]
  140. def iter_batch(self, shuffle=False):
  141. if shuffle:
  142. random.shuffle(self.batch_data)
  143. for idx in range(self.len_data):
  144. yield self.batch_data[idx]