Utils.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222
  1. '''
  2. Created on 2018年12月20日
  3. @author: User
  4. '''
  5. import numpy as np
  6. import re
  7. import gensim
  8. import os
  9. from threading import RLock
  10. # from pai_tf_predict_proto import tf_predict_pb2
  11. import requests
  12. import time
  13. from bs4 import BeautifulSoup
  14. model_w2v = None
  15. lock_model_w2v = RLock()
  16. USE_PAI_EAS = False
  17. Lazy_load = False
  18. ILLEGAL_CHARACTERS_RE = re.compile(r'[\000-\010]|[\013-\014]|[\016-\037]|\x00')
  19. import smtplib
  20. from email.mime.application import MIMEApplication
  21. from email.mime.multipart import MIMEMultipart
  22. from email.utils import formataddr
  23. ILLEGAL_CHARACTERS_RE = re.compile(r'[\000-\010]|[\013-\014]|[\016-\037]')
  24. def getLegal_str(_str):
  25. if _str is not None:
  26. return ILLEGAL_CHARACTERS_RE.sub("",str(_str))
  27. import traceback
  28. dict_server = {}
  29. def getServer(host,username,password,reconnect=False):
  30. key = "%s-%s-%s"%(host,username,password)
  31. if key in dict_server:
  32. server = dict_server[key]
  33. if reconnect:
  34. server = smtplib.SMTP_SSL(host, 465)
  35. server.login(username,password)
  36. else:
  37. server = smtplib.SMTP_SSL(host, 465)
  38. server.login(username,password)
  39. dict_server[key] = server
  40. return server
  41. from email.mime.text import MIMEText
  42. def sendEmail(host,username,password,receivers,subject="数据导出",content="",attachs=[]):
  43. try:
  44. #处理附件
  45. msg = MIMEMultipart()
  46. msg["From"] = formataddr(["广州比地数据科技有限公司",username])
  47. msg["To"] = formataddr(["客户",receivers[0]])
  48. msg["Subject"] = subject
  49. message = MIMEText(content, 'plain', 'utf-8')
  50. for at in attachs:
  51. xlsfile = MIMEApplication(open(at,"rb").read())
  52. xlsfile.add_header("Content-Disposition","attachment",filename=('gbk', '', at.split("/")[-1]))
  53. log(at.split("/")[-1])
  54. msg.attach(xlsfile)
  55. server = getServer(host,username,password)
  56. server.sendmail(username,receivers,msg.as_string())
  57. log("发送邮件成功%s"%str(attachs))
  58. except smtplib.SMTPServerDisconnected as e:
  59. server = getServer(host,username,password,reconnect=True)
  60. server.sendmail(username,receivers,msg.as_string())
  61. log("发送邮件成功%s"%str(attachs))
  62. except Exception as e:
  63. traceback.print_exc()
  64. log("发送邮件错误%s"%str(e))
  65. finally:
  66. server.close()
  67. mobile_pattern = re.compile("^1\d{10}$")
  68. def recog_likeType(phone):
  69. if re.search(mobile_pattern,phone) is not None:
  70. return "mobile"
  71. else:
  72. return "phone"
  73. def article_limit(soup,limit_words=30000):
  74. sub_space = re.compile("\s+")
  75. def soup_limit(_soup,_count,max_count=30000,max_gap=500):
  76. """
  77. :param _soup: soup
  78. :param _count: 当前字数
  79. :param max_count: 字数最大限制
  80. :param max_gap: 超过限制后的最大误差
  81. :return:
  82. """
  83. _gap = _count - max_count
  84. _is_skip = False
  85. next_soup = None
  86. while len(_soup.find_all(recursive=False)) == 1 and \
  87. _soup.get_text(strip=True) == _soup.find_all(recursive=False)[0].get_text(strip=True):
  88. _soup = _soup.find_all(recursive=False)[0]
  89. if len(_soup.find_all(recursive=False)) == 0:
  90. _soup.string = str(_soup.get_text())[:max_count-_count]
  91. _count += len(re.sub(sub_space, "", _soup.string))
  92. _gap = _count - max_count
  93. next_soup = None
  94. else:
  95. for _soup_part in _soup.find_all(recursive=False):
  96. if not _is_skip:
  97. _count += len(re.sub(sub_space, "", _soup_part.get_text()))
  98. if _count >= max_count:
  99. _gap = _count - max_count
  100. if _gap <= max_gap:
  101. _is_skip = True
  102. else:
  103. _is_skip = True
  104. next_soup = _soup_part
  105. _count -= len(re.sub(sub_space, "", _soup_part.get_text()))
  106. continue
  107. else:
  108. _soup_part.decompose()
  109. return _count,_gap,next_soup
  110. text_count = 0
  111. have_attachment = False
  112. attachment_part = None
  113. _attachment = soup.find("div", attrs={"class": "richTextFetch"})
  114. if _attachment is not None:
  115. _attachment.insert_before("##attachment##")
  116. attachment_part = _attachment
  117. have_attachment = True
  118. if not have_attachment:
  119. # 无附件
  120. if len(re.sub(sub_space, "", soup.get_text())) > limit_words:
  121. text_count,gap,n_soup = soup_limit(soup,text_count,max_count=limit_words,max_gap=500)
  122. while n_soup:
  123. text_count, gap, n_soup = soup_limit(n_soup, text_count, max_count=limit_words, max_gap=500)
  124. else:
  125. # 有附件
  126. _text = re.sub(sub_space, "", soup.get_text())
  127. _text_split = _text.split("##attachment##")
  128. if len(_text_split[0])>limit_words:
  129. main_soup = attachment_part.parent
  130. main_text = main_soup.find_all(recursive=False)[0]
  131. text_count, gap, n_soup = soup_limit(main_text, text_count, max_count=limit_words, max_gap=500)
  132. while n_soup:
  133. text_count, gap, n_soup = soup_limit(n_soup, text_count, max_count=limit_words, max_gap=500)
  134. if len(_text_split[1])>limit_words:
  135. # attachment_html纯文本,无子结构
  136. if len(attachment_part.find_all(recursive=False))==0:
  137. attachment_part.string = str(attachment_part.get_text())[:limit_words]
  138. else:
  139. attachment_text_nums = 0
  140. attachment_skip = False
  141. for part in attachment_part.find_all(recursive=False):
  142. if not attachment_skip:
  143. if part.name == 'div' and 'filemd5' in part.attrs:
  144. for p_part in part.find_all(recursive=False):
  145. last_attachment_text_nums = attachment_text_nums
  146. attachment_text_nums = attachment_text_nums + len(
  147. re.sub(sub_space, "", p_part.get_text()))
  148. if not attachment_skip:
  149. if attachment_text_nums >= limit_words:
  150. p_part.string = str(p_part.get_text())[:limit_words - last_attachment_text_nums]
  151. attachment_skip = True
  152. else:
  153. p_part.decompose()
  154. else:
  155. last_attachment_text_nums = attachment_text_nums
  156. attachment_text_nums = attachment_text_nums + len(re.sub(sub_space, "", part.get_text()))
  157. if attachment_text_nums >= limit_words and not attachment_skip:
  158. part.string = str(part.get_text())[:limit_words - last_attachment_text_nums]
  159. attachment_skip = True
  160. else:
  161. part.decompose()
  162. soup = str(soup).replace("##attachment##","")
  163. return soup
  164. def soup_limit(_soup,_count,max_count=30000,max_gap=500,sub_space = re.compile("\s+")):
  165. """
  166. :param _soup: soup
  167. :param _count: 当前字数
  168. :param max_count: 字数最大限制
  169. :param max_gap: 超过限制后的最大误差
  170. :return:
  171. """
  172. _gap = _count - max_count
  173. _is_skip = False
  174. next_soup = None
  175. while len(_soup.find_all(recursive=False)) == 1 and \
  176. _soup.get_text(strip=True) == _soup.find_all(recursive=False)[0].get_text(strip=True):
  177. _soup = _soup.find_all(recursive=False)[0]
  178. if len(_soup.find_all(recursive=False)) == 0:
  179. _soup.string = str(_soup.get_text())[:max_count-_count]
  180. _count += len(re.sub(sub_space, "", _soup.string))
  181. _gap = _count - max_count
  182. next_soup = None
  183. else:
  184. for _soup_part in _soup.find_all(recursive=False):
  185. if not _is_skip:
  186. _count += len(re.sub(sub_space, "", _soup_part.get_text()))
  187. if _count >= max_count:
  188. _gap = _count - max_count
  189. if _gap <= max_gap:
  190. _is_skip = True
  191. else:
  192. _is_skip = True
  193. next_soup = _soup_part
  194. _count -= len(re.sub(sub_space, "", _soup_part.get_text()))
  195. continue
  196. else:
  197. _soup_part.decompose()
  198. return _count,_gap,next_soup
  199. def cut_str(text_list, only_text_list, max_bytes_length=2000000):
  200. try:
  201. # 计算有格式总字节数
  202. bytes_length = 0
  203. for text in text_list:
  204. bytes_length += len(bytes(text, encoding='utf-8'))
  205. # 小于直接返回
  206. if bytes_length < max_bytes_length:
  207. return text_list
  208. # 全部文件连接,重新计算无格式字节数
  209. all_text = ""
  210. bytes_length = 0
  211. for text in only_text_list:
  212. bytes_length += len(bytes(text, encoding='utf-8'))
  213. all_text += text
  214. # 小于直接返回
  215. if bytes_length < max_bytes_length:
  216. return only_text_list
  217. # 截取字符
  218. all_text = all_text[:max_bytes_length//3]
  219. return [all_text]
  220. except Exception as e:
  221. logging.info("cut_str " + str(e))
  222. return text_list
  223. def getLegal_str(_str):
  224. if _str is not None:
  225. return ILLEGAL_CHARACTERS_RE.sub("",str(_str))
  226. return ""
  227. def getRow_ots_primary(row):
  228. _dict = dict()
  229. if row is None:
  230. return None
  231. for part in row.attribute_columns:
  232. _dict[part[0]] = part[1]
  233. for part in row.primary_key:
  234. _dict[part[0]] = part[1]
  235. return _dict
  236. def timeAdd(_time,days,format="%Y-%m-%d",minutes=0):
  237. a = time.mktime(time.strptime(_time,format))+86400*days+60*minutes
  238. _time1 = time.strftime(format,time.localtime(a))
  239. return _time1
  240. def getRow_ots(rows):
  241. list_dict = []
  242. for row in rows:
  243. _dict = dict()
  244. for part in row:
  245. for v in part:
  246. _dict[v[0]] = v[1]
  247. list_dict.append(_dict)
  248. return list_dict
  249. def getw2vfilepath():
  250. w2vfile = os.path.dirname(__file__)+"/../wiki_128_word_embedding_new.vector"
  251. if os.path.exists(w2vfile):
  252. return w2vfile
  253. return "wiki_128_word_embedding_new.vector"
  254. def getLazyLoad():
  255. global Lazy_load
  256. return Lazy_load
  257. def get_file_name(url, headers):
  258. filename = ''
  259. if 'Content-Disposition' in headers and headers['Content-Disposition']:
  260. disposition_split = headers['Content-Disposition'].split(';')
  261. if len(disposition_split) > 1:
  262. if disposition_split[1].strip().lower().startswith('filename='):
  263. file_name = disposition_split[1].split('=')
  264. if len(file_name) > 1:
  265. filename = file_name[1]
  266. if not filename and os.path.basename(url):
  267. filename = os.path.basename(url).split("?")[0]
  268. if not filename:
  269. return time.time()
  270. return filename
  271. model_word_file = os.path.dirname(__file__)+"/../singlew2v_model.vector"
  272. model_word = None
  273. lock_model_word = RLock()
  274. from decimal import Decimal
  275. import logging
  276. logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  277. logger = logging.getLogger(__name__)
  278. import pickle
  279. import os
  280. import json
  281. #自定义jsonEncoder
  282. class MyEncoder(json.JSONEncoder):
  283. def __init__(self):
  284. import numpy as np
  285. global np
  286. def default(self, obj):
  287. if isinstance(obj, np.ndarray):
  288. return obj.tolist()
  289. elif isinstance(obj, bytes):
  290. return str(obj, encoding='utf-8')
  291. elif isinstance(obj, (np.float_, np.float16, np.float32,
  292. np.float64)):
  293. return float(obj)
  294. elif isinstance(obj,(np.int64,np.int32)):
  295. return int(obj)
  296. return json.JSONEncoder.default(self, obj)
  297. vocab_word = None
  298. vocab_words = None
  299. file_vocab_word = "vocab_word.pk"
  300. file_vocab_words = "vocab_words.pk"
  301. selffool_authorization = "NjlhMWFjMjVmNWYyNzI0MjY1OGQ1M2Y0ZmY4ZGY0Mzg3Yjc2MTVjYg=="
  302. selffool_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/selffool_gpu"
  303. selffool_seg_authorization = "OWUwM2Q0ZmE3YjYxNzU4YzFiMjliNGVkMTA3MzJkNjQ2MzJiYzBhZg=="
  304. selffool_seg_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/selffool_seg_gpu"
  305. codename_authorization = "Y2M5MDUxMzU1MTU4OGM3ZDk2ZmEzYjkxYmYyYzJiZmUyYTgwYTg5NA=="
  306. codename_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/codename_gpu"
  307. form_item_authorization = "ODdkZWY1YWY0NmNhNjU2OTI2NWY4YmUyM2ZlMDg1NTZjOWRkYTVjMw=="
  308. form_item_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/form"
  309. person_authorization = "N2I2MDU2N2Q2MGQ0ZWZlZGM3NDkyNTA1Nzc4YmM5OTlhY2MxZGU1Mw=="
  310. person_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/person"
  311. role_authorization = "OWM1ZDg5ZDEwYTEwYWI4OGNjYmRlMmQ1NzYwNWNlZGZkZmRmMjE4OQ=="
  312. role_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/role"
  313. money_authorization = "MDQyNjc2ZDczYjBhYmM4Yzc4ZGI4YjRmMjc3NGI5NTdlNzJiY2IwZA=="
  314. money_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/money"
  315. codeclasses_authorization = "MmUyNWIxZjQ2NjAzMWJlMGIzYzkxMjMzNWY5OWI3NzJlMWQ1ZjY4Yw=="
  316. codeclasses_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/codeclasses"
  317. def viterbi_decode(score, transition_params):
  318. """Decode the highest scoring sequence of tags outside of TensorFlow.
  319. This should only be used at test time.
  320. Args:
  321. score: A [seq_len, num_tags] matrix of unary potentials.
  322. transition_params: A [num_tags, num_tags] matrix of binary potentials.
  323. Returns:
  324. viterbi: A [seq_len] list of integers containing the highest scoring tag
  325. indices.
  326. viterbi_score: A float containing the score for the Viterbi sequence.
  327. """
  328. trellis = np.zeros_like(score)
  329. backpointers = np.zeros_like(score, dtype=np.int32)
  330. trellis[0] = score[0]
  331. for t in range(1, score.shape[0]):
  332. v = np.expand_dims(trellis[t - 1], 1) + transition_params
  333. trellis[t] = score[t] + np.max(v, 0)
  334. backpointers[t] = np.argmax(v, 0)
  335. viterbi = [np.argmax(trellis[-1])]
  336. for bp in reversed(backpointers[1:]):
  337. viterbi.append(bp[viterbi[-1]])
  338. viterbi.reverse()
  339. viterbi_score = np.max(trellis[-1])
  340. return viterbi, viterbi_score
  341. import ctypes
  342. import inspect
  343. def _async_raise(tid, exctype):
  344. """raises the exception, performs cleanup if needed"""
  345. tid = ctypes.c_long(tid)
  346. if not inspect.isclass(exctype):
  347. exctype = type(exctype)
  348. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  349. if res == 0:
  350. raise ValueError("invalid thread id")
  351. elif res != 1:
  352. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  353. raise SystemError("PyThreadState_SetAsyncExc failed")
  354. def stop_thread(thread):
  355. _async_raise(thread.ident, SystemExit)
  356. def limitRun(sess,list_output,feed_dict,MAX_BATCH=1024):
  357. len_sample = 0
  358. if len(feed_dict.keys())>0:
  359. len_sample = len(feed_dict[list(feed_dict.keys())[0]])
  360. if len_sample>MAX_BATCH:
  361. list_result = [[] for _ in range(len(list_output))]
  362. _begin = 0
  363. while(_begin<len_sample):
  364. new_dict = dict()
  365. for _key in feed_dict.keys():
  366. new_dict[_key] = feed_dict[_key][_begin:_begin+MAX_BATCH]
  367. _output = sess.run(list_output,feed_dict=new_dict)
  368. for _index in range(len(list_output)):
  369. list_result[_index].extend(_output[_index])
  370. _begin += MAX_BATCH
  371. else:
  372. list_result = sess.run(list_output,feed_dict=feed_dict)
  373. return list_result
  374. def get_values(response,output_name):
  375. """
  376. Get the value of a specified output tensor
  377. :param output_name: name of the output tensor
  378. :return: the content of the output tensor
  379. """
  380. output = response.outputs[output_name]
  381. if output.dtype == tf_predict_pb2.DT_FLOAT:
  382. _value = output.float_val
  383. elif output.dtype == tf_predict_pb2.DT_INT8 or output.dtype == tf_predict_pb2.DT_INT16 or \
  384. output.dtype == tf_predict_pb2.DT_INT32:
  385. _value = output.int_val
  386. elif output.dtype == tf_predict_pb2.DT_INT64:
  387. _value = output.int64_val
  388. elif output.dtype == tf_predict_pb2.DT_DOUBLE:
  389. _value = output.double_val
  390. elif output.dtype == tf_predict_pb2.DT_STRING:
  391. _value = output.string_val
  392. elif output.dtype == tf_predict_pb2.DT_BOOL:
  393. _value = output.bool_val
  394. return np.array(_value).reshape(response.outputs[output_name].array_shape.dim)
  395. def vpc_requests(url,authorization,request_data,list_outputs):
  396. headers = {"Authorization": authorization}
  397. dict_outputs = dict()
  398. response = tf_predict_pb2.PredictResponse()
  399. resp = requests.post(url, data=request_data, headers=headers)
  400. if resp.status_code != 200:
  401. print(resp.status_code,resp.content)
  402. log("调用pai-eas接口出错,authorization:"+str(authorization))
  403. return None
  404. else:
  405. response = tf_predict_pb2.PredictResponse()
  406. response.ParseFromString(resp.content)
  407. for _output in list_outputs:
  408. dict_outputs[_output] = get_values(response, _output)
  409. return dict_outputs
  410. def encodeInput(data,word_len,word_flag=True,userFool=False):
  411. result = []
  412. out_index = 0
  413. for item in data:
  414. if out_index in [0]:
  415. list_word = item[-word_len:]
  416. else:
  417. list_word = item[:word_len]
  418. temp = []
  419. if word_flag:
  420. for word in list_word:
  421. if userFool:
  422. temp.append(getIndexOfWord_fool(word))
  423. else:
  424. temp.append(getIndexOfWord(word))
  425. list_append = []
  426. temp_len = len(temp)
  427. while(temp_len<word_len):
  428. if userFool:
  429. list_append.append(0)
  430. else:
  431. list_append.append(getIndexOfWord("<pad>"))
  432. temp_len += 1
  433. if out_index in [0]:
  434. temp = list_append+temp
  435. else:
  436. temp = temp+list_append
  437. else:
  438. for words in list_word:
  439. temp.append(getIndexOfWords(words))
  440. list_append = []
  441. temp_len = len(temp)
  442. while(temp_len<word_len):
  443. list_append.append(getIndexOfWords("<pad>"))
  444. temp_len += 1
  445. if out_index in [0,1]:
  446. temp = list_append+temp
  447. else:
  448. temp = temp+list_append
  449. result.append(temp)
  450. out_index += 1
  451. return result
  452. def encodeInput_form(input,MAX_LEN=30):
  453. x = np.zeros([MAX_LEN])
  454. for i in range(len(input)):
  455. if i>=MAX_LEN:
  456. break
  457. x[i] = getIndexOfWord(input[i])
  458. return x
  459. def getVocabAndMatrix(model,Embedding_size = 60):
  460. '''
  461. @summary:获取子向量的词典和子向量矩阵
  462. '''
  463. vocab = ["<pad>"]+model.index2word
  464. embedding_matrix = np.zeros((len(vocab),Embedding_size))
  465. for i in range(1,len(vocab)):
  466. embedding_matrix[i] = model[vocab[i]]
  467. return vocab,embedding_matrix
  468. def getIndexOfWord(word):
  469. global vocab_word,file_vocab_word
  470. if vocab_word is None:
  471. if os.path.exists(file_vocab_word):
  472. vocab = load(file_vocab_word)
  473. vocab_word = dict((w, i) for i, w in enumerate(np.array(vocab)))
  474. else:
  475. model = getModel_word()
  476. vocab,_ = getVocabAndMatrix(model, Embedding_size=60)
  477. vocab_word = dict((w, i) for i, w in enumerate(np.array(vocab)))
  478. save(vocab,file_vocab_word)
  479. if word in vocab_word.keys():
  480. return vocab_word[word]
  481. else:
  482. return vocab_word['<pad>']
  483. def getIndexOfWords(words):
  484. global vocab_words,file_vocab_words
  485. if vocab_words is None:
  486. if os.path.exists(file_vocab_words):
  487. vocab = load(file_vocab_words)
  488. vocab_words = dict((w, i) for i, w in enumerate(np.array(vocab)))
  489. else:
  490. model = getModel_w2v()
  491. vocab,_ = getVocabAndMatrix(model, Embedding_size=128)
  492. vocab_words = dict((w, i) for i, w in enumerate(np.array(vocab)))
  493. save(vocab,file_vocab_words)
  494. if words in vocab_words.keys():
  495. return vocab_words[words]
  496. else:
  497. return vocab_words["<pad>"]
  498. def isCellphone(phone):
  499. if phone is not None and re.search("^1\d{10}$",str(phone)) is not None:
  500. return True
  501. return False
  502. def popNoneFromDict(_dict):
  503. list_pop = []
  504. for k,v in _dict.items():
  505. if v is None or v=="":
  506. list_pop.append(k)
  507. for k in list_pop:
  508. _dict.pop(k)
  509. return _dict
  510. pattern_attachment = re.compile("\.(?P<attachment>jpg|jpeg|png|swf|tif|pdf|doc|docx|xls|xlsx|zip|rar|tar|7z|wim)$")
  511. def getAttachmentTypeFromUrl(url):
  512. _match = re.search(pattern_attachment,url)
  513. if _match is not None:
  514. return _match.groupdict().get("attachment")
  515. return None
  516. def getAttachmentUrls(sourceHtml):
  517. list_urls = []
  518. _soup = BeautifulSoup(sourceHtml,"lxml")
  519. set_types = set()
  520. list_a = _soup.find_all("a")
  521. for _a in list_a:
  522. _url = _a.attrs.get("href","")
  523. _type = getAttachmentTypeFromUrl(_url)
  524. if _type is not None:
  525. list_urls.append({"url":_url,"type":_type})
  526. list_img = _soup.find_all("img")
  527. for _img in list_img:
  528. _url = _img.attrs.get("src","")
  529. _type = getAttachmentTypeFromUrl(_url)
  530. if _type is not None:
  531. list_urls.append({"url":_url,"type":_type})
  532. return list_urls
  533. def getCurrent_date(format="%Y-%m-%d %H:%M:%S"):
  534. _time = time.strftime(format,time.localtime())
  535. return _time
  536. def log_tofile(filename):
  537. logging.basicConfig(filename=filename,level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  538. logger = logging.getLogger(__name__)
  539. def log(msg):
  540. '''
  541. @summary:打印信息
  542. '''
  543. logger.info(msg)
  544. def debug(msg):
  545. '''
  546. @summary:打印信息
  547. '''
  548. logger.debug(msg)
  549. def save(object_to_save, path):
  550. '''
  551. 保存对象
  552. @Arugs:
  553. object_to_save: 需要保存的对象
  554. @Return:
  555. 保存的路径
  556. '''
  557. with open(path, 'wb') as f:
  558. pickle.dump(object_to_save, f)
  559. def load(path):
  560. '''
  561. 读取对象
  562. @Arugs:
  563. path: 读取的路径
  564. @Return:
  565. 读取的对象
  566. '''
  567. with open(path, 'rb') as f:
  568. object1 = pickle.load(f)
  569. return object1
  570. def uniform_num(num):
  571. d1 = {'一': '1', '二': '2', '三': '3', '四': '4', '五': '5', '六': '6', '七': '7', '八': '8', '九': '9', '十': '10'}
  572. # d2 = {'A': '1', 'B': '2', 'C': '3', 'D': '4', 'E': '5', 'F': '6', 'G': '7', 'H': '8', 'I': '9', 'J': '10'}
  573. d3 = {'Ⅰ': '1', 'Ⅱ': '2', 'Ⅲ': '3', 'Ⅳ': '4', 'Ⅴ': '5', 'Ⅵ': '6', 'Ⅶ': '7'}
  574. if num.isdigit():
  575. if re.search('^0[\d]$', num):
  576. num = num[1:]
  577. return num
  578. elif re.search('^[一二三四五六七八九十]+$', num):
  579. _digit = re.search('^[一二三四五六七八九十]+$', num).group(0)
  580. if len(_digit) == 1:
  581. num = d1[_digit]
  582. elif len(_digit) == 2 and _digit[0] == '十':
  583. num = '1'+ d1[_digit[1]]
  584. elif len(_digit) == 2 and _digit[1] == '十':
  585. num = d1[_digit[0]] + '0'
  586. elif len(_digit) == 3 and _digit[1] == '十':
  587. num = d1[_digit[0]] + d1[_digit[2]]
  588. elif re.search('[ⅠⅡⅢⅣⅤⅥⅦ]', num):
  589. num = re.search('[ⅠⅡⅢⅣⅤⅥⅦ]', num).group(0)
  590. num = d3[num]
  591. return num
  592. def uniform_package_name(package_name):
  593. '''
  594. 统一规范化包号。数值类型统一为阿拉伯数字,字母统一为大写,包含施工监理等抽到前面, 例 A包监理一标段 统一为 监理A1 ; 包Ⅱ 统一为 2
  595. :param package_name: 字符串类型 包号
  596. :return:
  597. '''
  598. package_name_raw = package_name
  599. package_name = re.sub('pdf|doc|docs|xlsx|rar|\d{4}年', ' ', package_name)
  600. package_name = package_name.replace('标段(包)', '标段').replace('№', '')
  601. package_name = re.sub('\[|【', '', package_name)
  602. kw = re.search('(施工|监理|监测|勘察|设计|劳务)', package_name)
  603. name = ""
  604. if kw:
  605. name += kw.group(0)
  606. if re.search('^[a-zA-Z0-9-]{5,}$', package_name): # 五个字符以上编号
  607. _digit = re.search('^[a-zA-Z0-9-]{5,}$', package_name).group(0).upper()
  608. # print('规范化包号1', _digit)
  609. name += _digit
  610. elif re.search('(?P<eng>[a-zA-Z])包[:)]?第?(?P<num>([0-9]{1,4}|[一二三四五六七八九十]{1,4}|[ⅠⅡⅢⅣⅤⅥⅦ]{1,4}))标段?', package_name): # 处理类似 A包2标段
  611. ser = re.search('(?P<eng>[a-zA-Z])包[:)]?第?(?P<num>([0-9]{1,4}|[一二三四五六七八九十]{1,4}|[ⅠⅡⅢⅣⅤⅥⅦ]{1,4}))标段?', package_name)
  612. # print('规范化包号2', ser.group(0))
  613. _char = ser.groupdict().get('eng')
  614. if _char:
  615. _char = _char.upper()
  616. _digit = ser.groupdict().get('num')
  617. _digit = uniform_num(_digit)
  618. name += _char.upper() + _digit
  619. elif re.search('第?(?P<eng>[0-9a-zA-Z-]{1,4})?(?P<num>([0-9]{1,4}|[一二三四五六七八九十]{1,4}|[ⅠⅡⅢⅣⅤⅥⅦ]{1,4}))(标[段号的包项]?|合同[包段]|([分子]?[包标]))', package_name): # 处理类似 A包2标段
  620. ser = re.search('第?(?P<eng>[0-9a-zA-Z-]{1,4})?(?P<num>([0-9]{1,4}|[一二三四五六七八九十]{1,4}|[ⅠⅡⅢⅣⅤⅥⅦ]{1,4}))(标[段号的包项]?|合同[包段]|([分子]?[包标]))', package_name)
  621. # print('规范化包号3', ser.group(0))
  622. _char = ser.groupdict().get('eng')
  623. if _char:
  624. _char = _char.upper()
  625. _digit = ser.groupdict().get('num')
  626. _digit = uniform_num(_digit)
  627. if _char:
  628. name += _char.upper()
  629. name += _digit
  630. elif re.search('(标[段号的包项]?|项目|子项目?|([分子]?包|包[组件号]))编?号?[::]?(?P<eng>[0-9a-zA-Z-]{1,4})?(?P<num>([0-9]{1,4}|[一二三四五六七八九十]{1,4}|[ⅠⅡⅢⅣⅤⅥⅦ]{1,4}))', package_name): # 数字的统一的阿拉伯数字
  631. ser = re.search('(标[段号的包项]?|项目|子项目?|([分子]?包|包[组件号]))编?号?[::]?(?P<eng>[0-9a-zA-Z-]{1,4})?(?P<num>([0-9]{1,4}|[一二三四五六七八九十]{1,4}|[ⅠⅡⅢⅣⅤⅥⅦ]{1,4}))',package_name)
  632. # print('规范化包号4', ser.group(0))
  633. _char = ser.groupdict().get('eng')
  634. if _char:
  635. _char = _char.upper()
  636. _digit = ser.groupdict().get('num')
  637. _digit = uniform_num(_digit)
  638. if _char:
  639. name += _char.upper()
  640. name += _digit
  641. elif re.search('(标[段号的包项]|([分子]?包|包[组件号]))编?号?[::]?(?P<eng>[a-zA-Z-]{1,5})', package_name): # 数字的统一的阿拉伯数字
  642. _digit = re.search('(标[段号的包项]|([分子]?包|包[组件号]))编?号?[::]?(?P<eng>[a-zA-Z-]{1,5})', package_name).group('eng').upper()
  643. # print('规范化包号5', _digit)
  644. name += _digit
  645. elif re.search('(?P<eng>[a-zA-Z]{1,4})(标[段号的包项]|([分子]?[包标]|包[组件号]))', package_name): # 数字的统一的阿拉伯数字
  646. _digit = re.search('(?P<eng>[a-zA-Z]{1,4})(标[段号的包项]|([分子]?[包标]|包[组件号]))', package_name).group('eng').upper()
  647. # print('规范化包号6', _digit)
  648. name += _digit
  649. elif re.search('^([0-9]{1,4}|[一二三四五六七八九十]{1,4}|[ⅠⅡⅢⅣⅤⅥⅦ]{1,4})$', package_name): # 数字的统一的阿拉伯数字
  650. _digit = re.search('^([0-9]{1,4}|[一二三四五六七八九十]{1,4}|[ⅠⅡⅢⅣⅤⅥⅦ]{1,4})$', package_name).group(0)
  651. # print('规范化包号7', _digit)
  652. _digit = uniform_num(_digit)
  653. name += _digit
  654. elif re.search('^[a-zA-Z0-9-]+$', package_name):
  655. _char = re.search('^[a-zA-Z0-9-]+$', package_name).group(0)
  656. # print('规范化包号8', _char)
  657. name += _char.upper()
  658. if name == "":
  659. return package_name_raw
  660. else:
  661. if name.isdigit():
  662. name = str(int(name))
  663. # print('原始包号:%s, 处理后:%s'%(package_name, name))
  664. return name
  665. def getIndexOfWord_fool(word):
  666. if word in fool_char_to_id.keys():
  667. return fool_char_to_id[word]
  668. else:
  669. return fool_char_to_id["[UNK]"]
  670. def find_index(list_tofind,text):
  671. '''
  672. @summary: 查找所有词汇在字符串中第一次出现的位置
  673. @param:
  674. list_tofind:待查找词汇
  675. text:字符串
  676. @return: list,每个词汇第一次出现的位置
  677. '''
  678. result = []
  679. for item in list_tofind:
  680. index = text.find(item)
  681. if index>=0:
  682. result.append(index)
  683. else:
  684. result.append(-1)
  685. return result
  686. def combine(list1,list2):
  687. '''
  688. @summary:将两个list中的字符串两两拼接
  689. @param:
  690. list1:字符串list
  691. list2:字符串list
  692. @return:拼接结果list
  693. '''
  694. result = []
  695. for item1 in list1:
  696. for item2 in list2:
  697. result.append(str(item1)+str(item2))
  698. return result
  699. def getDigitsDic(unit):
  700. '''
  701. @summary:拿到中文对应的数字
  702. '''
  703. DigitsDic = {"零":0, "壹":1, "贰":2, "叁":3, "肆":4, "伍":5, "陆":6, "柒":7, "捌":8, "玖":9,
  704. "〇":0, "一":1, "二":2, "三":3, "四":4, "五":5, "六":6, "七":7, "八":8, "九":9}
  705. return DigitsDic.get(unit)
  706. def getMultipleFactor(unit):
  707. '''
  708. @summary:拿到单位对应的值
  709. '''
  710. MultipleFactor = {"兆":Decimal(1000000000000),"亿":Decimal(100000000),"万":Decimal(10000),"仟":Decimal(1000),"千":Decimal(1000),"佰":Decimal(100),"百":Decimal(100),"拾":Decimal(10),"十":Decimal(10),"元":Decimal(1),"圆":Decimal(1),"角":round(Decimal(0.1),1),"分":round(Decimal(0.01),2)}
  711. return MultipleFactor.get(unit)
  712. def getUnifyMoney(money):
  713. '''
  714. @summary:将中文金额字符串转换为数字金额
  715. @param:
  716. money:中文金额字符串
  717. @return: decimal,数据金额
  718. '''
  719. MAX_MONEY = 1000000000000
  720. MAX_NUM = 12
  721. #去掉逗号
  722. money = re.sub("[,,]","",money)
  723. money = re.sub("[^0-9.零壹贰叁肆伍陆柒捌玖拾佰仟萬億圆十百千万亿元角分]","",money)
  724. result = Decimal(0)
  725. chnDigits = ["零", "壹", "贰", "叁", "肆", "伍", "陆", "柒", "捌", "玖"]
  726. # chnFactorUnits = ["兆", "亿", "万", "仟", "佰", "拾","圆","元","角","分"]
  727. chnFactorUnits = ["兆", "亿", "万", "仟", '千', "佰", '百', "拾", '十',"圆", "元", "角", "分"] # 20240611 修复大写提取错误 '陆拾陆亿伍千柒佰零叁万肆千叁佰陆拾伍元' Decimal('11607430365')
  728. LowMoneypattern = re.compile("^[\d,]+(\.\d+)?$")
  729. BigMoneypattern = re.compile("^零?(?P<BigMoney>[%s])$"%("".join(chnDigits)))
  730. try:
  731. if re.search(LowMoneypattern,money) is not None:
  732. return Decimal(money)
  733. elif re.search(BigMoneypattern,money) is not None:
  734. return getDigitsDic(re.search(BigMoneypattern,money).group("BigMoney"))
  735. for factorUnit in chnFactorUnits:
  736. if re.search(re.compile(".*%s.*"%(factorUnit)),money) is not None:
  737. subMoneys = re.split(re.compile("%s(?!.*%s.*)"%(factorUnit,factorUnit)),money)
  738. if re.search(re.compile("^(\d+)(\.\d+)?$"),subMoneys[0]) is not None:
  739. if MAX_MONEY/getMultipleFactor(factorUnit)<Decimal(subMoneys[0]):
  740. return Decimal(0)
  741. result += Decimal(subMoneys[0])*(getMultipleFactor(factorUnit))
  742. elif len(subMoneys[0])==1:
  743. if re.search(re.compile("^[%s]$"%("".join(chnDigits))),subMoneys[0]) is not None:
  744. result += Decimal(getDigitsDic(subMoneys[0]))*(getMultipleFactor(factorUnit))
  745. # subMoneys[0]中无金额单位,不可再拆分
  746. elif subMoneys[0]=="":
  747. result += 0
  748. elif re.search(re.compile("[%s]"%("".join(chnFactorUnits))),subMoneys[0]) is None:
  749. # print(subMoneys)
  750. # subMoneys[0] = subMoneys[0][0]
  751. result += Decimal(getUnifyMoney(subMoneys[0])) * (getMultipleFactor(factorUnit))
  752. else:
  753. result += Decimal(getUnifyMoney(subMoneys[0]))*(getMultipleFactor(factorUnit))
  754. if len(subMoneys)>1:
  755. if re.search(re.compile("^(\d+(,)?)+(\.\d+)?[百千万亿]?\s?(元)?$"),subMoneys[1]) is not None:
  756. result += Decimal(subMoneys[1])
  757. elif len(subMoneys[1])==1:
  758. if re.search(re.compile("^[%s]$"%("".join(chnDigits))),subMoneys[1]) is not None:
  759. result += Decimal(getDigitsDic(subMoneys[1]))
  760. else:
  761. result += Decimal(getUnifyMoney(subMoneys[1]))
  762. break
  763. except Exception as e:
  764. # traceback.print_exc()
  765. return Decimal(0)
  766. return result
  767. def getModel_w2v():
  768. '''
  769. @summary:加载词向量
  770. '''
  771. global model_w2v,lock_model_w2v
  772. with lock_model_w2v:
  773. if model_w2v is None:
  774. model_w2v = gensim.models.KeyedVectors.load_word2vec_format(getw2vfilepath(),binary=True)
  775. return model_w2v
  776. def getModel_word():
  777. '''
  778. @summary:加载字向量
  779. '''
  780. global model_word,lock_model_w2v
  781. with lock_model_word:
  782. if model_word is None:
  783. model_word = gensim.models.KeyedVectors.load_word2vec_format(model_word_file,binary=True)
  784. return model_word
  785. # getModel_w2v()
  786. # getModel_word()
  787. def formatArea(area):
  788. if area is not None and len(area)>=3:
  789. return re.sub("[省市区县]","",area)
  790. return area
  791. def findAllIndex(substr,wholestr):
  792. '''
  793. @summary: 找到字符串的子串的所有begin_index
  794. @param:
  795. substr:子字符串
  796. wholestr:子串所在完整字符串
  797. @return: list,字符串的子串的所有begin_index
  798. '''
  799. copystr = wholestr
  800. result = []
  801. indexappend = 0
  802. while(True):
  803. index = copystr.find(substr)
  804. if index<0:
  805. break
  806. else:
  807. result.append(indexappend+index)
  808. indexappend += index+len(substr)
  809. copystr = copystr[index+len(substr):]
  810. return result
  811. def spanWindow(tokens,begin_index,end_index,size,center_include=False,word_flag = False,use_text = False,text = None):
  812. '''
  813. @summary:取得某个实体的上下文词汇
  814. @param:
  815. tokens:句子分词list
  816. begin_index:实体的开始index
  817. end_index:实体的结束index
  818. size:左右两边各取多少个词
  819. center_include:是否包含实体
  820. word_flag:词/字,默认是词
  821. @return: list,实体的上下文词汇
  822. '''
  823. if use_text:
  824. assert text is not None
  825. length_tokens = len(tokens)
  826. if begin_index>size:
  827. begin = begin_index-size
  828. else:
  829. begin = 0
  830. if end_index+size<length_tokens:
  831. end = end_index+size+1
  832. else:
  833. end = length_tokens
  834. result = []
  835. if not word_flag:
  836. result.append(tokens[begin:begin_index])
  837. if center_include:
  838. if use_text:
  839. result.append(text)
  840. else:
  841. result.append(tokens[begin_index:end_index+1])
  842. result.append(tokens[end_index+1:end])
  843. else:
  844. result.append("".join(tokens[begin:begin_index]))
  845. if center_include:
  846. if use_text:
  847. result.append(text)
  848. else:
  849. result.append("".join(tokens[begin_index:end_index+1]))
  850. result.append("".join(tokens[end_index+1:end]))
  851. #print(result)
  852. return result
  853. #根据规则补全编号或名称两边的符号
  854. def fitDataByRule(data):
  855. symbol_dict = {"(":")",
  856. "(":")",
  857. "[":"]",
  858. "【":"】",
  859. ")":"(",
  860. ")":"(",
  861. "]":"[",
  862. "】":"【"}
  863. leftSymbol_pattern = re.compile("[\((\[【]")
  864. rightSymbol_pattern = re.compile("[\))\]】]")
  865. leftfinds = re.findall(leftSymbol_pattern,data)
  866. rightfinds = re.findall(rightSymbol_pattern,data)
  867. result = data
  868. if len(leftfinds)+len(rightfinds)==0:
  869. return data
  870. elif len(leftfinds)==len(rightfinds):
  871. return data
  872. elif abs(len(leftfinds)-len(rightfinds))==1:
  873. if len(leftfinds)>len(rightfinds):
  874. if symbol_dict.get(data[0]) is not None:
  875. result = data[1:]
  876. else:
  877. #print(symbol_dict.get(leftfinds[0]))
  878. result = data+symbol_dict.get(leftfinds[0])
  879. else:
  880. if symbol_dict.get(data[-1]) is not None:
  881. result = data[:-1]
  882. else:
  883. result = symbol_dict.get(rightfinds[0])+data
  884. result = re.sub("[。]","",result)
  885. return result
  886. def embedding(datas,shape):
  887. '''
  888. @summary:查找词汇对应的词向量
  889. @param:
  890. datas:词汇的list
  891. shape:结果的shape
  892. @return: array,返回对应shape的词嵌入
  893. '''
  894. model_w2v = getModel_w2v()
  895. embed = np.zeros(shape)
  896. length = shape[1]
  897. out_index = 0
  898. #print(datas)
  899. for data in datas:
  900. index = 0
  901. for item in data:
  902. item_not_space = re.sub("\s*","",item)
  903. if index>=length:
  904. break
  905. if item_not_space in model_w2v.vocab:
  906. embed[out_index][index] = model_w2v[item_not_space]
  907. index += 1
  908. else:
  909. #embed[out_index][index] = model_w2v['unk']
  910. index += 1
  911. out_index += 1
  912. return embed
  913. def embedding_word(datas,shape):
  914. '''
  915. @summary:查找词汇对应的词向量
  916. @param:
  917. datas:词汇的list
  918. shape:结果的shape
  919. @return: array,返回对应shape的词嵌入
  920. '''
  921. model_w2v = getModel_word()
  922. embed = np.zeros(shape)
  923. length = shape[1]
  924. out_index = 0
  925. #print(datas)
  926. for data in datas:
  927. index = 0
  928. for item in str(data)[-shape[1]:]:
  929. if index>=length:
  930. break
  931. if item in model_w2v.vocab:
  932. embed[out_index][index] = model_w2v[item]
  933. index += 1
  934. else:
  935. # embed[out_index][index] = model_w2v['unk']
  936. index += 1
  937. out_index += 1
  938. return embed
  939. def formEncoding(text,shape=(100,60),expand=False):
  940. embedding = np.zeros(shape)
  941. word_model = getModel_word()
  942. for i in range(len(text)):
  943. if i>=shape[0]:
  944. break
  945. if text[i] in word_model.vocab:
  946. embedding[i] = word_model[text[i]]
  947. if expand:
  948. embedding = np.expand_dims(embedding,0)
  949. return embedding
  950. def partMoney(entity_text,input2_shape = [7]):
  951. '''
  952. @summary:对金额分段
  953. @param:
  954. entity_text:数值金额
  955. input2_shape:分类数
  956. @return: array,分段之后的独热编码
  957. '''
  958. money = float(entity_text)
  959. parts = np.zeros(input2_shape)
  960. if money<100:
  961. parts[0] = 1
  962. elif money<1000:
  963. parts[1] = 1
  964. elif money<10000:
  965. parts[2] = 1
  966. elif money<100000:
  967. parts[3] = 1
  968. elif money<1000000:
  969. parts[4] = 1
  970. elif money<10000000:
  971. parts[5] = 1
  972. else:
  973. parts[6] = 1
  974. return parts
  975. def recall(y_true, y_pred):
  976. from keras import backend as K
  977. '''
  978. 计算召回率
  979. @Argus:
  980. y_true: 正确的标签
  981. y_pred: 模型预测的标签
  982. @Return
  983. 召回率
  984. '''
  985. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  986. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  987. if c3 == 0:
  988. return 0
  989. recall = c1 / c3
  990. return recall
  991. def f1_score(y_true, y_pred):
  992. from keras import backend as K
  993. '''
  994. 计算F1
  995. @Argus:
  996. y_true: 正确的标签
  997. y_pred: 模型预测的标签
  998. @Return
  999. F1值
  1000. '''
  1001. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  1002. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  1003. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  1004. precision = c1 / c2
  1005. if c3 == 0:
  1006. recall = 0
  1007. else:
  1008. recall = c1 / c3
  1009. f1_score = 2 * (precision * recall) / (precision + recall)
  1010. return f1_score
  1011. def precision(y_true, y_pred):
  1012. from keras import backend as K
  1013. '''
  1014. 计算精确率
  1015. @Argus:
  1016. y_true: 正确的标签
  1017. y_pred: 模型预测的标签
  1018. @Return
  1019. 精确率
  1020. '''
  1021. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  1022. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  1023. precision = c1 / c2
  1024. return precision
  1025. # def print_metrics(history):
  1026. # '''
  1027. # 制作每次迭代的各metrics变化图片
  1028. #
  1029. # @Arugs:
  1030. # history: 模型训练迭代的历史记录
  1031. # '''
  1032. # import matplotlib.pyplot as plt
  1033. #
  1034. # # loss图
  1035. # loss = history.history['loss']
  1036. # val_loss = history.history['val_loss']
  1037. # epochs = range(1, len(loss) + 1)
  1038. # plt.subplot(2, 2, 1)
  1039. # plt.plot(epochs, loss, 'bo', label='Training loss')
  1040. # plt.plot(epochs, val_loss, 'b', label='Validation loss')
  1041. # plt.title('Training and validation loss')
  1042. # plt.xlabel('Epochs')
  1043. # plt.ylabel('Loss')
  1044. # plt.legend()
  1045. #
  1046. # # f1图
  1047. # f1 = history.history['f1_score']
  1048. # val_f1 = history.history['val_f1_score']
  1049. # plt.subplot(2, 2, 2)
  1050. # plt.plot(epochs, f1, 'bo', label='Training f1')
  1051. # plt.plot(epochs, val_f1, 'b', label='Validation f1')
  1052. # plt.title('Training and validation f1')
  1053. # plt.xlabel('Epochs')
  1054. # plt.ylabel('F1')
  1055. # plt.legend()
  1056. #
  1057. # # precision图
  1058. # prec = history.history['precision']
  1059. # val_prec = history.history['val_precision']
  1060. # plt.subplot(2, 2, 3)
  1061. # plt.plot(epochs, prec, 'bo', label='Training precision')
  1062. # plt.plot(epochs, val_prec, 'b', label='Validation pecision')
  1063. # plt.title('Training and validation precision')
  1064. # plt.xlabel('Epochs')
  1065. # plt.ylabel('Precision')
  1066. # plt.legend()
  1067. #
  1068. # # recall图
  1069. # recall = history.history['recall']
  1070. # val_recall = history.history['val_recall']
  1071. # plt.subplot(2, 2, 4)
  1072. # plt.plot(epochs, recall, 'bo', label='Training recall')
  1073. # plt.plot(epochs, val_recall, 'b', label='Validation recall')
  1074. # plt.title('Training and validation recall')
  1075. # plt.xlabel('Epochs')
  1076. # plt.ylabel('Recall')
  1077. # plt.legend()
  1078. #
  1079. # plt.show()
  1080. if __name__=="__main__":
  1081. print(fool_char_to_id[">"])
  1082. # model = getModel_w2v()
  1083. # vocab,matrix = getVocabAndMatrix(model, Embedding_size=128)
  1084. # save([vocab,matrix],"vocabMatrix_words.pk")