model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. #! -*- coding:utf-8 -*-
  2. import os,sys
  3. parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  4. sys.path.insert(0,parentdir)
  5. import json
  6. import numpy as np
  7. from random import choice
  8. from tqdm import tqdm
  9. from BiddingKG.dl.common.models import save,load
  10. from itertools import groupby
  11. def seq_padding(X, padding=0):
  12. L = [len(x) for x in X]
  13. ML = max(L)
  14. return np.array([
  15. np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X
  16. ])
  17. from keras.layers import *
  18. from keras.models import Model
  19. import keras.backend as K
  20. from keras.callbacks import Callback
  21. from keras.optimizers import Adam
  22. def seq_gather(x):
  23. """seq是[None, seq_len, s_size]的格式,
  24. idxs是[None, 1]的格式,在seq的第i个序列中选出第idxs[i]个向量,
  25. 最终输出[None, s_size]的向量。
  26. """
  27. seq, idxs = x
  28. idxs = K.cast(idxs, 'int32')
  29. batch_idxs = K.arange(0, K.shape(seq)[0])
  30. batch_idxs = K.expand_dims(batch_idxs, 1)
  31. idxs = K.concatenate([batch_idxs, idxs], 1)
  32. return K.tf.gather_nd(seq, idxs)
  33. def seq_maxpool(x):
  34. """seq是[None, seq_len, s_size]的格式,
  35. mask是[None, seq_len, 1]的格式,先除去mask部分,
  36. 然后再做maxpooling。
  37. """
  38. seq, mask = x
  39. seq -= (1 - mask) * 1e10
  40. return K.max(seq, 1, keepdims=True)
  41. def dilated_gated_conv1d(seq, mask, dilation_rate=1):
  42. """膨胀门卷积(残差式)
  43. """
  44. dim = K.int_shape(seq)[-1]
  45. h = Conv1D(dim*2, 3, padding='same', dilation_rate=dilation_rate)(seq)
  46. def _gate(x):
  47. dropout_rate = 0.2
  48. s, h = x
  49. g, h = h[:, :, :dim], h[:, :, dim:]
  50. g = K.in_train_phase(K.dropout(g, dropout_rate), g)
  51. g = K.sigmoid(g)
  52. return g * s + (1 - g) * h
  53. seq = Lambda(_gate)([seq, h])
  54. seq = Lambda(lambda x: x[0] * x[1])([seq, mask])
  55. return seq
  56. class Attention(Layer):
  57. """多头注意力机制
  58. """
  59. def __init__(self, nb_head, size_per_head, **kwargs):
  60. self.nb_head = nb_head
  61. self.size_per_head = size_per_head
  62. self.out_dim = nb_head * size_per_head
  63. super(Attention, self).__init__(**kwargs)
  64. def build(self, input_shape):
  65. super(Attention, self).build(input_shape)
  66. q_in_dim = input_shape[0][-1]
  67. k_in_dim = input_shape[1][-1]
  68. v_in_dim = input_shape[2][-1]
  69. self.q_kernel = self.add_weight(name='q_kernel',
  70. shape=(q_in_dim, self.out_dim),
  71. initializer='glorot_normal')
  72. self.k_kernel = self.add_weight(name='k_kernel',
  73. shape=(k_in_dim, self.out_dim),
  74. initializer='glorot_normal')
  75. self.v_kernel = self.add_weight(name='w_kernel',
  76. shape=(v_in_dim, self.out_dim),
  77. initializer='glorot_normal')
  78. def mask(self, x, mask, mode='mul'):
  79. if mask is None:
  80. return x
  81. else:
  82. for _ in range(K.ndim(x) - K.ndim(mask)):
  83. mask = K.expand_dims(mask, K.ndim(mask))
  84. if mode == 'mul':
  85. return x * mask
  86. else:
  87. return x - (1 - mask) * 1e10
  88. def call(self, inputs):
  89. q, k, v = inputs[:3]
  90. v_mask, q_mask = None, None
  91. if len(inputs) > 3:
  92. v_mask = inputs[3]
  93. if len(inputs) > 4:
  94. q_mask = inputs[4]
  95. # 线性变换
  96. qw = K.dot(q, self.q_kernel)
  97. kw = K.dot(k, self.k_kernel)
  98. vw = K.dot(v, self.v_kernel)
  99. # 形状变换
  100. qw = K.reshape(qw, (-1, K.shape(qw)[1], self.nb_head, self.size_per_head))
  101. kw = K.reshape(kw, (-1, K.shape(kw)[1], self.nb_head, self.size_per_head))
  102. vw = K.reshape(vw, (-1, K.shape(vw)[1], self.nb_head, self.size_per_head))
  103. # 维度置换
  104. qw = K.permute_dimensions(qw, (0, 2, 1, 3))
  105. kw = K.permute_dimensions(kw, (0, 2, 1, 3))
  106. vw = K.permute_dimensions(vw, (0, 2, 1, 3))
  107. # Attention
  108. a = K.batch_dot(qw, kw, [3, 3]) / self.size_per_head**0.5
  109. a = K.permute_dimensions(a, (0, 3, 2, 1))
  110. a = self.mask(a, v_mask, 'add')
  111. a = K.permute_dimensions(a, (0, 3, 2, 1))
  112. a = K.softmax(a)
  113. # 完成输出
  114. o = K.batch_dot(a, vw, [3, 2])
  115. o = K.permute_dimensions(o, (0, 2, 1, 3))
  116. o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim))
  117. o = self.mask(o, q_mask, 'mul')
  118. return o
  119. def compute_output_shape(self, input_shape):
  120. return (input_shape[0][0], input_shape[0][1], self.out_dim)
  121. def position_id(x):
  122. if isinstance(x, list) and len(x) == 2:
  123. x, r = x
  124. else:
  125. r = 0
  126. pid = K.arange(K.shape(x)[1])
  127. pid = K.expand_dims(pid, 0)
  128. pid = K.tile(pid, [K.shape(x)[0], 1])
  129. return K.abs(pid - K.cast(r, 'int32'))
  130. entity_type_dict = {
  131. 'org': '<company/org>',
  132. 'company': '<company/org>',
  133. 'location': '<location>',
  134. 'phone': '<phone>',
  135. 'person': '<contact_person>'
  136. }
  137. class Relation_extraction():
  138. def __init__(self,is_train=False):
  139. self.is_train = is_train
  140. self.words_vocab = load(os.path.dirname(__file__)+'/../relation_extraction/words_vocab.pkl')
  141. id2word = {i: j for i, j in enumerate(self.words_vocab)}
  142. self.words2id = {j: i for i, j in id2word.items()}
  143. self.words_size = 128
  144. self.id2predicate = {
  145. 0: "rel_person", # 公司——联系人
  146. 1: "rel_phone", # 联系人——电话
  147. 2: "rel_address" # 公司——地址
  148. }
  149. self.predicate2id = dict({j: i for i, j in self.id2predicate.items()})
  150. self.num_classes = len(self.id2predicate)
  151. self.maxlen = 512
  152. self.word2vec = None
  153. if self.is_train:
  154. self.word2vec = load('words2v_matrix.pkl')
  155. self.model_path = os.path.dirname(__file__)+'/../relation_extraction/models/my_best_model_oneoutput.weights'
  156. self.get_model()
  157. if self.model_path:
  158. self.train_model.load_weights(self.model_path)
  159. def get_model(self):
  160. words_size = self.words_size
  161. t2_in = Input(shape=(None,))
  162. s1_in = Input(shape=(None,))
  163. k1_in = Input(shape=(1,))
  164. o1_in = Input(shape=(None, self.num_classes))
  165. t2, s1, k1, o1 = t2_in, s1_in, k1_in, o1_in
  166. mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(t2)
  167. pid = Lambda(position_id)(t2)
  168. position_embedding = Embedding(self.maxlen, words_size, embeddings_initializer='zeros')
  169. pv = position_embedding(pid)
  170. t2 = Embedding(len(self.words2id), words_size, weights=[self.word2vec] if self.is_train else None, trainable=True)(t2)
  171. t = Add()([t2, pv])
  172. t = Dropout(0.25)(t)
  173. t = Lambda(lambda x: x[0] * x[1])([t, mask])
  174. if K.tensorflow_backend._get_available_gpus():
  175. t = Bidirectional(CuDNNGRU(64, return_sequences=True))(t)
  176. else:
  177. t = Bidirectional(GRU(64,return_sequences=True,reset_after=True))(t)
  178. t_dim = K.int_shape(t)[-1]
  179. pn1 = Dense(words_size, activation='relu')(t)
  180. pn1 = Dense(1, activation='sigmoid')(pn1)
  181. h = Attention(8, 16)([t, t, t, mask])
  182. h = Concatenate()([t, h])
  183. h = Conv1D(words_size, 3, activation='relu', padding='same')(h)
  184. ps1 = Dense(1, activation='sigmoid')(h)
  185. ps1 = Lambda(lambda x: x[0] * x[1])([ps1, pn1])
  186. self.subject_model = Model([t2_in], [ps1]) # 预测subject的模型
  187. t_max = Lambda(seq_maxpool)([t, mask])
  188. pc = Dense(words_size, activation='relu')(t_max)
  189. pc = Dense(self.num_classes, activation='sigmoid')(pc)
  190. def get_k_inter(x, n=6):
  191. seq, k1 = x
  192. # k_inter = [K.round(k1 * a + k2 * (1 - a)) for a in np.arange(n) / (n - 1.)]
  193. k_inter = [seq_gather([seq, k1])] * 2
  194. k_inter = [K.expand_dims(k, 1) for k in k_inter]
  195. k_inter = K.concatenate(k_inter, 1)
  196. return k_inter
  197. k = Lambda(get_k_inter, output_shape=(2, t_dim))([t, k1])
  198. if K.tensorflow_backend._get_available_gpus():
  199. k = Bidirectional(CuDNNGRU(t_dim))(k)
  200. else:
  201. k = Bidirectional(GRU(t_dim, reset_after=True))(k)
  202. k1v = position_embedding(Lambda(position_id)([t, k1]))
  203. kv = Concatenate()([k1v, k1v])
  204. k = Lambda(lambda x: K.expand_dims(x[0], 1) + x[1])([k, kv])
  205. h = Attention(8, 16)([t, t, t, mask])
  206. h = Concatenate()([t, h, k])
  207. h = Conv1D(words_size, 3, activation='relu', padding='same')(h)
  208. po = Dense(1, activation='sigmoid')(h)
  209. po1 = Dense(self.num_classes, activation='sigmoid')(h)
  210. po1 = Lambda(lambda x: x[0] * x[1] * x[2] * x[3])([po, po1, pc, pn1])
  211. self.object_model = Model([t2_in, k1_in], [po1])
  212. train_model = Model([t2_in, s1_in, k1_in, o1_in],
  213. [ps1, po1])
  214. # loss
  215. s1 = K.expand_dims(s1, 2)
  216. s1_loss = K.binary_crossentropy(s1, ps1)
  217. s1_loss = K.sum(s1_loss * mask) / K.sum(mask)
  218. o1_loss = K.sum(K.binary_crossentropy(o1, po1), 2, keepdims=True)
  219. o1_loss = K.sum(o1_loss * mask) / K.sum(mask)
  220. loss = s1_loss + o1_loss
  221. train_model.add_loss(loss)
  222. train_model.compile(optimizer=Adam(1e-3))
  223. # train_model.summary()
  224. self.train_model = train_model
  225. def extract_items(self,text_in, words, rate=0.5):
  226. text_words = text_in
  227. R = []
  228. _t2 = [self.words2id.get(c, 1) for c in words]
  229. _t2 = np.array([_t2])
  230. _k1 = self.subject_model.predict([_t2])
  231. _k1 = _k1[0, :, 0]
  232. _k1 = np.where(_k1 > rate)[0]
  233. _subjects = []
  234. for i in _k1:
  235. _subject = text_in[i]
  236. _subjects.append((_subject, i, i))
  237. if _subjects:
  238. _t2 = np.repeat(_t2, len(_subjects), 0)
  239. _k1, _ = np.array([_s[1:] for _s in _subjects]).T.reshape((2, -1, 1))
  240. _o1 = self.object_model.predict([_t2, _k1])
  241. for i, _subject in enumerate(_subjects):
  242. _oo1 = np.where(_o1[i] > 0.5)
  243. for _ooo1, _c1 in zip(*_oo1):
  244. _object = text_in[_ooo1]
  245. _predicate = self.id2predicate[_c1]
  246. R.append((_subject[0], _predicate, _object))
  247. return R
  248. else:
  249. return []
  250. def predict(self,text, words):
  251. res = self.extract_items(text,words)
  252. return res
  253. @staticmethod
  254. def get_predata(entity_list,list_sentence):
  255. list_sentence = sorted(list_sentence, key=lambda x: x.sentence_index)
  256. entity_list = sorted(entity_list,key=lambda x:(x.sentence_index,x.begin_index))
  257. pre_data = []
  258. text_data = []
  259. last_sentence_index = -1
  260. for key, group in groupby(entity_list,key=lambda x:x.sentence_index):
  261. if key-last_sentence_index>1:
  262. for i in range(last_sentence_index+1,key):
  263. pre_data.extend(list_sentence[i].tokens)
  264. text_data.extend([0]*len(list_sentence[i].tokens))
  265. group = list(group)
  266. for i in range(len(group)):
  267. ent = group[i]
  268. _tokens = list_sentence[key].tokens
  269. if i==len(group)-1:
  270. if i==0:
  271. pre_data.extend(_tokens[:ent.begin_index])
  272. text_data.extend([0]*len(_tokens[:ent.begin_index]))
  273. pre_data.append(entity_type_dict[ent.entity_type])
  274. text_data.append(ent)
  275. pre_data.extend(_tokens[ent.end_index+1:])
  276. text_data.extend([0]*len(_tokens[ent.end_index+1:]))
  277. break
  278. else:
  279. pre_data.append(entity_type_dict[ent.entity_type])
  280. text_data.append(ent)
  281. pre_data.extend(_tokens[ent.end_index+1:])
  282. text_data.extend([0]*len(_tokens[ent.end_index+1:]))
  283. break
  284. if i==0:
  285. pre_data.extend(_tokens[:ent.begin_index])
  286. text_data.extend([0] * len(_tokens[:ent.begin_index]))
  287. pre_data.append(entity_type_dict[ent.entity_type])
  288. text_data.append(ent)
  289. pre_data.extend(_tokens[ent.end_index+1:group[i+1].begin_index])
  290. text_data.extend([0] * len(_tokens[ent.end_index+1:group[i+1].begin_index]))
  291. else:
  292. pre_data.append(entity_type_dict[ent.entity_type])
  293. text_data.append(ent)
  294. pre_data.extend(_tokens[ent.end_index+1:group[i + 1].begin_index])
  295. text_data.extend([0] * len(_tokens[ent.end_index+1:group[i+1].begin_index]))
  296. last_sentencee_index = key
  297. return text_data,pre_data
  298. if __name__ == '__main__':
  299. test_model = Relation_extraction()
  300. text_in = "索引||号||:||014583788||/||2018-00038||,||成文||日期||:||2018-11-19||,||关于||国家税务总局都昌县税务局||办公楼||七||楼||会议室||维修||改造||项目||综合||比价||成交||公告||,||关于||国家税务总局都昌县税务局||办公楼七楼会议室||维修||改造||项目||(||比价||编号||:||JXXL2018-JJ-DC001||)||综合||比价||成交||公告||,||江西新立建设管理有限公司九江分公司||受||国家税务总局都昌县税务局||委托||,||就||其||办公楼||七||楼||会议室||维修||改造||项目||(||控制||价||:||294788.86||元||)||进行||综合||比价||方式||,||比价||活动||于||2018年||11月||16日||15:30||在||都昌县万里大道和平宾馆旁三楼||江西新立建设管理有限公司九江分公司||进行||,||经||比价||小组||评审||,||比价人||确定||,||现||将||比价||结果||公式||如下||:||序号||:||1||,||比价||编号||,||JXXL2018-JJ-DC001||,||项目||内容||名称||,||都昌县税务局||办公楼||七||楼||会议室||维修||改造||项目||,||数量||:||1||,||成交||供应商||名称||,||江西芙蓉建筑工程有限公司||,||成交价||(||元||)||,||284687.67||。||一||、||比价||小组||成员||:||杨忠辉||李燕杨瑾||,||本||公告||自||发布||之||日||起||1||个||工作日||内||若||无||异议||,||将||向||中标人||发出||《||成交||通知书||》||,||二||、||联系||方式||,||单位||:||国家税务总局都昌县税务局||,||比价||代理||机构||:||江西新立建设管理有限公司九江分公司||,||联系人||:||詹女士||,||电话||:||15979976088||,||江西新立建设管理有限公司九江分公司"
  301. words = "索引||号||:||014583788||/||2018-00038||,||成文||日期||:||2018-11-19||,||关于||国家税务总局都昌县税务局||" \
  302. "办公楼||七||楼||会议室||维修||改造||项目||综合||比价||成交||公告||,||关于||国家税务总局都昌县税务局||办公楼七楼会议室||" \
  303. "维修||改造||项目||(||比价||编号||:||JXXL2018-JJ-DC001||)||综合||比价||成交||公告||,||<company/org>||" \
  304. "受||国家税务总局都昌县税务局||委托||,||就||其||办公楼||七||楼||会议室||维修||改造||项目||(||控制||价||:||294788.86||元||)||" \
  305. "进行||综合||比价||方式||,||比价||活动||于||2018年||11月||16日||15:30||在||都昌县万里大道和平宾馆旁三楼||<company/org>||" \
  306. "进行||,||经||比价||小组||评审||,||比价人||确定||,||现||将||比价||结果||公式||如下||:||序号||:||1||,||比价||编号||," \
  307. "||JXXL2018-JJ-DC001||,||项目||内容||名称||,||都昌县税务局||办公楼||七||楼||会议室||维修||改造||项目||,||数量||:||1||,||成交||" \
  308. "供应商||名称||,||<company/org>||,||成交价||(||元||)||,||284687.67||。||一||、||比价||小组||成员||:||杨忠辉||李燕杨瑾||," \
  309. "||本||公告||自||发布||之||日||起||1||个||工作日||内||若||无||异议||,||将||向||中标人||发出||《||成交||通知书||》||,||二||、||联系||方式||," \
  310. "||单位||:||<company/org>||,||比价||代理||机构||:||<company/org>||,||联系人||:||<contact_person>||,||电话||:||<phone>||,||江西新立建设管理有限公司九江分公司"
  311. text_in = "索引"
  312. words = "索引"
  313. res = test_model.predict(text_in.split("||"),words.split("||"))
  314. print(res)
  315. print(test_model.predict(text_in.split("||"),words.split("||")))
  316. print(test_model.predict(text_in.split("||"),words.split("||")))