Utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784
  1. '''
  2. Created on 2018年12月20日
  3. @author: User
  4. '''
  5. import numpy as np
  6. import re
  7. import gensim
  8. from keras import backend as K
  9. import os
  10. from threading import RLock
  11. # from pai_tf_predict_proto import tf_predict_pb2
  12. import requests
  13. model_w2v = None
  14. lock_model_w2v = RLock()
  15. USE_PAI_EAS = False
  16. Lazy_load = False
  17. def getw2vfilepath():
  18. w2vfile = os.path.dirname(__file__)+"/../wiki_128_word_embedding_new.vector"
  19. if os.path.exists(w2vfile):
  20. return w2vfile
  21. return "wiki_128_word_embedding_new.vector"
  22. def getLazyLoad():
  23. global Lazy_load
  24. return Lazy_load
  25. model_word_file = os.path.dirname(__file__)+"/../singlew2v_model.vector"
  26. model_word = None
  27. lock_model_word = RLock()
  28. from decimal import Decimal
  29. import logging
  30. logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  31. logger = logging.getLogger(__name__)
  32. import pickle
  33. import os
  34. import json
  35. #自定义jsonEncoder
  36. class MyEncoder(json.JSONEncoder):
  37. def __init__(self):
  38. import numpy as np
  39. global np
  40. def default(self, obj):
  41. if isinstance(obj, np.ndarray):
  42. return obj.tolist()
  43. elif isinstance(obj, bytes):
  44. return str(obj, encoding='utf-8')
  45. elif isinstance(obj, (np.float_, np.float16, np.float32,
  46. np.float64)):
  47. return float(obj)
  48. elif isinstance(obj,(np.int64,np.int32)):
  49. return int(obj)
  50. return json.JSONEncoder.default(self, obj)
  51. vocab_word = None
  52. vocab_words = None
  53. file_vocab_word = "vocab_word.pk"
  54. file_vocab_words = "vocab_words.pk"
  55. selffool_authorization = "NjlhMWFjMjVmNWYyNzI0MjY1OGQ1M2Y0ZmY4ZGY0Mzg3Yjc2MTVjYg=="
  56. selffool_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/selffool_gpu"
  57. selffool_seg_authorization = "OWUwM2Q0ZmE3YjYxNzU4YzFiMjliNGVkMTA3MzJkNjQ2MzJiYzBhZg=="
  58. selffool_seg_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/selffool_seg_gpu"
  59. codename_authorization = "Y2M5MDUxMzU1MTU4OGM3ZDk2ZmEzYjkxYmYyYzJiZmUyYTgwYTg5NA=="
  60. codename_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/codename_gpu"
  61. form_item_authorization = "ODdkZWY1YWY0NmNhNjU2OTI2NWY4YmUyM2ZlMDg1NTZjOWRkYTVjMw=="
  62. form_item_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/form"
  63. person_authorization = "N2I2MDU2N2Q2MGQ0ZWZlZGM3NDkyNTA1Nzc4YmM5OTlhY2MxZGU1Mw=="
  64. person_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/person"
  65. role_authorization = "OWM1ZDg5ZDEwYTEwYWI4OGNjYmRlMmQ1NzYwNWNlZGZkZmRmMjE4OQ=="
  66. role_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/role"
  67. money_authorization = "MDQyNjc2ZDczYjBhYmM4Yzc4ZGI4YjRmMjc3NGI5NTdlNzJiY2IwZA=="
  68. money_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/money"
  69. codeclasses_authorization = "MmUyNWIxZjQ2NjAzMWJlMGIzYzkxMjMzNWY5OWI3NzJlMWQ1ZjY4Yw=="
  70. codeclasses_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/codeclasses"
  71. def viterbi_decode(score, transition_params):
  72. """Decode the highest scoring sequence of tags outside of TensorFlow.
  73. This should only be used at test time.
  74. Args:
  75. score: A [seq_len, num_tags] matrix of unary potentials.
  76. transition_params: A [num_tags, num_tags] matrix of binary potentials.
  77. Returns:
  78. viterbi: A [seq_len] list of integers containing the highest scoring tag
  79. indices.
  80. viterbi_score: A float containing the score for the Viterbi sequence.
  81. """
  82. trellis = np.zeros_like(score)
  83. backpointers = np.zeros_like(score, dtype=np.int32)
  84. trellis[0] = score[0]
  85. for t in range(1, score.shape[0]):
  86. v = np.expand_dims(trellis[t - 1], 1) + transition_params
  87. trellis[t] = score[t] + np.max(v, 0)
  88. backpointers[t] = np.argmax(v, 0)
  89. viterbi = [np.argmax(trellis[-1])]
  90. for bp in reversed(backpointers[1:]):
  91. viterbi.append(bp[viterbi[-1]])
  92. viterbi.reverse()
  93. viterbi_score = np.max(trellis[-1])
  94. return viterbi, viterbi_score
  95. def limitRun(sess,list_output,feed_dict,MAX_BATCH=1024):
  96. len_sample = 0
  97. if len(feed_dict.keys())>0:
  98. len_sample = len(feed_dict[list(feed_dict.keys())[0]])
  99. if len_sample>MAX_BATCH:
  100. list_result = [[] for _ in range(len(list_output))]
  101. _begin = 0
  102. while(_begin<len_sample):
  103. new_dict = dict()
  104. for _key in feed_dict.keys():
  105. new_dict[_key] = feed_dict[_key][_begin:_begin+MAX_BATCH]
  106. _output = sess.run(list_output,feed_dict=new_dict)
  107. for _index in range(len(list_output)):
  108. list_result[_index].extend(_output[_index])
  109. _begin += MAX_BATCH
  110. else:
  111. list_result = sess.run(list_output,feed_dict=feed_dict)
  112. return list_result
  113. def get_values(response,output_name):
  114. """
  115. Get the value of a specified output tensor
  116. :param output_name: name of the output tensor
  117. :return: the content of the output tensor
  118. """
  119. output = response.outputs[output_name]
  120. if output.dtype == tf_predict_pb2.DT_FLOAT:
  121. _value = output.float_val
  122. elif output.dtype == tf_predict_pb2.DT_INT8 or output.dtype == tf_predict_pb2.DT_INT16 or \
  123. output.dtype == tf_predict_pb2.DT_INT32:
  124. _value = output.int_val
  125. elif output.dtype == tf_predict_pb2.DT_INT64:
  126. _value = output.int64_val
  127. elif output.dtype == tf_predict_pb2.DT_DOUBLE:
  128. _value = output.double_val
  129. elif output.dtype == tf_predict_pb2.DT_STRING:
  130. _value = output.string_val
  131. elif output.dtype == tf_predict_pb2.DT_BOOL:
  132. _value = output.bool_val
  133. return np.array(_value).reshape(response.outputs[output_name].array_shape.dim)
  134. def vpc_requests(url,authorization,request_data,list_outputs):
  135. headers = {"Authorization": authorization}
  136. dict_outputs = dict()
  137. response = tf_predict_pb2.PredictResponse()
  138. resp = requests.post(url, data=request_data, headers=headers)
  139. if resp.status_code != 200:
  140. print(resp.status_code,resp.content)
  141. log("调用pai-eas接口出错,authorization:"+str(authorization))
  142. return None
  143. else:
  144. response = tf_predict_pb2.PredictResponse()
  145. response.ParseFromString(resp.content)
  146. for _output in list_outputs:
  147. dict_outputs[_output] = get_values(response, _output)
  148. return dict_outputs
  149. def encodeInput(data,word_len,word_flag=True,userFool=False):
  150. result = []
  151. out_index = 0
  152. for item in data:
  153. if out_index in [0]:
  154. list_word = item[-word_len:]
  155. else:
  156. list_word = item[:word_len]
  157. temp = []
  158. if word_flag:
  159. for word in list_word:
  160. if userFool:
  161. temp.append(getIndexOfWord_fool(word))
  162. else:
  163. temp.append(getIndexOfWord(word))
  164. list_append = []
  165. temp_len = len(temp)
  166. while(temp_len<word_len):
  167. if userFool:
  168. list_append.append(0)
  169. else:
  170. list_append.append(getIndexOfWord("<pad>"))
  171. temp_len += 1
  172. if out_index in [0]:
  173. temp = list_append+temp
  174. else:
  175. temp = temp+list_append
  176. else:
  177. for words in list_word:
  178. temp.append(getIndexOfWords(words))
  179. list_append = []
  180. temp_len = len(temp)
  181. while(temp_len<word_len):
  182. list_append.append(getIndexOfWords("<pad>"))
  183. temp_len += 1
  184. if out_index in [0,1]:
  185. temp = list_append+temp
  186. else:
  187. temp = temp+list_append
  188. result.append(temp)
  189. out_index += 1
  190. return result
  191. def encodeInput_form(input,MAX_LEN=30):
  192. x = np.zeros([MAX_LEN])
  193. for i in range(len(input)):
  194. if i>=MAX_LEN:
  195. break
  196. x[i] = getIndexOfWord(input[i])
  197. return x
  198. def getVocabAndMatrix(model,Embedding_size = 60):
  199. '''
  200. @summary:获取子向量的词典和子向量矩阵
  201. '''
  202. vocab = ["<pad>"]+model.index2word
  203. embedding_matrix = np.zeros((len(vocab),Embedding_size))
  204. for i in range(1,len(vocab)):
  205. embedding_matrix[i] = model[vocab[i]]
  206. return vocab,embedding_matrix
  207. def getIndexOfWord(word):
  208. global vocab_word,file_vocab_word
  209. if vocab_word is None:
  210. if os.path.exists(file_vocab_word):
  211. vocab = load(file_vocab_word)
  212. vocab_word = dict((w, i) for i, w in enumerate(np.array(vocab)))
  213. else:
  214. model = getModel_word()
  215. vocab,_ = getVocabAndMatrix(model, Embedding_size=60)
  216. vocab_word = dict((w, i) for i, w in enumerate(np.array(vocab)))
  217. save(vocab,file_vocab_word)
  218. if word in vocab_word.keys():
  219. return vocab_word[word]
  220. else:
  221. return vocab_word['<pad>']
  222. def changeIndexFromWordToWords(tokens,word_index):
  223. '''
  224. @summary:转换某个字的字偏移为词偏移
  225. '''
  226. before_index = 0
  227. after_index = 0
  228. for i in range(len(tokens)):
  229. after_index = after_index+len(tokens[i])
  230. if before_index<=word_index and after_index>=word_index:
  231. return i
  232. before_index = after_index
  233. def getIndexOfWords(words):
  234. global vocab_words,file_vocab_words
  235. if vocab_words is None:
  236. if os.path.exists(file_vocab_words):
  237. vocab = load(file_vocab_words)
  238. vocab_words = dict((w, i) for i, w in enumerate(np.array(vocab)))
  239. else:
  240. model = getModel_w2v()
  241. vocab,_ = getVocabAndMatrix(model, Embedding_size=128)
  242. vocab_words = dict((w, i) for i, w in enumerate(np.array(vocab)))
  243. save(vocab,file_vocab_words)
  244. if words in vocab_words.keys():
  245. return vocab_words[words]
  246. else:
  247. return vocab_words["<pad>"]
  248. def log(msg):
  249. '''
  250. @summary:打印信息
  251. '''
  252. logger.info(msg)
  253. def debug(msg):
  254. '''
  255. @summary:打印信息
  256. '''
  257. logger.debug(msg)
  258. def save(object_to_save, path):
  259. '''
  260. 保存对象
  261. @Arugs:
  262. object_to_save: 需要保存的对象
  263. @Return:
  264. 保存的路径
  265. '''
  266. with open(path, 'wb') as f:
  267. pickle.dump(object_to_save, f)
  268. def load(path):
  269. '''
  270. 读取对象
  271. @Arugs:
  272. path: 读取的路径
  273. @Return:
  274. 读取的对象
  275. '''
  276. with open(path, 'rb') as f:
  277. object1 = pickle.load(f)
  278. return object1
  279. fool_char_to_id = load(os.path.dirname(__file__)+"/fool_char_to_id.pk")
  280. def getIndexOfWord_fool(word):
  281. if word in fool_char_to_id.keys():
  282. return fool_char_to_id[word]
  283. else:
  284. return fool_char_to_id["[UNK]"]
  285. def find_index(list_tofind,text):
  286. '''
  287. @summary: 查找所有词汇在字符串中第一次出现的位置
  288. @param:
  289. list_tofind:待查找词汇
  290. text:字符串
  291. @return: list,每个词汇第一次出现的位置
  292. '''
  293. result = []
  294. for item in list_tofind:
  295. index = text.find(item)
  296. if index>=0:
  297. result.append(index)
  298. else:
  299. result.append(-1)
  300. return result
  301. def combine(list1,list2):
  302. '''
  303. @summary:将两个list中的字符串两两拼接
  304. @param:
  305. list1:字符串list
  306. list2:字符串list
  307. @return:拼接结果list
  308. '''
  309. result = []
  310. for item1 in list1:
  311. for item2 in list2:
  312. result.append(str(item1)+str(item2))
  313. return result
  314. def getDigitsDic(unit):
  315. '''
  316. @summary:拿到中文对应的数字
  317. '''
  318. DigitsDic = {"零":0, "壹":1, "贰":2, "叁":3, "肆":4, "伍":5, "陆":6, "柒":7, "捌":8, "玖":9,
  319. "〇":0, "一":1, "二":2, "三":3, "四":4, "五":5, "六":6, "七":7, "八":8, "九":9}
  320. return DigitsDic.get(unit)
  321. def getMultipleFactor(unit):
  322. '''
  323. @summary:拿到单位对应的值
  324. '''
  325. MultipleFactor = {"兆":Decimal(1000000000000),"亿":Decimal(100000000),"万":Decimal(10000),"仟":Decimal(1000),"千":Decimal(1000),"佰":Decimal(100),"百":Decimal(100),"拾":Decimal(10),"十":Decimal(10),"元":Decimal(1),"圆":Decimal(1),"角":round(Decimal(0.1),1),"分":round(Decimal(0.01),2)}
  326. return MultipleFactor.get(unit)
  327. def getUnifyMoney(money):
  328. '''
  329. @summary:将中文金额字符串转换为数字金额
  330. @param:
  331. money:中文金额字符串
  332. @return: decimal,数据金额
  333. '''
  334. MAX_MONEY = 1000000000000
  335. MAX_NUM = 12
  336. #去掉逗号
  337. money = re.sub("[,,]","",money)
  338. money = re.sub("[^0-9.零壹贰叁肆伍陆柒捌玖拾佰仟萬億圆十百千万亿元角分]","",money)
  339. result = Decimal(0)
  340. chnDigits = ["零", "壹", "贰", "叁", "肆", "伍", "陆", "柒", "捌", "玖"]
  341. chnFactorUnits = ["兆", "亿", "万", "仟", "佰", "拾","圆","元","角","分"]
  342. LowMoneypattern = re.compile("^[\d,]+(\.\d+)?$")
  343. BigMoneypattern = re.compile("^零?(?P<BigMoney>[%s])$"%("".join(chnDigits)))
  344. try:
  345. if re.search(LowMoneypattern,money) is not None:
  346. return Decimal(money)
  347. elif re.search(BigMoneypattern,money) is not None:
  348. return getDigitsDic(re.search(BigMoneypattern,money).group("BigMoney"))
  349. for factorUnit in chnFactorUnits:
  350. if re.search(re.compile(".*%s.*"%(factorUnit)),money) is not None:
  351. subMoneys = re.split(re.compile("%s(?!.*%s.*)"%(factorUnit,factorUnit)),money)
  352. if re.search(re.compile("^(\d+)(\.\d+)?$"),subMoneys[0]) is not None:
  353. if MAX_MONEY/getMultipleFactor(factorUnit)<Decimal(subMoneys[0]):
  354. return Decimal(0)
  355. result += Decimal(subMoneys[0])*(getMultipleFactor(factorUnit))
  356. elif len(subMoneys[0])==1:
  357. if re.search(re.compile("^[%s]$"%("".join(chnDigits))),subMoneys[0]) is not None:
  358. result += Decimal(getDigitsDic(subMoneys[0]))*(getMultipleFactor(factorUnit))
  359. else:
  360. result += Decimal(getUnifyMoney(subMoneys[0]))*(getMultipleFactor(factorUnit))
  361. if len(subMoneys)>1:
  362. if re.search(re.compile("^(\d+(,)?)+(\.\d+)?[百千万亿]?\s?(元)?$"),subMoneys[1]) is not None:
  363. result += Decimal(subMoneys[1])
  364. elif len(subMoneys[1])==1:
  365. if re.search(re.compile("^[%s]$"%("".join(chnDigits))),subMoneys[1]) is not None:
  366. result += Decimal(getDigitsDic(subMoneys[1]))
  367. else:
  368. result += Decimal(getUnifyMoney(subMoneys[1]))
  369. break
  370. except Exception as e:
  371. return Decimal(0)
  372. return result
  373. def getModel_w2v():
  374. '''
  375. @summary:加载词向量
  376. '''
  377. global model_w2v,lock_model_w2v
  378. with lock_model_w2v:
  379. if model_w2v is None:
  380. model_w2v = gensim.models.KeyedVectors.load_word2vec_format(getw2vfilepath(),binary=True)
  381. return model_w2v
  382. def getModel_word():
  383. '''
  384. @summary:加载字向量
  385. '''
  386. global model_word,lock_model_w2v
  387. with lock_model_word:
  388. if model_word is None:
  389. model_word = gensim.models.KeyedVectors.load_word2vec_format(model_word_file,binary=True)
  390. return model_word
  391. # getModel_w2v()
  392. # getModel_word()
  393. def findAllIndex(substr,wholestr):
  394. '''
  395. @summary: 找到字符串的子串的所有begin_index
  396. @param:
  397. substr:子字符串
  398. wholestr:子串所在完整字符串
  399. @return: list,字符串的子串的所有begin_index
  400. '''
  401. copystr = wholestr
  402. result = []
  403. indexappend = 0
  404. while(True):
  405. index = copystr.find(substr)
  406. if index<0:
  407. break
  408. else:
  409. result.append(indexappend+index)
  410. indexappend += index+len(substr)
  411. copystr = copystr[index+len(substr):]
  412. return result
  413. def spanWindow(tokens,begin_index,end_index,size,center_include=False,word_flag = False,use_text = False,text = None):
  414. '''
  415. @summary:取得某个实体的上下文词汇
  416. @param:
  417. tokens:句子分词list
  418. begin_index:实体的开始index
  419. end_index:实体的结束index
  420. size:左右两边各取多少个词
  421. center_include:是否包含实体
  422. word_flag:词/字,默认是词
  423. @return: list,实体的上下文词汇
  424. '''
  425. if use_text:
  426. assert text is not None
  427. length_tokens = len(tokens)
  428. if begin_index>size:
  429. begin = begin_index-size
  430. else:
  431. begin = 0
  432. if end_index+size<length_tokens:
  433. end = end_index+size+1
  434. else:
  435. end = length_tokens
  436. result = []
  437. if not word_flag:
  438. result.append(tokens[begin:begin_index])
  439. if center_include:
  440. if use_text:
  441. result.append(text)
  442. else:
  443. result.append(tokens[begin_index:end_index+1])
  444. result.append(tokens[end_index+1:end])
  445. else:
  446. result.append("".join(tokens[begin:begin_index]))
  447. if center_include:
  448. if use_text:
  449. result.append(text)
  450. else:
  451. result.append("".join(tokens[begin_index:end_index+1]))
  452. result.append("".join(tokens[end_index+1:end]))
  453. #print(result)
  454. return result
  455. #根据规则补全编号或名称两边的符号
  456. def fitDataByRule(data):
  457. symbol_dict = {"(":")",
  458. "(":")",
  459. "[":"]",
  460. "【":"】",
  461. ")":"(",
  462. ")":"(",
  463. "]":"[",
  464. "】":"【"}
  465. leftSymbol_pattern = re.compile("[\((\[【]")
  466. rightSymbol_pattern = re.compile("[\))\]】]")
  467. leftfinds = re.findall(leftSymbol_pattern,data)
  468. rightfinds = re.findall(rightSymbol_pattern,data)
  469. result = data
  470. if len(leftfinds)+len(rightfinds)==0:
  471. return data
  472. elif len(leftfinds)==len(rightfinds):
  473. return data
  474. elif abs(len(leftfinds)-len(rightfinds))==1:
  475. if len(leftfinds)>len(rightfinds):
  476. if symbol_dict.get(data[0]) is not None:
  477. result = data[1:]
  478. else:
  479. #print(symbol_dict.get(leftfinds[0]))
  480. result = data+symbol_dict.get(leftfinds[0])
  481. else:
  482. if symbol_dict.get(data[-1]) is not None:
  483. result = data[:-1]
  484. else:
  485. result = symbol_dict.get(rightfinds[0])+data
  486. result = re.sub("[。]","",result)
  487. return result
  488. def embedding(datas,shape):
  489. '''
  490. @summary:查找词汇对应的词向量
  491. @param:
  492. datas:词汇的list
  493. shape:结果的shape
  494. @return: array,返回对应shape的词嵌入
  495. '''
  496. model_w2v = getModel_w2v()
  497. embed = np.zeros(shape)
  498. length = shape[1]
  499. out_index = 0
  500. #print(datas)
  501. for data in datas:
  502. index = 0
  503. for item in data:
  504. item_not_space = re.sub("\s*","",item)
  505. if index>=length:
  506. break
  507. if item_not_space in model_w2v.vocab:
  508. embed[out_index][index] = model_w2v[item_not_space]
  509. index += 1
  510. else:
  511. #embed[out_index][index] = model_w2v['unk']
  512. index += 1
  513. out_index += 1
  514. return embed
  515. def embedding_word(datas,shape):
  516. '''
  517. @summary:查找词汇对应的词向量
  518. @param:
  519. datas:词汇的list
  520. shape:结果的shape
  521. @return: array,返回对应shape的词嵌入
  522. '''
  523. model_w2v = getModel_word()
  524. embed = np.zeros(shape)
  525. length = shape[1]
  526. out_index = 0
  527. #print(datas)
  528. for data in datas:
  529. index = 0
  530. for item in str(data)[-shape[1]:]:
  531. if index>=length:
  532. break
  533. if item in model_w2v.vocab:
  534. embed[out_index][index] = model_w2v[item]
  535. index += 1
  536. else:
  537. # embed[out_index][index] = model_w2v['unk']
  538. index += 1
  539. out_index += 1
  540. return embed
  541. def formEncoding(text,shape=(100,60),expand=False):
  542. embedding = np.zeros(shape)
  543. word_model = getModel_word()
  544. for i in range(len(text)):
  545. if i>=shape[0]:
  546. break
  547. if text[i] in word_model.vocab:
  548. embedding[i] = word_model[text[i]]
  549. if expand:
  550. embedding = np.expand_dims(embedding,0)
  551. return embedding
  552. def partMoney(entity_text,input2_shape = [7]):
  553. '''
  554. @summary:对金额分段
  555. @param:
  556. entity_text:数值金额
  557. input2_shape:分类数
  558. @return: array,分段之后的独热编码
  559. '''
  560. money = float(entity_text)
  561. parts = np.zeros(input2_shape)
  562. if money<100:
  563. parts[0] = 1
  564. elif money<1000:
  565. parts[1] = 1
  566. elif money<10000:
  567. parts[2] = 1
  568. elif money<100000:
  569. parts[3] = 1
  570. elif money<1000000:
  571. parts[4] = 1
  572. elif money<10000000:
  573. parts[5] = 1
  574. else:
  575. parts[6] = 1
  576. return parts
  577. def recall(y_true, y_pred):
  578. '''
  579. 计算召回率
  580. @Argus:
  581. y_true: 正确的标签
  582. y_pred: 模型预测的标签
  583. @Return
  584. 召回率
  585. '''
  586. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  587. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  588. if c3 == 0:
  589. return 0
  590. recall = c1 / c3
  591. return recall
  592. def f1_score(y_true, y_pred):
  593. '''
  594. 计算F1
  595. @Argus:
  596. y_true: 正确的标签
  597. y_pred: 模型预测的标签
  598. @Return
  599. F1值
  600. '''
  601. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  602. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  603. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  604. precision = c1 / c2
  605. if c3 == 0:
  606. recall = 0
  607. else:
  608. recall = c1 / c3
  609. f1_score = 2 * (precision * recall) / (precision + recall)
  610. return f1_score
  611. def precision(y_true, y_pred):
  612. '''
  613. 计算精确率
  614. @Argus:
  615. y_true: 正确的标签
  616. y_pred: 模型预测的标签
  617. @Return
  618. 精确率
  619. '''
  620. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  621. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  622. precision = c1 / c2
  623. return precision
  624. # def print_metrics(history):
  625. # '''
  626. # 制作每次迭代的各metrics变化图片
  627. #
  628. # @Arugs:
  629. # history: 模型训练迭代的历史记录
  630. # '''
  631. # import matplotlib.pyplot as plt
  632. #
  633. # # loss图
  634. # loss = history.history['loss']
  635. # val_loss = history.history['val_loss']
  636. # epochs = range(1, len(loss) + 1)
  637. # plt.subplot(2, 2, 1)
  638. # plt.plot(epochs, loss, 'bo', label='Training loss')
  639. # plt.plot(epochs, val_loss, 'b', label='Validation loss')
  640. # plt.title('Training and validation loss')
  641. # plt.xlabel('Epochs')
  642. # plt.ylabel('Loss')
  643. # plt.legend()
  644. #
  645. # # f1图
  646. # f1 = history.history['f1_score']
  647. # val_f1 = history.history['val_f1_score']
  648. # plt.subplot(2, 2, 2)
  649. # plt.plot(epochs, f1, 'bo', label='Training f1')
  650. # plt.plot(epochs, val_f1, 'b', label='Validation f1')
  651. # plt.title('Training and validation f1')
  652. # plt.xlabel('Epochs')
  653. # plt.ylabel('F1')
  654. # plt.legend()
  655. #
  656. # # precision图
  657. # prec = history.history['precision']
  658. # val_prec = history.history['val_precision']
  659. # plt.subplot(2, 2, 3)
  660. # plt.plot(epochs, prec, 'bo', label='Training precision')
  661. # plt.plot(epochs, val_prec, 'b', label='Validation pecision')
  662. # plt.title('Training and validation precision')
  663. # plt.xlabel('Epochs')
  664. # plt.ylabel('Precision')
  665. # plt.legend()
  666. #
  667. # # recall图
  668. # recall = history.history['recall']
  669. # val_recall = history.history['val_recall']
  670. # plt.subplot(2, 2, 4)
  671. # plt.plot(epochs, recall, 'bo', label='Training recall')
  672. # plt.plot(epochs, val_recall, 'b', label='Validation recall')
  673. # plt.title('Training and validation recall')
  674. # plt.xlabel('Epochs')
  675. # plt.ylabel('Recall')
  676. # plt.legend()
  677. #
  678. # plt.show()
  679. if __name__=="__main__":
  680. print(fool_char_to_id[">"])
  681. # model = getModel_w2v()
  682. # vocab,matrix = getVocabAndMatrix(model, Embedding_size=128)
  683. # save([vocab,matrix],"vocabMatrix_words.pk")