data_util.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/1/13 0013 14:19
  5. import re
  6. import math
  7. import random
  8. import numpy as np
  9. from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_word,viterbi_decode
  10. id_to_tag = {0:'O',1:'B',2:'I',3:'E'}
  11. word_model = getModel_word()
  12. vocab, matrix = getVocabAndMatrix(word_model, Embedding_size=60)
  13. word2id = {k: v for v, k in enumerate(vocab)}
  14. max_id = len(vocab)
  15. def get_label_data():
  16. import psycopg2
  17. conn = psycopg2.connect(dbname='iepy_product', user='postgres', password='postgres', host='192.168.2.101')
  18. cursor = conn.cursor()
  19. sql = "select human_identifier, text from corpus_iedocument where edittime NOTNULL AND jump_signal=0 \
  20. and creation_date > to_timestamp('2021-01-14 00:00:00','yyyy-MM-dd HH24:mi:ss');"
  21. cursor.execute(sql)
  22. writer = open('label_data.txt', 'w', encoding='utf-8')
  23. datas = []
  24. for row in cursor.fetchall():
  25. docid = row[0]
  26. text = row[1]
  27. # string = list(text)
  28. tags = [0]*len(text)
  29. sql_lb = "select b.value from brat_bratannotation as b where document_id = '{}' and b.value like 'T%product%';".format(docid)
  30. cursor.execute(sql_lb)
  31. for row_lb in cursor.fetchall():
  32. label = row_lb[0]
  33. _, _, begin, end, _ = re.split('\s',label)
  34. begin = int(begin)
  35. end = int(end)
  36. if end-begin>=2:
  37. tags[begin]=1
  38. tags[end-1]=3
  39. for i in range(begin+1,end-1):
  40. tags[i]=2
  41. # datas.append([string, tags])
  42. text_sentence = []
  43. ids_sentence = []
  44. tag_sentence = []
  45. for i in range(len(text)):
  46. text_sentence.append(text[i])
  47. ids_sentence.append(word2id.get(text[i], max_id))
  48. tag_sentence.append(tags[i])
  49. writer.write("%s\t%s\n"%(text[i],tags[i]))
  50. if text[i] in ['。','?','!',';']:
  51. writer.write('\n')
  52. if text_sentence:
  53. if len(text_sentence) > 100:
  54. # if len(text_sentence)>5 and len(text_sentence)<1000:
  55. datas.append([text_sentence, ids_sentence,tag_sentence])
  56. elif len(text_sentence) > 5:
  57. continue
  58. else:
  59. print('单句小于5或大于100,句子长度为:%d,文章ID:%s'%(len(text_sentence), docid))
  60. text_sentence = []
  61. ids_sentence = []
  62. tag_sentence = []
  63. if text_sentence:
  64. if len(text_sentence) > 5:
  65. # if len(text_sentence) > 5 and len(text_sentence) < 1000:
  66. datas.append([text_sentence, ids_sentence, tag_sentence])
  67. else:
  68. print('单句小于5或大于100,句子长度为:%d,文章ID:%s' % (len(text_sentence), docid))
  69. writer.close()
  70. return datas
  71. def input_from_line(line):
  72. string = list(line)
  73. ids = [word2id.get(k, max_id) for k in string]
  74. tags = []
  75. return [[string], [ids], [tags]]
  76. def process_data(sentences):
  77. '''
  78. 字符串数字化并统一长度
  79. :param sentences: 文章分句字符串列表['招标公告','招标代理']
  80. :return: 数字化后的统一长度
  81. '''
  82. maxLen = max([len(sentence) for sentence in sentences])
  83. tags = [[word2id.get(k, max_id) for k in sentence] for sentence in sentences]
  84. pad_tags = [tag[:maxLen]+[0]*(maxLen-len(tag)) for tag in tags]
  85. return pad_tags
  86. def get_ner(BIE_tag):
  87. ner = set()
  88. for it in re.finditer('BI*E',BIE_tag):
  89. ner.add((it.start(),it.end()))
  90. return ner
  91. def decode(logits, lengths, matrix):
  92. paths = []
  93. small = -1000.0
  94. start = np.asarray([[small]*4+[0]])
  95. for score, length in zip(logits, lengths):
  96. score = score[:length]
  97. pad = small * np.ones([length, 1])
  98. logits = np.concatenate([score, pad], axis=1)
  99. logits = np.concatenate([start, logits], axis=0)
  100. path, _ = viterbi_decode(logits, matrix)
  101. paths.append(path[1:])
  102. return paths
  103. def result_to_json(line, tags):
  104. result = []
  105. ner = []
  106. tags = ''.join([str(it) for it in tags])
  107. for it in re.finditer("12*3", tags):
  108. start = it.start()
  109. end = it.end()
  110. ner.append([line[start:end], (start, end)])
  111. result.append([line, ner])
  112. print(tags)
  113. return result
  114. class BatchManager(object):
  115. def __init__(self, data, batch_size):
  116. self.batch_data = self.sort_and_pad(data, batch_size)
  117. self.len_data = len(self.batch_data)
  118. def sort_and_pad(self, data, batch_size):
  119. num_batch = int(math.ceil(len(data)/batch_size))
  120. sorted_data = sorted(data, key=lambda x:len(x[0]))
  121. print('最小句子长度:%d;最大句子长度:%d' % (len(sorted_data[0][0]), len(sorted_data[-1][0]))) # 临时增加打印句子长度
  122. batch_data = list()
  123. for i in range(num_batch):
  124. batch_data.append(self.pad_data(sorted_data[i*int(batch_size):(i+1)*int(batch_size)]))
  125. return batch_data
  126. @staticmethod
  127. def pad_data(data):
  128. strings = []
  129. chars = []
  130. targets = []
  131. max_length = max([len(sentence[0]) for sentence in data])
  132. for line in data:
  133. string, char, target = line
  134. padding = [0]*(max_length-len(string))
  135. strings.append(string + padding)
  136. chars.append(char + padding)
  137. targets.append(target + padding)
  138. return [strings, chars, targets]
  139. def iter_batch(self, shuffle=False):
  140. if shuffle:
  141. random.shuffle(self.batch_data)
  142. for idx in range(self.len_data):
  143. yield self.batch_data[idx]