data_util.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/1/13 0013 14:19
  5. import re
  6. import os
  7. import math
  8. import json
  9. import random
  10. import numpy as np
  11. import pandas as pd
  12. from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_word,viterbi_decode, load
  13. tag2id = {'S':0,'B-pro':1, 'I-pro':2, 'E-pro':3, 'B-rea':4, 'I-rea':5, 'E-rea':6}
  14. id_to_tag = {v:k for k,v in tag2id.items()}
  15. path1 = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))+"/interface/codename_vocab.pk"
  16. path2 = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))+"/interface/codename_w2v_matrix.pk"
  17. vocab = load(path1)
  18. matrix = load(path2)
  19. max_id = len(vocab)
  20. word2id = {k: v for v, k in enumerate(vocab)}
  21. def df2data(df):
  22. import pandas as pd
  23. import json
  24. datas = []
  25. for idx in df.index:
  26. docid = df.loc[idx, 'docid']
  27. text = df.loc[idx, 'text']
  28. # string = list(text)
  29. tags = [0]*len(text)
  30. labels = json.loads(df.loc[idx, 'label'])
  31. for label in labels:
  32. _, _, begin, end, _ = re.split('\s',label)
  33. begin = int(begin)
  34. end = int(end)
  35. if end-begin>=2:
  36. tags[begin]=1
  37. tags[end-1]=3
  38. for i in range(begin+1,end-1):
  39. tags[i]=2
  40. # datas.append([string, tags])
  41. text_sentence = []
  42. ids_sentence = []
  43. tag_sentence = []
  44. for i in range(len(text)):
  45. text_sentence.append(text[i])
  46. # ids_sentence.append(word2id.get(text[i], max_id))
  47. ids_sentence.append(word2id.get(text[i], word2id.get('<unk>')))
  48. tag_sentence.append(tags[i])
  49. if text[i] in ['。','!']:
  50. if text_sentence:
  51. # if len(text_sentence) > 100:
  52. if len(text_sentence)>5 and len(text_sentence)<1000:
  53. datas.append([text_sentence, ids_sentence,tag_sentence])
  54. else:
  55. print('单句小于5或大于1000,句子长度为:%d,文章ID:%s'%(len(text_sentence), docid))
  56. text_sentence = []
  57. ids_sentence = []
  58. tag_sentence = []
  59. if text_sentence:
  60. # if len(text_sentence) > 5:
  61. if len(text_sentence) > 5 and len(text_sentence) < 1000:
  62. datas.append([text_sentence, ids_sentence, tag_sentence])
  63. else:
  64. print('单句小于5或大于1000,句子长度为:%d,文章ID:%s' % (len(text_sentence), docid))
  65. return datas
  66. def find_kw_from_text(kw, s):
  67. '''
  68. 输入关键词及句子信息,返回句子中关键词的所有出现位置
  69. :param kw: 关键词
  70. :param s: 文本
  71. :return:
  72. '''
  73. begin = s.find(kw, 0)
  74. kws = []
  75. while begin!=-1:
  76. end = begin + len(kw)
  77. # print(s[begin:end])
  78. kws.append((begin, end))
  79. begin = s.find(kw, end)
  80. return kws
  81. def get_feature(text, lbs):
  82. '''
  83. 输入文章预处理后文本内容及产品名称列表,返回句子列表,数字化句子列表,数字化标签列表
  84. :param text: 文本内容
  85. :param lbs: 产品名称列表
  86. :return:
  87. '''
  88. lbs = sorted(set(lbs), key=lambda x: len(x), reverse=True)
  89. sentences = []
  90. ids_list = []
  91. tags_list = []
  92. for sentence in text.split('。'):
  93. if len(sentence) < 5:
  94. continue
  95. if len(sentence) > 1000:
  96. sentence = sentence[:1000]
  97. tags = [0] * len(sentence)
  98. # ids = [word2id.get(word, max_id) for word in sentence]
  99. ids = [word2id.get(word, word2id.get('<unk>')) for word in sentence]
  100. for lb in lbs:
  101. kw_indexs = find_kw_from_text(lb, sentence)
  102. for indexs in kw_indexs:
  103. b, e = indexs
  104. if tags[b] == 0 and tags[e - 1] == 0:
  105. tags[b] = 1
  106. tags[e - 1] = 3
  107. for i in range(b+1, e - 1):
  108. tags[i] = 2
  109. sentences.append(list(sentence))
  110. ids_list.append(ids)
  111. tags_list.append(tags)
  112. return sentences, ids_list, tags_list
  113. def dfsearchlb(df):
  114. datas = []
  115. for i in df.index:
  116. text = df.loc[i, 'text']
  117. lbs = json.loads(df.loc[i, 'lbset'])
  118. sentences, ids_list, tags_list = get_feature(text, lbs)
  119. for sen, ids, tags in zip(sentences, ids_list, tags_list):
  120. datas.append([sen, ids, tags])
  121. return datas
  122. def get_label_data():
  123. import psycopg2
  124. conn = psycopg2.connect(dbname='iepy_product', user='postgres', password='postgres', host='192.168.2.101')
  125. cursor = conn.cursor()
  126. sql = "select human_identifier, text from corpus_iedocument where edittime NOTNULL AND jump_signal=0 \
  127. and creation_date > to_timestamp('2021-01-14 00:00:00','yyyy-MM-dd HH24:mi:ss');"
  128. cursor.execute(sql)
  129. writer = open('label_data.txt', 'w', encoding='utf-8')
  130. datas = []
  131. for row in cursor.fetchall():
  132. docid = row[0]
  133. text = row[1]
  134. # string = list(text)
  135. tags = [0]*len(text)
  136. sql_lb = "select b.value from brat_bratannotation as b where document_id = '{}' and b.value like 'T%product%';".format(docid)
  137. cursor.execute(sql_lb)
  138. for row_lb in cursor.fetchall():
  139. label = row_lb[0]
  140. _, _, begin, end, _ = re.split('\s',label)
  141. begin = int(begin)
  142. end = int(end)
  143. if end-begin>=2:
  144. tags[begin]=1
  145. tags[end-1]=3
  146. for i in range(begin+1,end-1):
  147. tags[i]=2
  148. # datas.append([string, tags])
  149. text_sentence = []
  150. ids_sentence = []
  151. tag_sentence = []
  152. for i in range(len(text)):
  153. text_sentence.append(text[i])
  154. # ids_sentence.append(word2id.get(text[i], max_id))
  155. ids_sentence.append(word2id.get(text[i], word2id.get('<unk>')))
  156. tag_sentence.append(tags[i])
  157. writer.write("%s\t%s\n"%(text[i],tags[i]))
  158. if text[i] in ['。','?','!',';']:
  159. writer.write('\n')
  160. if text_sentence:
  161. if len(text_sentence) > 100:
  162. # if len(text_sentence)>5 and len(text_sentence)<1000:
  163. datas.append([text_sentence, ids_sentence,tag_sentence])
  164. elif len(text_sentence) > 5:
  165. continue
  166. else:
  167. print('单句小于5或大于100,句子长度为:%d,文章ID:%s'%(len(text_sentence), docid))
  168. text_sentence = []
  169. ids_sentence = []
  170. tag_sentence = []
  171. if text_sentence:
  172. if len(text_sentence) > 5:
  173. # if len(text_sentence) > 5 and len(text_sentence) < 1000:
  174. datas.append([text_sentence, ids_sentence, tag_sentence])
  175. else:
  176. print('单句小于5或大于100,句子长度为:%d,文章ID:%s' % (len(text_sentence), docid))
  177. writer.close()
  178. return datas
  179. def input_from_line(line):
  180. string = list(line)
  181. # ids = [word2id.get(k, max_id) for k in string]
  182. ids = [word2id.get(k, word2id.get('<unk>')) for k in string]
  183. tags = []
  184. return [[string], [ids], [tags]]
  185. def process_data(sentences):
  186. '''
  187. 字符串数字化并统一长度
  188. :param sentences: 文章分句字符串列表['招标公告','招标代理']
  189. :return: 数字化后的统一长度
  190. '''
  191. maxLen = max([len(sentence) for sentence in sentences])
  192. # tags = [[word2id.get(k, max_id) for k in sentence] for sentence in sentences]
  193. tags = [[word2id.get(k, word2id.get('<unk>')) for k in sentence] for sentence in sentences]
  194. pad_tags = [tag[:maxLen]+[0]*(maxLen-len(tag)) for tag in tags]
  195. return pad_tags
  196. def get_ner(BIE_tag):
  197. ner = set()
  198. for it in re.finditer('BI*E',BIE_tag):
  199. ner.add((it.start(),it.end()))
  200. return ner
  201. def decode(logits, lengths, matrix):
  202. paths = []
  203. small = -1000.0
  204. # start = np.asarray([[small]*4+[0]]) # 只有产品
  205. start = np.asarray([[small]*7+[0]]) # 产品及失败原因
  206. for score, length in zip(logits, lengths):
  207. score = score[:length]
  208. pad = small * np.ones([length, 1])
  209. logits = np.concatenate([score, pad], axis=1)
  210. logits = np.concatenate([start, logits], axis=0)
  211. path, _ = viterbi_decode(logits, matrix)
  212. paths.append(path[1:])
  213. return paths
  214. def result_to_json(line, tags):
  215. result = []
  216. ner = []
  217. tags = ''.join([str(it) for it in tags])
  218. for it in re.finditer("12*3", tags):
  219. start = it.start()
  220. end = it.end()
  221. ner.append([line[start:end], (start, end)])
  222. # for it in re.finditer("45*6", tags):
  223. # start = it.start()
  224. # end = it.end()
  225. # ner.append([line[start:end], (start, end)])
  226. result.append([line, ner])
  227. # print(tags)
  228. return result
  229. class BatchManager(object):
  230. def __init__(self, data, batch_size):
  231. self.batch_data = self.sort_and_pad(data, batch_size)
  232. self.len_data = len(self.batch_data)
  233. def sort_and_pad(self, data, batch_size):
  234. num_batch = int(math.ceil(len(data)/batch_size))
  235. sorted_data = sorted(data, key=lambda x:len(x[0]))
  236. print('最小句子长度:%d;最大句子长度:%d' % (len(sorted_data[0][0]), len(sorted_data[-1][0]))) # 临时增加打印句子长度
  237. batch_data = list()
  238. for i in range(num_batch):
  239. batch_data.append(self.pad_data(sorted_data[i*int(batch_size):(i+1)*int(batch_size)]))
  240. return batch_data
  241. @staticmethod
  242. def pad_data(data):
  243. strings = []
  244. chars = []
  245. targets = []
  246. max_length = max([len(sentence[0]) for sentence in data])
  247. for line in data:
  248. string, char, target = line
  249. padding = [0]*(max_length-len(string))
  250. strings.append(string + padding)
  251. chars.append(char + padding)
  252. targets.append(target + padding)
  253. return [strings, chars, targets]
  254. def iter_batch(self, shuffle=False):
  255. if shuffle:
  256. random.shuffle(self.batch_data)
  257. for idx in range(self.len_data):
  258. yield self.batch_data[idx]
  259. def 获取原始标注数据():
  260. import psycopg2
  261. import json
  262. conn = psycopg2.connect(dbname='iepy_product', user='postgres', password='postgres', host='192.168.2.103')
  263. cursor = conn.cursor()
  264. sql = "select human_identifier, text from corpus_iedocument where edittime NOTNULL AND jump_signal=0 ;"
  265. cursor.execute(sql)
  266. writer = open('label_data.txt', 'w', encoding='utf-8')
  267. datas = []
  268. for row in cursor.fetchall():
  269. docid = row[0]
  270. text = row[1]
  271. sql_lb = "select b.value from brat_bratannotation as b where document_id = '{}' and b.value like 'T%product%';".format(docid)
  272. cursor.execute(sql_lb)
  273. rows = cursor.fetchall()
  274. print('len(rows)', len(rows))
  275. datas.append((docid, text, json.dumps(rows, ensure_ascii=False), len(rows)))
  276. df = pd.DataFrame(datas, columns=['docid', 'text', 'rows', 'product_num'])
  277. df.to_excel('data/产品数据自己人标注的原始数据.xlsx')
  278. if __name__=="__main__":
  279. # import os
  280. import pickle
  281. # with open('data/dev_data2.pkl', 'rb') as f:
  282. # dev_data = pickle.load(f)
  283. # print(len(dev_data))
  284. # print(os.path.exists('data/testdata.xlsx'))
  285. # df = pd.read_excel('data/testdata.xlsx')
  286. # print(len(df))
  287. # data_test = df2data(df)
  288. # print(len(data_test), len(data_test[0][0]))
  289. # 获取原始标注数据()
  290. df = pd.read_excel('data/产品数据自己人标注的原始数据.xlsx')
  291. with open('data/dev_data2.pkl', 'rb') as f:
  292. dev_data = pickle.load(f)
  293. print(len(set(df['docid'])))
  294. print('')