Utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884
  1. '''
  2. Created on 2018年12月20日
  3. @author: User
  4. '''
  5. import numpy as np
  6. import re
  7. import gensim
  8. from keras import backend as K
  9. import os,sys
  10. import time
  11. from threading import RLock
  12. # from pai_tf_predict_proto import tf_predict_pb2
  13. import requests
  14. model_w2v = None
  15. lock_model_w2v = RLock()
  16. USE_PAI_EAS = False
  17. Lazy_load = False
  18. # API_URL = "http://192.168.2.103:8802"
  19. API_URL = "http://127.0.0.1:888"
  20. # USE_API = True
  21. USE_API = False
  22. def getCurrent_date(format="%Y-%m-%d %H:%M:%S"):
  23. _time = time.strftime(format,time.localtime())
  24. return _time
  25. def getw2vfilepath():
  26. filename = "wiki_128_word_embedding_new.vector"
  27. w2vfile = getFileFromSysPath(filename)
  28. if w2vfile is not None:
  29. return w2vfile
  30. return filename
  31. def getLazyLoad():
  32. global Lazy_load
  33. return Lazy_load
  34. def getFileFromSysPath(filename):
  35. for _path in sys.path:
  36. if os.path.isdir(_path):
  37. for _file in os.listdir(_path):
  38. _abspath = os.path.join(_path,_file)
  39. if os.path.isfile(_abspath):
  40. if _file==filename:
  41. return _abspath
  42. return None
  43. model_word_file = os.path.dirname(__file__)+"/../singlew2v_model.vector"
  44. model_word = None
  45. lock_model_word = RLock()
  46. from decimal import Decimal
  47. import logging
  48. logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  49. logger = logging.getLogger(__name__)
  50. import pickle
  51. import os
  52. import json
  53. #自定义jsonEncoder
  54. class MyEncoder(json.JSONEncoder):
  55. def __init__(self):
  56. import numpy as np
  57. global np
  58. def default(self, obj):
  59. if isinstance(obj, np.ndarray):
  60. return obj.tolist()
  61. elif isinstance(obj, bytes):
  62. return str(obj, encoding='utf-8')
  63. elif isinstance(obj, (np.float_, np.float16, np.float32,
  64. np.float64)):
  65. return float(obj)
  66. elif isinstance(obj,(np.int64,np.int32)):
  67. return int(obj)
  68. return json.JSONEncoder.default(self, obj)
  69. vocab_word = None
  70. vocab_words = None
  71. file_vocab_word = "vocab_word.pk"
  72. file_vocab_words = "vocab_words.pk"
  73. selffool_authorization = "NjlhMWFjMjVmNWYyNzI0MjY1OGQ1M2Y0ZmY4ZGY0Mzg3Yjc2MTVjYg=="
  74. selffool_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/selffool_gpu"
  75. selffool_seg_authorization = "OWUwM2Q0ZmE3YjYxNzU4YzFiMjliNGVkMTA3MzJkNjQ2MzJiYzBhZg=="
  76. selffool_seg_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/selffool_seg_gpu"
  77. codename_authorization = "Y2M5MDUxMzU1MTU4OGM3ZDk2ZmEzYjkxYmYyYzJiZmUyYTgwYTg5NA=="
  78. codename_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/codename_gpu"
  79. form_item_authorization = "ODdkZWY1YWY0NmNhNjU2OTI2NWY4YmUyM2ZlMDg1NTZjOWRkYTVjMw=="
  80. form_item_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/form"
  81. person_authorization = "N2I2MDU2N2Q2MGQ0ZWZlZGM3NDkyNTA1Nzc4YmM5OTlhY2MxZGU1Mw=="
  82. person_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/person"
  83. role_authorization = "OWM1ZDg5ZDEwYTEwYWI4OGNjYmRlMmQ1NzYwNWNlZGZkZmRmMjE4OQ=="
  84. role_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/role"
  85. money_authorization = "MDQyNjc2ZDczYjBhYmM4Yzc4ZGI4YjRmMjc3NGI5NTdlNzJiY2IwZA=="
  86. money_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/money"
  87. codeclasses_authorization = "MmUyNWIxZjQ2NjAzMWJlMGIzYzkxMjMzNWY5OWI3NzJlMWQ1ZjY4Yw=="
  88. codeclasses_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/codeclasses"
  89. def viterbi_decode(score, transition_params):
  90. """Decode the highest scoring sequence of tags outside of TensorFlow.
  91. This should only be used at test time.
  92. Args:
  93. score: A [seq_len, num_tags] matrix of unary potentials.
  94. transition_params: A [num_tags, num_tags] matrix of binary potentials.
  95. Returns:
  96. viterbi: A [seq_len] list of integers containing the highest scoring tag
  97. indices.
  98. viterbi_score: A float containing the score for the Viterbi sequence.
  99. """
  100. trellis = np.zeros_like(score)
  101. backpointers = np.zeros_like(score, dtype=np.int32)
  102. trellis[0] = score[0]
  103. for t in range(1, score.shape[0]):
  104. v = np.expand_dims(trellis[t - 1], 1) + transition_params
  105. trellis[t] = score[t] + np.max(v, 0)
  106. backpointers[t] = np.argmax(v, 0)
  107. viterbi = [np.argmax(trellis[-1])]
  108. for bp in reversed(backpointers[1:]):
  109. viterbi.append(bp[viterbi[-1]])
  110. viterbi.reverse()
  111. viterbi_score = np.max(trellis[-1])
  112. return viterbi, viterbi_score
  113. def limitRun(sess,list_output,feed_dict,MAX_BATCH=1024):
  114. len_sample = 0
  115. if len(feed_dict.keys())>0:
  116. len_sample = len(feed_dict[list(feed_dict.keys())[0]])
  117. if len_sample>MAX_BATCH:
  118. list_result = [[] for _ in range(len(list_output))]
  119. _begin = 0
  120. while(_begin<len_sample):
  121. new_dict = dict()
  122. for _key in feed_dict.keys():
  123. if isinstance(feed_dict[_key],(float,int,np.int32,np.float_,np.float16,np.float32,np.float64)):
  124. new_dict[_key] = feed_dict[_key]
  125. else:
  126. new_dict[_key] = feed_dict[_key][_begin:_begin+MAX_BATCH]
  127. _output = sess.run(list_output,feed_dict=new_dict)
  128. for _index in range(len(list_output)):
  129. list_result[_index].extend(_output[_index])
  130. _begin += MAX_BATCH
  131. else:
  132. list_result = sess.run(list_output,feed_dict=feed_dict)
  133. return list_result
  134. def get_values(response,output_name):
  135. """
  136. Get the value of a specified output tensor
  137. :param output_name: name of the output tensor
  138. :return: the content of the output tensor
  139. """
  140. output = response.outputs[output_name]
  141. if output.dtype == tf_predict_pb2.DT_FLOAT:
  142. _value = output.float_val
  143. elif output.dtype == tf_predict_pb2.DT_INT8 or output.dtype == tf_predict_pb2.DT_INT16 or \
  144. output.dtype == tf_predict_pb2.DT_INT32:
  145. _value = output.int_val
  146. elif output.dtype == tf_predict_pb2.DT_INT64:
  147. _value = output.int64_val
  148. elif output.dtype == tf_predict_pb2.DT_DOUBLE:
  149. _value = output.double_val
  150. elif output.dtype == tf_predict_pb2.DT_STRING:
  151. _value = output.string_val
  152. elif output.dtype == tf_predict_pb2.DT_BOOL:
  153. _value = output.bool_val
  154. return np.array(_value).reshape(response.outputs[output_name].array_shape.dim)
  155. def vpc_requests(url,authorization,request_data,list_outputs):
  156. headers = {"Authorization": authorization}
  157. dict_outputs = dict()
  158. response = tf_predict_pb2.PredictResponse()
  159. resp = requests.post(url, data=request_data, headers=headers)
  160. if resp.status_code != 200:
  161. print(resp.status_code,resp.content)
  162. log("调用pai-eas接口出错,authorization:"+str(authorization))
  163. return None
  164. else:
  165. response = tf_predict_pb2.PredictResponse()
  166. response.ParseFromString(resp.content)
  167. for _output in list_outputs:
  168. dict_outputs[_output] = get_values(response, _output)
  169. return dict_outputs
  170. def encodeInput(data,word_len,word_flag=True,userFool=False):
  171. result = []
  172. out_index = 0
  173. for item in data:
  174. if out_index in [0]:
  175. list_word = item[-word_len:]
  176. else:
  177. list_word = item[:word_len]
  178. temp = []
  179. if word_flag:
  180. for word in list_word:
  181. if userFool:
  182. temp.append(getIndexOfWord_fool(word))
  183. else:
  184. temp.append(getIndexOfWord(word))
  185. list_append = []
  186. temp_len = len(temp)
  187. while(temp_len<word_len):
  188. if userFool:
  189. list_append.append(0)
  190. else:
  191. list_append.append(getIndexOfWord("<pad>"))
  192. temp_len += 1
  193. if out_index in [0]:
  194. temp = list_append+temp
  195. else:
  196. temp = temp+list_append
  197. else:
  198. for words in list_word:
  199. temp.append(getIndexOfWords(words))
  200. list_append = []
  201. temp_len = len(temp)
  202. while(temp_len<word_len):
  203. list_append.append(getIndexOfWords("<pad>"))
  204. temp_len += 1
  205. if out_index in [0,1]:
  206. temp = list_append+temp
  207. else:
  208. temp = temp+list_append
  209. result.append(temp)
  210. out_index += 1
  211. return result
  212. def encodeInput_form(input,MAX_LEN=30):
  213. x = np.zeros([MAX_LEN])
  214. for i in range(len(input)):
  215. if i>=MAX_LEN:
  216. break
  217. x[i] = getIndexOfWord(input[i])
  218. return x
  219. def getVocabAndMatrix(model,Embedding_size = 60):
  220. '''
  221. @summary:获取子向量的词典和子向量矩阵
  222. '''
  223. vocab = ["<pad>"]+model.index2word
  224. embedding_matrix = np.zeros((len(vocab),Embedding_size))
  225. for i in range(1,len(vocab)):
  226. embedding_matrix[i] = model[vocab[i]]
  227. return vocab,embedding_matrix
  228. def getIndexOfWord(word):
  229. global vocab_word,file_vocab_word
  230. if vocab_word is None:
  231. if os.path.exists(file_vocab_word):
  232. vocab = load(file_vocab_word)
  233. vocab_word = dict((w, i) for i, w in enumerate(np.array(vocab)))
  234. else:
  235. model = getModel_word()
  236. vocab,_ = getVocabAndMatrix(model, Embedding_size=60)
  237. vocab_word = dict((w, i) for i, w in enumerate(np.array(vocab)))
  238. save(vocab,file_vocab_word)
  239. if word in vocab_word.keys():
  240. return vocab_word[word]
  241. else:
  242. return vocab_word['<pad>']
  243. def changeIndexFromWordToWords(tokens,word_index):
  244. '''
  245. @summary:转换某个字的字偏移为词偏移
  246. '''
  247. before_index = 0
  248. after_index = 0
  249. for i in range(len(tokens)):
  250. after_index = after_index+len(tokens[i])
  251. if before_index<=word_index and after_index>word_index:
  252. return i
  253. before_index = after_index
  254. def getIndexOfWords(words):
  255. global vocab_words,file_vocab_words
  256. if vocab_words is None:
  257. if os.path.exists(file_vocab_words):
  258. vocab = load(file_vocab_words)
  259. vocab_words = dict((w, i) for i, w in enumerate(np.array(vocab)))
  260. else:
  261. model = getModel_w2v()
  262. vocab,_ = getVocabAndMatrix(model, Embedding_size=128)
  263. vocab_words = dict((w, i) for i, w in enumerate(np.array(vocab)))
  264. save(vocab,file_vocab_words)
  265. if words in vocab_words.keys():
  266. return vocab_words[words]
  267. else:
  268. return vocab_words["<pad>"]
  269. def log(msg):
  270. '''
  271. @summary:打印信息
  272. '''
  273. logger.info(msg)
  274. def debug(msg):
  275. '''
  276. @summary:打印信息
  277. '''
  278. logger.debug(msg)
  279. def save(object_to_save, path):
  280. '''
  281. 保存对象
  282. @Arugs:
  283. object_to_save: 需要保存的对象
  284. @Return:
  285. 保存的路径
  286. '''
  287. with open(path, 'wb') as f:
  288. pickle.dump(object_to_save, f)
  289. def load(path):
  290. '''
  291. 读取对象
  292. @Arugs:
  293. path: 读取的路径
  294. @Return:
  295. 读取的对象
  296. '''
  297. with open(path, 'rb') as f:
  298. object1 = pickle.load(f)
  299. return object1
  300. fool_char_to_id = load(os.path.dirname(__file__)+"/fool_char_to_id.pk")
  301. def getIndexOfWord_fool(word):
  302. if word in fool_char_to_id.keys():
  303. return fool_char_to_id[word]
  304. else:
  305. return fool_char_to_id["[UNK]"]
  306. def find_index(list_tofind,text):
  307. '''
  308. @summary: 查找所有词汇在字符串中第一次出现的位置
  309. @param:
  310. list_tofind:待查找词汇
  311. text:字符串
  312. @return: list,每个词汇第一次出现的位置
  313. '''
  314. result = []
  315. for item in list_tofind:
  316. index = text.find(item)
  317. if index>=0:
  318. result.append(index)
  319. else:
  320. result.append(-1)
  321. return result
  322. def combine(list1,list2):
  323. '''
  324. @summary:将两个list中的字符串两两拼接
  325. @param:
  326. list1:字符串list
  327. list2:字符串list
  328. @return:拼接结果list
  329. '''
  330. result = []
  331. for item1 in list1:
  332. for item2 in list2:
  333. result.append(str(item1)+str(item2))
  334. return result
  335. def getDigitsDic(unit):
  336. '''
  337. @summary:拿到中文对应的数字
  338. '''
  339. DigitsDic = {"零":0, "壹":1, "贰":2, "叁":3, "肆":4, "伍":5, "陆":6, "柒":7, "捌":8, "玖":9,
  340. "〇":0, "一":1, "二":2, "三":3, "四":4, "五":5, "六":6, "七":7, "八":8, "九":9}
  341. return DigitsDic.get(unit)
  342. def getMultipleFactor(unit):
  343. '''
  344. @summary:拿到单位对应的值
  345. '''
  346. MultipleFactor = {"兆":Decimal(1000000000000),"亿":Decimal(100000000),"万":Decimal(10000),"仟":Decimal(1000),"千":Decimal(1000),"佰":Decimal(100),"百":Decimal(100),"拾":Decimal(10),"十":Decimal(10),"元":Decimal(1),"圆":Decimal(1),"角":round(Decimal(0.1),1),"分":round(Decimal(0.01),2)}
  347. return MultipleFactor.get(unit)
  348. def getUnifyMoney(money):
  349. '''
  350. @summary:将中文金额字符串转换为数字金额
  351. @param:
  352. money:中文金额字符串
  353. @return: decimal,数据金额
  354. '''
  355. MAX_MONEY = 1000000000000
  356. MAX_NUM = 12
  357. #去掉逗号
  358. money = re.sub("[,,]","",money)
  359. money = re.sub("[^0-9.零壹贰叁肆伍陆柒捌玖拾佰仟萬億圆十百千万亿元角分]","",money)
  360. result = Decimal(0)
  361. chnDigits = ["零", "壹", "贰", "叁", "肆", "伍", "陆", "柒", "捌", "玖"]
  362. # chnFactorUnits = ["兆", "亿", "万", "仟", "佰", "拾","圆","元","角","分"]
  363. chnFactorUnits = ["兆", "亿", "万", "仟", "佰", "拾", "圆", "元", "角", "分", '十', '百', '千']
  364. LowMoneypattern = re.compile("^[\d,]+(\.\d+)?$")
  365. BigMoneypattern = re.compile("^零?(?P<BigMoney>[%s])$"%("".join(chnDigits)))
  366. try:
  367. if re.search(LowMoneypattern,money) is not None:
  368. return Decimal(money)
  369. elif re.search(BigMoneypattern,money) is not None:
  370. return getDigitsDic(re.search(BigMoneypattern,money).group("BigMoney"))
  371. for factorUnit in chnFactorUnits:
  372. if re.search(re.compile(".*%s.*"%(factorUnit)),money) is not None:
  373. subMoneys = re.split(re.compile("%s(?!.*%s.*)"%(factorUnit,factorUnit)),money)
  374. if re.search(re.compile("^(\d+)(\.\d+)?$"),subMoneys[0]) is not None:
  375. if MAX_MONEY/getMultipleFactor(factorUnit)<Decimal(subMoneys[0]):
  376. return Decimal(0)
  377. result += Decimal(subMoneys[0])*(getMultipleFactor(factorUnit))
  378. elif len(subMoneys[0])==1:
  379. if re.search(re.compile("^[%s]$"%("".join(chnDigits))),subMoneys[0]) is not None:
  380. result += Decimal(getDigitsDic(subMoneys[0]))*(getMultipleFactor(factorUnit))
  381. # subMoneys[0]中无金额单位,不可再拆分
  382. elif re.search(re.compile("[%s]"%("".join(chnFactorUnits))),subMoneys[0]) is None:
  383. subMoneys[0] = subMoneys[0][0]
  384. result += Decimal(getDigitsDic(subMoneys[0])) * (getMultipleFactor(factorUnit))
  385. else:
  386. result += Decimal(getUnifyMoney(subMoneys[0]))*(getMultipleFactor(factorUnit))
  387. if len(subMoneys)>1:
  388. if re.search(re.compile("^(\d+(,)?)+(\.\d+)?[百千万亿]?\s?(元)?$"),subMoneys[1]) is not None:
  389. result += Decimal(subMoneys[1])
  390. elif len(subMoneys[1])==1:
  391. if re.search(re.compile("^[%s]$"%("".join(chnDigits))),subMoneys[1]) is not None:
  392. result += Decimal(getDigitsDic(subMoneys[1]))
  393. else:
  394. result += Decimal(getUnifyMoney(subMoneys[1]))
  395. break
  396. except Exception as e:
  397. return Decimal(0)
  398. return result
  399. def getModel_w2v():
  400. '''
  401. @summary:加载词向量
  402. '''
  403. global model_w2v,lock_model_w2v
  404. with lock_model_w2v:
  405. if model_w2v is None:
  406. model_w2v = gensim.models.KeyedVectors.load_word2vec_format(getw2vfilepath(),binary=True)
  407. return model_w2v
  408. def getModel_word():
  409. '''
  410. @summary:加载字向量
  411. '''
  412. global model_word,lock_model_w2v
  413. with lock_model_word:
  414. if model_word is None:
  415. model_word = gensim.models.KeyedVectors.load_word2vec_format(model_word_file,binary=True)
  416. return model_word
  417. # getModel_w2v()
  418. # getModel_word()
  419. def findAllIndex(substr,wholestr):
  420. '''
  421. @summary: 找到字符串的子串的所有begin_index
  422. @param:
  423. substr:子字符串
  424. wholestr:子串所在完整字符串
  425. @return: list,字符串的子串的所有begin_index
  426. '''
  427. copystr = wholestr
  428. result = []
  429. indexappend = 0
  430. while(True):
  431. index = copystr.find(substr)
  432. if index<0:
  433. break
  434. else:
  435. result.append(indexappend+index)
  436. indexappend += index+len(substr)
  437. copystr = copystr[index+len(substr):]
  438. return result
  439. def spanWindow(tokens,begin_index,end_index,size,center_include=False,word_flag = False,use_text = False,text = None):
  440. '''
  441. @summary:取得某个实体的上下文词汇
  442. @param:
  443. tokens:句子分词list
  444. begin_index:实体的开始index
  445. end_index:实体的结束index
  446. size:左右两边各取多少个词
  447. center_include:是否包含实体
  448. word_flag:词/字,默认是词
  449. @return: list,实体的上下文词汇
  450. '''
  451. if use_text:
  452. assert text is not None
  453. length_tokens = len(tokens)
  454. if begin_index>size:
  455. begin = begin_index-size
  456. else:
  457. begin = 0
  458. if end_index+size<length_tokens:
  459. end = end_index+size+1
  460. else:
  461. end = length_tokens
  462. result = []
  463. if not word_flag:
  464. result.append(tokens[begin:begin_index])
  465. if center_include:
  466. if use_text:
  467. result.append(text)
  468. else:
  469. result.append(tokens[begin_index:end_index+1])
  470. result.append(tokens[end_index+1:end])
  471. else:
  472. result.append("".join(tokens[begin:begin_index]))
  473. if center_include:
  474. if use_text:
  475. result.append(text)
  476. else:
  477. result.append("".join(tokens[begin_index:end_index+1]))
  478. result.append("".join(tokens[end_index+1:end]))
  479. #print(result)
  480. return result
  481. #根据规则补全编号或名称两边的符号
  482. def fitDataByRule(data):
  483. symbol_dict = {"(":")",
  484. "(":")",
  485. "[":"]",
  486. "【":"】",
  487. ")":"(",
  488. ")":"(",
  489. "]":"[",
  490. "】":"【"}
  491. leftSymbol_pattern = re.compile("[\((\[【]")
  492. rightSymbol_pattern = re.compile("[\))\]】]")
  493. leftfinds = re.findall(leftSymbol_pattern,data)
  494. rightfinds = re.findall(rightSymbol_pattern,data)
  495. result = data
  496. if len(leftfinds)+len(rightfinds)==0:
  497. return data
  498. elif len(leftfinds)==len(rightfinds):
  499. return data
  500. elif abs(len(leftfinds)-len(rightfinds))==1:
  501. if len(leftfinds)>len(rightfinds):
  502. if symbol_dict.get(data[0]) is not None:
  503. result = data[1:]
  504. else:
  505. #print(symbol_dict.get(leftfinds[0]))
  506. result = data+symbol_dict.get(leftfinds[0])
  507. else:
  508. if symbol_dict.get(data[-1]) is not None:
  509. result = data[:-1]
  510. else:
  511. result = symbol_dict.get(rightfinds[0])+data
  512. result = re.sub("[。]","",result)
  513. return result
  514. time_format_pattern = re.compile("((?P<year>\d{4}|\d{2})\s*[-\/年\.]\s*(?P<month>\d{1,2})\s*[-\/月\.]\s*(?P<day>\d{1,2}))")
  515. def timeFormat(_time):
  516. current_year = time.strftime("%Y",time.localtime())
  517. all_match = re.finditer(time_format_pattern,_time)
  518. for _match in all_match:
  519. if len(_match.group())>0:
  520. legal = True
  521. year = ""
  522. month = ""
  523. day = ""
  524. for k,v in _match.groupdict().items():
  525. if k=="year":
  526. year = v
  527. if k=="month":
  528. month = v
  529. if k=="day":
  530. day = v
  531. if year!="":
  532. if len(year)==2:
  533. year = "20"+year
  534. if int(year)>int(current_year):
  535. legal = False
  536. else:
  537. legal = False
  538. if month!="":
  539. if int(month)>12:
  540. legal = False
  541. else:
  542. legal = False
  543. if day!="":
  544. if int(day)>31:
  545. legal = False
  546. else:
  547. legal = False
  548. if legal:
  549. return "%s-%s-%s"%(year,month.rjust(2,"0"),day.rjust(2,"0"))
  550. return ""
  551. def embedding(datas,shape):
  552. '''
  553. @summary:查找词汇对应的词向量
  554. @param:
  555. datas:词汇的list
  556. shape:结果的shape
  557. @return: array,返回对应shape的词嵌入
  558. '''
  559. model_w2v = getModel_w2v()
  560. embed = np.zeros(shape)
  561. length = shape[1]
  562. out_index = 0
  563. #print(datas)
  564. for data in datas:
  565. index = 0
  566. for item in data:
  567. item_not_space = re.sub("\s*","",item)
  568. if index>=length:
  569. break
  570. if item_not_space in model_w2v.vocab:
  571. embed[out_index][index] = model_w2v[item_not_space]
  572. index += 1
  573. else:
  574. #embed[out_index][index] = model_w2v['unk']
  575. index += 1
  576. out_index += 1
  577. return embed
  578. def embedding_word(datas,shape):
  579. '''
  580. @summary:查找词汇对应的词向量
  581. @param:
  582. datas:词汇的list
  583. shape:结果的shape
  584. @return: array,返回对应shape的词嵌入
  585. '''
  586. model_w2v = getModel_word()
  587. embed = np.zeros(shape)
  588. length = shape[1]
  589. out_index = 0
  590. #print(datas)
  591. for data in datas:
  592. index = 0
  593. for item in str(data)[-shape[1]:]:
  594. if index>=length:
  595. break
  596. if item in model_w2v.vocab:
  597. embed[out_index][index] = model_w2v[item]
  598. index += 1
  599. else:
  600. # embed[out_index][index] = model_w2v['unk']
  601. index += 1
  602. out_index += 1
  603. return embed
  604. def embedding_word_forward(datas,shape):
  605. '''
  606. @summary:查找词汇对应的词向量
  607. @param:
  608. datas:词汇的list
  609. shape:结果的shape
  610. @return: array,返回对应shape的词嵌入
  611. '''
  612. model_w2v = getModel_word()
  613. embed = np.zeros(shape)
  614. length = shape[1]
  615. out_index = 0
  616. #print(datas)
  617. for data in datas:
  618. index = 0
  619. for item in str(data)[:shape[1]]:
  620. if index>=length:
  621. break
  622. if item in model_w2v.vocab:
  623. embed[out_index][index] = model_w2v[item]
  624. index += 1
  625. else:
  626. # embed[out_index][index] = model_w2v['unk']
  627. index += 1
  628. out_index += 1
  629. return embed
  630. def formEncoding(text,shape=(100,60),expand=False):
  631. embedding = np.zeros(shape)
  632. word_model = getModel_word()
  633. for i in range(len(text)):
  634. if i>=shape[0]:
  635. break
  636. if text[i] in word_model.vocab:
  637. embedding[i] = word_model[text[i]]
  638. if expand:
  639. embedding = np.expand_dims(embedding,0)
  640. return embedding
  641. def partMoney(entity_text,input2_shape = [7]):
  642. '''
  643. @summary:对金额分段
  644. @param:
  645. entity_text:数值金额
  646. input2_shape:分类数
  647. @return: array,分段之后的独热编码
  648. '''
  649. money = float(entity_text)
  650. parts = np.zeros(input2_shape)
  651. if money<100:
  652. parts[0] = 1
  653. elif money<1000:
  654. parts[1] = 1
  655. elif money<10000:
  656. parts[2] = 1
  657. elif money<100000:
  658. parts[3] = 1
  659. elif money<1000000:
  660. parts[4] = 1
  661. elif money<10000000:
  662. parts[5] = 1
  663. else:
  664. parts[6] = 1
  665. return parts
  666. def recall(y_true, y_pred):
  667. '''
  668. 计算召回率
  669. @Argus:
  670. y_true: 正确的标签
  671. y_pred: 模型预测的标签
  672. @Return
  673. 召回率
  674. '''
  675. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  676. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  677. if c3 == 0:
  678. return 0
  679. recall = c1 / c3
  680. return recall
  681. def f1_score(y_true, y_pred):
  682. '''
  683. 计算F1
  684. @Argus:
  685. y_true: 正确的标签
  686. y_pred: 模型预测的标签
  687. @Return
  688. F1值
  689. '''
  690. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  691. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  692. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  693. precision = c1 / c2
  694. if c3 == 0:
  695. recall = 0
  696. else:
  697. recall = c1 / c3
  698. f1_score = 2 * (precision * recall) / (precision + recall)
  699. return f1_score
  700. def precision(y_true, y_pred):
  701. '''
  702. 计算精确率
  703. @Argus:
  704. y_true: 正确的标签
  705. y_pred: 模型预测的标签
  706. @Return
  707. 精确率
  708. '''
  709. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  710. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  711. precision = c1 / c2
  712. return precision
  713. # def print_metrics(history):
  714. # '''
  715. # 制作每次迭代的各metrics变化图片
  716. #
  717. # @Arugs:
  718. # history: 模型训练迭代的历史记录
  719. # '''
  720. # import matplotlib.pyplot as plt
  721. #
  722. # # loss图
  723. # loss = history.history['loss']
  724. # val_loss = history.history['val_loss']
  725. # epochs = range(1, len(loss) + 1)
  726. # plt.subplot(2, 2, 1)
  727. # plt.plot(epochs, loss, 'bo', label='Training loss')
  728. # plt.plot(epochs, val_loss, 'b', label='Validation loss')
  729. # plt.title('Training and validation loss')
  730. # plt.xlabel('Epochs')
  731. # plt.ylabel('Loss')
  732. # plt.legend()
  733. #
  734. # # f1图
  735. # f1 = history.history['f1_score']
  736. # val_f1 = history.history['val_f1_score']
  737. # plt.subplot(2, 2, 2)
  738. # plt.plot(epochs, f1, 'bo', label='Training f1')
  739. # plt.plot(epochs, val_f1, 'b', label='Validation f1')
  740. # plt.title('Training and validation f1')
  741. # plt.xlabel('Epochs')
  742. # plt.ylabel('F1')
  743. # plt.legend()
  744. #
  745. # # precision图
  746. # prec = history.history['precision']
  747. # val_prec = history.history['val_precision']
  748. # plt.subplot(2, 2, 3)
  749. # plt.plot(epochs, prec, 'bo', label='Training precision')
  750. # plt.plot(epochs, val_prec, 'b', label='Validation pecision')
  751. # plt.title('Training and validation precision')
  752. # plt.xlabel('Epochs')
  753. # plt.ylabel('Precision')
  754. # plt.legend()
  755. #
  756. # # recall图
  757. # recall = history.history['recall']
  758. # val_recall = history.history['val_recall']
  759. # plt.subplot(2, 2, 4)
  760. # plt.plot(epochs, recall, 'bo', label='Training recall')
  761. # plt.plot(epochs, val_recall, 'b', label='Validation recall')
  762. # plt.title('Training and validation recall')
  763. # plt.xlabel('Epochs')
  764. # plt.ylabel('Recall')
  765. # plt.legend()
  766. #
  767. # plt.show()
  768. if __name__=="__main__":
  769. # print(fool_char_to_id[">"])
  770. print(getUnifyMoney('壹拾柒万元'))
  771. # model = getModel_w2v()
  772. # vocab,matrix = getVocabAndMatrix(model, Embedding_size=128)
  773. # save([vocab,matrix],"vocabMatrix_words.pk")