Utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942
  1. '''
  2. Created on 2018年12月20日
  3. @author: User
  4. '''
  5. import numpy as np
  6. import re
  7. import gensim
  8. from keras import backend as K
  9. import os
  10. from threading import RLock
  11. from pai_tf_predict_proto import tf_predict_pb2
  12. import requests
  13. import time
  14. import smtplib
  15. from email.mime.application import MIMEApplication
  16. from email.mime.multipart import MIMEMultipart
  17. from email.utils import formataddr
  18. model_w2v = None
  19. lock_model_w2v = RLock()
  20. USE_PAI_EAS = False
  21. Lazy_load = False
  22. ILLEGAL_CHARACTERS_RE = re.compile(r'[\000-\010]|[\013-\014]|[\016-\037]')
  23. import traceback
  24. def sendEmail(host,username,password,receivers,attachs=[]):
  25. try:
  26. #处理附件
  27. msg = MIMEMultipart()
  28. msg["From"] = formataddr(["广州比地数据科技有限公司",username])
  29. msg["To"] = formataddr(["客户",receivers[0]])
  30. msg["Subject"] = "数据导出服务"
  31. for at in attachs:
  32. xlsfile = MIMEApplication(open(at,"rb").read())
  33. xlsfile.add_header("Content-Disposition","attachment",filename=('gbk', '', at.split("/")[-1]))
  34. log(at.split("/")[-1])
  35. msg.attach(xlsfile)
  36. server = smtplib.SMTP()
  37. server.connect(host,25)
  38. server.login(username,password)
  39. server.sendmail(username,receivers,msg.as_string())
  40. log("发送邮件成功%s"%str(attachs))
  41. server.close()
  42. except Exception as e:
  43. traceback.print_exc()
  44. log("发送邮件错误%s"%str(e))
  45. def getLegal_str(_str):
  46. if _str is not None:
  47. return ILLEGAL_CHARACTERS_RE.sub("",str(_str))
  48. def getRow_ots_primary(row):
  49. _dict = dict()
  50. if row is None:
  51. return _dict
  52. for part in row.attribute_columns:
  53. _dict[part[0]] = part[1]
  54. for part in row.primary_key:
  55. _dict[part[0]] = part[1]
  56. return _dict
  57. def getRow_ots(rows):
  58. list_dict = []
  59. for row in rows:
  60. _dict = dict()
  61. for part in row:
  62. for v in part:
  63. _dict[v[0]] = v[1]
  64. list_dict.append(_dict)
  65. return list_dict
  66. def getw2vfilepath():
  67. w2vfile = os.path.dirname(__file__)+"/../wiki_128_word_embedding_new.vector"
  68. if os.path.exists(w2vfile):
  69. return w2vfile
  70. return "wiki_128_word_embedding_new.vector"
  71. def getLazyLoad():
  72. global Lazy_load
  73. return Lazy_load
  74. def get_file_name(url, headers):
  75. filename = ''
  76. if 'Content-Disposition' in headers and headers['Content-Disposition']:
  77. disposition_split = headers['Content-Disposition'].split(';')
  78. if len(disposition_split) > 1:
  79. if disposition_split[1].strip().lower().startswith('filename='):
  80. file_name = disposition_split[1].split('=')
  81. if len(file_name) > 1:
  82. filename = file_name[1]
  83. if not filename and os.path.basename(url):
  84. filename = os.path.basename(url).split("?")[0]
  85. if not filename:
  86. return time.time()
  87. return filename
  88. model_word_file = os.path.dirname(__file__)+"/../singlew2v_model.vector"
  89. model_word = None
  90. lock_model_word = RLock()
  91. from decimal import Decimal
  92. import logging
  93. logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  94. logger = logging.getLogger(__name__)
  95. import pickle
  96. import os
  97. import json
  98. #自定义jsonEncoder
  99. class MyEncoder(json.JSONEncoder):
  100. def __init__(self):
  101. import numpy as np
  102. global np
  103. def default(self, obj):
  104. if isinstance(obj, np.ndarray):
  105. return obj.tolist()
  106. elif isinstance(obj, bytes):
  107. return str(obj, encoding='utf-8')
  108. elif isinstance(obj, (np.float_, np.float16, np.float32,
  109. np.float64)):
  110. return float(obj)
  111. elif isinstance(obj,(np.int64,np.int32)):
  112. return int(obj)
  113. return json.JSONEncoder.default(self, obj)
  114. vocab_word = None
  115. vocab_words = None
  116. file_vocab_word = "vocab_word.pk"
  117. file_vocab_words = "vocab_words.pk"
  118. selffool_authorization = "NjlhMWFjMjVmNWYyNzI0MjY1OGQ1M2Y0ZmY4ZGY0Mzg3Yjc2MTVjYg=="
  119. selffool_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/selffool_gpu"
  120. selffool_seg_authorization = "OWUwM2Q0ZmE3YjYxNzU4YzFiMjliNGVkMTA3MzJkNjQ2MzJiYzBhZg=="
  121. selffool_seg_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/selffool_seg_gpu"
  122. codename_authorization = "Y2M5MDUxMzU1MTU4OGM3ZDk2ZmEzYjkxYmYyYzJiZmUyYTgwYTg5NA=="
  123. codename_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/codename_gpu"
  124. form_item_authorization = "ODdkZWY1YWY0NmNhNjU2OTI2NWY4YmUyM2ZlMDg1NTZjOWRkYTVjMw=="
  125. form_item_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/form"
  126. person_authorization = "N2I2MDU2N2Q2MGQ0ZWZlZGM3NDkyNTA1Nzc4YmM5OTlhY2MxZGU1Mw=="
  127. person_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/person"
  128. role_authorization = "OWM1ZDg5ZDEwYTEwYWI4OGNjYmRlMmQ1NzYwNWNlZGZkZmRmMjE4OQ=="
  129. role_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/role"
  130. money_authorization = "MDQyNjc2ZDczYjBhYmM4Yzc4ZGI4YjRmMjc3NGI5NTdlNzJiY2IwZA=="
  131. money_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/money"
  132. codeclasses_authorization = "MmUyNWIxZjQ2NjAzMWJlMGIzYzkxMjMzNWY5OWI3NzJlMWQ1ZjY4Yw=="
  133. codeclasses_url = "http://pai-eas-vpc.cn-beijing.aliyuncs.com/api/predict/codeclasses"
  134. def viterbi_decode(score, transition_params):
  135. """Decode the highest scoring sequence of tags outside of TensorFlow.
  136. This should only be used at test time.
  137. Args:
  138. score: A [seq_len, num_tags] matrix of unary potentials.
  139. transition_params: A [num_tags, num_tags] matrix of binary potentials.
  140. Returns:
  141. viterbi: A [seq_len] list of integers containing the highest scoring tag
  142. indices.
  143. viterbi_score: A float containing the score for the Viterbi sequence.
  144. """
  145. trellis = np.zeros_like(score)
  146. backpointers = np.zeros_like(score, dtype=np.int32)
  147. trellis[0] = score[0]
  148. for t in range(1, score.shape[0]):
  149. v = np.expand_dims(trellis[t - 1], 1) + transition_params
  150. trellis[t] = score[t] + np.max(v, 0)
  151. backpointers[t] = np.argmax(v, 0)
  152. viterbi = [np.argmax(trellis[-1])]
  153. for bp in reversed(backpointers[1:]):
  154. viterbi.append(bp[viterbi[-1]])
  155. viterbi.reverse()
  156. viterbi_score = np.max(trellis[-1])
  157. return viterbi, viterbi_score
  158. import ctypes
  159. import inspect
  160. def _async_raise(tid, exctype):
  161. """raises the exception, performs cleanup if needed"""
  162. tid = ctypes.c_long(tid)
  163. if not inspect.isclass(exctype):
  164. exctype = type(exctype)
  165. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  166. if res == 0:
  167. raise ValueError("invalid thread id")
  168. elif res != 1:
  169. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  170. raise SystemError("PyThreadState_SetAsyncExc failed")
  171. def stop_thread(thread):
  172. _async_raise(thread.ident, SystemExit)
  173. def limitRun(sess,list_output,feed_dict,MAX_BATCH=1024):
  174. len_sample = 0
  175. if len(feed_dict.keys())>0:
  176. len_sample = len(feed_dict[list(feed_dict.keys())[0]])
  177. if len_sample>MAX_BATCH:
  178. list_result = [[] for _ in range(len(list_output))]
  179. _begin = 0
  180. while(_begin<len_sample):
  181. new_dict = dict()
  182. for _key in feed_dict.keys():
  183. new_dict[_key] = feed_dict[_key][_begin:_begin+MAX_BATCH]
  184. _output = sess.run(list_output,feed_dict=new_dict)
  185. for _index in range(len(list_output)):
  186. list_result[_index].extend(_output[_index])
  187. _begin += MAX_BATCH
  188. else:
  189. list_result = sess.run(list_output,feed_dict=feed_dict)
  190. return list_result
  191. def get_values(response,output_name):
  192. """
  193. Get the value of a specified output tensor
  194. :param output_name: name of the output tensor
  195. :return: the content of the output tensor
  196. """
  197. output = response.outputs[output_name]
  198. if output.dtype == tf_predict_pb2.DT_FLOAT:
  199. _value = output.float_val
  200. elif output.dtype == tf_predict_pb2.DT_INT8 or output.dtype == tf_predict_pb2.DT_INT16 or \
  201. output.dtype == tf_predict_pb2.DT_INT32:
  202. _value = output.int_val
  203. elif output.dtype == tf_predict_pb2.DT_INT64:
  204. _value = output.int64_val
  205. elif output.dtype == tf_predict_pb2.DT_DOUBLE:
  206. _value = output.double_val
  207. elif output.dtype == tf_predict_pb2.DT_STRING:
  208. _value = output.string_val
  209. elif output.dtype == tf_predict_pb2.DT_BOOL:
  210. _value = output.bool_val
  211. return np.array(_value).reshape(response.outputs[output_name].array_shape.dim)
  212. def vpc_requests(url,authorization,request_data,list_outputs):
  213. headers = {"Authorization": authorization}
  214. dict_outputs = dict()
  215. response = tf_predict_pb2.PredictResponse()
  216. resp = requests.post(url, data=request_data, headers=headers)
  217. if resp.status_code != 200:
  218. print(resp.status_code,resp.content)
  219. log("调用pai-eas接口出错,authorization:"+str(authorization))
  220. return None
  221. else:
  222. response = tf_predict_pb2.PredictResponse()
  223. response.ParseFromString(resp.content)
  224. for _output in list_outputs:
  225. dict_outputs[_output] = get_values(response, _output)
  226. return dict_outputs
  227. def encodeInput(data,word_len,word_flag=True,userFool=False):
  228. result = []
  229. out_index = 0
  230. for item in data:
  231. if out_index in [0]:
  232. list_word = item[-word_len:]
  233. else:
  234. list_word = item[:word_len]
  235. temp = []
  236. if word_flag:
  237. for word in list_word:
  238. if userFool:
  239. temp.append(getIndexOfWord_fool(word))
  240. else:
  241. temp.append(getIndexOfWord(word))
  242. list_append = []
  243. temp_len = len(temp)
  244. while(temp_len<word_len):
  245. if userFool:
  246. list_append.append(0)
  247. else:
  248. list_append.append(getIndexOfWord("<pad>"))
  249. temp_len += 1
  250. if out_index in [0]:
  251. temp = list_append+temp
  252. else:
  253. temp = temp+list_append
  254. else:
  255. for words in list_word:
  256. temp.append(getIndexOfWords(words))
  257. list_append = []
  258. temp_len = len(temp)
  259. while(temp_len<word_len):
  260. list_append.append(getIndexOfWords("<pad>"))
  261. temp_len += 1
  262. if out_index in [0,1]:
  263. temp = list_append+temp
  264. else:
  265. temp = temp+list_append
  266. result.append(temp)
  267. out_index += 1
  268. return result
  269. def encodeInput_form(input,MAX_LEN=30):
  270. x = np.zeros([MAX_LEN])
  271. for i in range(len(input)):
  272. if i>=MAX_LEN:
  273. break
  274. x[i] = getIndexOfWord(input[i])
  275. return x
  276. def getVocabAndMatrix(model,Embedding_size = 60):
  277. '''
  278. @summary:获取子向量的词典和子向量矩阵
  279. '''
  280. vocab = ["<pad>"]+model.index2word
  281. embedding_matrix = np.zeros((len(vocab),Embedding_size))
  282. for i in range(1,len(vocab)):
  283. embedding_matrix[i] = model[vocab[i]]
  284. return vocab,embedding_matrix
  285. def getIndexOfWord(word):
  286. global vocab_word,file_vocab_word
  287. if vocab_word is None:
  288. if os.path.exists(file_vocab_word):
  289. vocab = load(file_vocab_word)
  290. vocab_word = dict((w, i) for i, w in enumerate(np.array(vocab)))
  291. else:
  292. model = getModel_word()
  293. vocab,_ = getVocabAndMatrix(model, Embedding_size=60)
  294. vocab_word = dict((w, i) for i, w in enumerate(np.array(vocab)))
  295. save(vocab,file_vocab_word)
  296. if word in vocab_word.keys():
  297. return vocab_word[word]
  298. else:
  299. return vocab_word['<pad>']
  300. def getIndexOfWords(words):
  301. global vocab_words,file_vocab_words
  302. if vocab_words is None:
  303. if os.path.exists(file_vocab_words):
  304. vocab = load(file_vocab_words)
  305. vocab_words = dict((w, i) for i, w in enumerate(np.array(vocab)))
  306. else:
  307. model = getModel_w2v()
  308. vocab,_ = getVocabAndMatrix(model, Embedding_size=128)
  309. vocab_words = dict((w, i) for i, w in enumerate(np.array(vocab)))
  310. save(vocab,file_vocab_words)
  311. if words in vocab_words.keys():
  312. return vocab_words[words]
  313. else:
  314. return vocab_words["<pad>"]
  315. def log_tofile(filename):
  316. logging.basicConfig(filename=filename,level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  317. logger = logging.getLogger(__name__)
  318. def log(msg):
  319. '''
  320. @summary:打印信息
  321. '''
  322. logger.info(msg)
  323. def debug(msg):
  324. '''
  325. @summary:打印信息
  326. '''
  327. logger.debug(msg)
  328. def save(object_to_save, path):
  329. '''
  330. 保存对象
  331. @Arugs:
  332. object_to_save: 需要保存的对象
  333. @Return:
  334. 保存的路径
  335. '''
  336. with open(path, 'wb') as f:
  337. pickle.dump(object_to_save, f)
  338. def load(path):
  339. '''
  340. 读取对象
  341. @Arugs:
  342. path: 读取的路径
  343. @Return:
  344. 读取的对象
  345. '''
  346. with open(path, 'rb') as f:
  347. object1 = pickle.load(f)
  348. return object1
  349. def getIndexOfWord_fool(word):
  350. if word in fool_char_to_id.keys():
  351. return fool_char_to_id[word]
  352. else:
  353. return fool_char_to_id["[UNK]"]
  354. def find_index(list_tofind,text):
  355. '''
  356. @summary: 查找所有词汇在字符串中第一次出现的位置
  357. @param:
  358. list_tofind:待查找词汇
  359. text:字符串
  360. @return: list,每个词汇第一次出现的位置
  361. '''
  362. result = []
  363. for item in list_tofind:
  364. index = text.find(item)
  365. if index>=0:
  366. result.append(index)
  367. else:
  368. result.append(-1)
  369. return result
  370. def combine(list1,list2):
  371. '''
  372. @summary:将两个list中的字符串两两拼接
  373. @param:
  374. list1:字符串list
  375. list2:字符串list
  376. @return:拼接结果list
  377. '''
  378. result = []
  379. for item1 in list1:
  380. for item2 in list2:
  381. result.append(str(item1)+str(item2))
  382. return result
  383. def getDigitsDic(unit):
  384. '''
  385. @summary:拿到中文对应的数字
  386. '''
  387. DigitsDic = {"零":0, "壹":1, "贰":2, "叁":3, "肆":4, "伍":5, "陆":6, "柒":7, "捌":8, "玖":9,
  388. "〇":0, "一":1, "二":2, "三":3, "四":4, "五":5, "六":6, "七":7, "八":8, "九":9}
  389. return DigitsDic.get(unit)
  390. def getMultipleFactor(unit):
  391. '''
  392. @summary:拿到单位对应的值
  393. '''
  394. MultipleFactor = {"兆":Decimal(1000000000000),"亿":Decimal(100000000),"万":Decimal(10000),"仟":Decimal(1000),"千":Decimal(1000),"佰":Decimal(100),"百":Decimal(100),"拾":Decimal(10),"十":Decimal(10),"元":Decimal(1),"角":round(Decimal(0.1),1),"分":round(Decimal(0.01),2)}
  395. return MultipleFactor.get(unit)
  396. def getUnifyMoney(money):
  397. '''
  398. @summary:将中文金额字符串转换为数字金额
  399. @param:
  400. money:中文金额字符串
  401. @return: decimal,数据金额
  402. '''
  403. MAX_NUM = 12
  404. #去掉逗号
  405. money = re.sub("[,,]","",money)
  406. money = re.sub("[^0-9.零壹贰叁肆伍陆柒捌玖拾佰仟萬億〇一二三四五六七八九十百千万亿元角分]","",money)
  407. result = Decimal(0)
  408. chnDigits = ["零", "壹", "贰", "叁", "肆", "伍", "陆", "柒", "捌", "玖"]
  409. chnFactorUnits = ["兆", "亿", "万", "仟", "佰", "拾","元","角","分"]
  410. LowMoneypattern = re.compile("^[\d,]+(\.\d+)?$")
  411. BigMoneypattern = re.compile("^零?(?P<BigMoney>[%s])$"%("".join(chnDigits)))
  412. if re.search(LowMoneypattern,money) is not None:
  413. return Decimal(money)
  414. elif re.search(BigMoneypattern,money) is not None:
  415. return getDigitsDic(re.search(BigMoneypattern,money).group("BigMoney"))
  416. for factorUnit in chnFactorUnits:
  417. if re.search(re.compile(".*%s.*"%(factorUnit)),money) is not None:
  418. subMoneys = re.split(re.compile("%s(?!.*%s.*)"%(factorUnit,factorUnit)),money)
  419. if re.search(re.compile("^(\d+(,)?)+(\.\d+)?$"),subMoneys[0]) is not None:
  420. result += Decimal(subMoneys[0])*(getMultipleFactor(factorUnit))
  421. elif len(subMoneys[0])==1:
  422. if re.search(re.compile("^[%s]$"%("".join(chnDigits))),subMoneys[0]) is not None:
  423. result += Decimal(getDigitsDic(subMoneys[0]))*(getMultipleFactor(factorUnit))
  424. else:
  425. result += Decimal(getUnifyMoney(subMoneys[0]))*(getMultipleFactor(factorUnit))
  426. if len(subMoneys)>1:
  427. if re.search(re.compile("^(\d+(,)?)+(\.\d+)?[百千万亿]?\s?(元)?$"),subMoneys[1]) is not None:
  428. result += Decimal(subMoneys[1])
  429. elif len(subMoneys[1])==1:
  430. if re.search(re.compile("^[%s]$"%("".join(chnDigits))),subMoneys[1]) is not None:
  431. result += Decimal(getDigitsDic(subMoneys[1]))
  432. else:
  433. result += Decimal(getUnifyMoney(subMoneys[1]))
  434. break
  435. return result
  436. def getModel_w2v():
  437. '''
  438. @summary:加载词向量
  439. '''
  440. global model_w2v,lock_model_w2v
  441. with lock_model_w2v:
  442. if model_w2v is None:
  443. model_w2v = gensim.models.KeyedVectors.load_word2vec_format(getw2vfilepath(),binary=True)
  444. return model_w2v
  445. def getModel_word():
  446. '''
  447. @summary:加载字向量
  448. '''
  449. global model_word,lock_model_w2v
  450. with lock_model_word:
  451. if model_word is None:
  452. model_word = gensim.models.KeyedVectors.load_word2vec_format(model_word_file,binary=True)
  453. return model_word
  454. # getModel_w2v()
  455. # getModel_word()
  456. def findAllIndex(substr,wholestr):
  457. '''
  458. @summary: 找到字符串的子串的所有begin_index
  459. @param:
  460. substr:子字符串
  461. wholestr:子串所在完整字符串
  462. @return: list,字符串的子串的所有begin_index
  463. '''
  464. copystr = wholestr
  465. result = []
  466. indexappend = 0
  467. while(True):
  468. index = copystr.find(substr)
  469. if index<0:
  470. break
  471. else:
  472. result.append(indexappend+index)
  473. indexappend += index+len(substr)
  474. copystr = copystr[index+len(substr):]
  475. return result
  476. def spanWindow(tokens,begin_index,end_index,size,center_include=False,word_flag = False,use_text = False,text = None):
  477. '''
  478. @summary:取得某个实体的上下文词汇
  479. @param:
  480. tokens:句子分词list
  481. begin_index:实体的开始index
  482. end_index:实体的结束index
  483. size:左右两边各取多少个词
  484. center_include:是否包含实体
  485. word_flag:词/字,默认是词
  486. @return: list,实体的上下文词汇
  487. '''
  488. if use_text:
  489. assert text is not None
  490. length_tokens = len(tokens)
  491. if begin_index>size:
  492. begin = begin_index-size
  493. else:
  494. begin = 0
  495. if end_index+size<length_tokens:
  496. end = end_index+size+1
  497. else:
  498. end = length_tokens
  499. result = []
  500. if not word_flag:
  501. result.append(tokens[begin:begin_index])
  502. if center_include:
  503. if use_text:
  504. result.append(text)
  505. else:
  506. result.append(tokens[begin_index:end_index+1])
  507. result.append(tokens[end_index+1:end])
  508. else:
  509. result.append("".join(tokens[begin:begin_index]))
  510. if center_include:
  511. if use_text:
  512. result.append(text)
  513. else:
  514. result.append("".join(tokens[begin_index:end_index+1]))
  515. result.append("".join(tokens[end_index+1:end]))
  516. #print(result)
  517. return result
  518. #根据规则补全编号或名称两边的符号
  519. def fitDataByRule(data):
  520. symbol_dict = {"(":")",
  521. "(":")",
  522. "[":"]",
  523. "【":"】",
  524. ")":"(",
  525. ")":"(",
  526. "]":"[",
  527. "】":"【"}
  528. leftSymbol_pattern = re.compile("[\((\[【]")
  529. rightSymbol_pattern = re.compile("[\))\]】]")
  530. leftfinds = re.findall(leftSymbol_pattern,data)
  531. rightfinds = re.findall(rightSymbol_pattern,data)
  532. result = data
  533. if len(leftfinds)+len(rightfinds)==0:
  534. return data
  535. elif len(leftfinds)==len(rightfinds):
  536. return data
  537. elif abs(len(leftfinds)-len(rightfinds))==1:
  538. if len(leftfinds)>len(rightfinds):
  539. if symbol_dict.get(data[0]) is not None:
  540. result = data[1:]
  541. else:
  542. #print(symbol_dict.get(leftfinds[0]))
  543. result = data+symbol_dict.get(leftfinds[0])
  544. else:
  545. if symbol_dict.get(data[-1]) is not None:
  546. result = data[:-1]
  547. else:
  548. result = symbol_dict.get(rightfinds[0])+data
  549. result = re.sub("[。]","",result)
  550. return result
  551. def embedding(datas,shape):
  552. '''
  553. @summary:查找词汇对应的词向量
  554. @param:
  555. datas:词汇的list
  556. shape:结果的shape
  557. @return: array,返回对应shape的词嵌入
  558. '''
  559. model_w2v = getModel_w2v()
  560. embed = np.zeros(shape)
  561. length = shape[1]
  562. out_index = 0
  563. #print(datas)
  564. for data in datas:
  565. index = 0
  566. for item in data:
  567. item_not_space = re.sub("\s*","",item)
  568. if index>=length:
  569. break
  570. if item_not_space in model_w2v.vocab:
  571. embed[out_index][index] = model_w2v[item_not_space]
  572. index += 1
  573. else:
  574. #embed[out_index][index] = model_w2v['unk']
  575. index += 1
  576. out_index += 1
  577. return embed
  578. def embedding_word(datas,shape):
  579. '''
  580. @summary:查找词汇对应的词向量
  581. @param:
  582. datas:词汇的list
  583. shape:结果的shape
  584. @return: array,返回对应shape的词嵌入
  585. '''
  586. model_w2v = getModel_word()
  587. embed = np.zeros(shape)
  588. length = shape[1]
  589. out_index = 0
  590. #print(datas)
  591. for data in datas:
  592. index = 0
  593. for item in str(data)[-shape[1]:]:
  594. if index>=length:
  595. break
  596. if item in model_w2v.vocab:
  597. embed[out_index][index] = model_w2v[item]
  598. index += 1
  599. else:
  600. # embed[out_index][index] = model_w2v['unk']
  601. index += 1
  602. out_index += 1
  603. return embed
  604. def formEncoding(text,shape=(100,60),expand=False):
  605. embedding = np.zeros(shape)
  606. word_model = getModel_word()
  607. for i in range(len(text)):
  608. if i>=shape[0]:
  609. break
  610. if text[i] in word_model.vocab:
  611. embedding[i] = word_model[text[i]]
  612. if expand:
  613. embedding = np.expand_dims(embedding,0)
  614. return embedding
  615. def partMoney(entity_text,input2_shape = [7]):
  616. '''
  617. @summary:对金额分段
  618. @param:
  619. entity_text:数值金额
  620. input2_shape:分类数
  621. @return: array,分段之后的独热编码
  622. '''
  623. money = float(entity_text)
  624. parts = np.zeros(input2_shape)
  625. if money<100:
  626. parts[0] = 1
  627. elif money<1000:
  628. parts[1] = 1
  629. elif money<10000:
  630. parts[2] = 1
  631. elif money<100000:
  632. parts[3] = 1
  633. elif money<1000000:
  634. parts[4] = 1
  635. elif money<10000000:
  636. parts[5] = 1
  637. else:
  638. parts[6] = 1
  639. return parts
  640. def recall(y_true, y_pred):
  641. '''
  642. 计算召回率
  643. @Argus:
  644. y_true: 正确的标签
  645. y_pred: 模型预测的标签
  646. @Return
  647. 召回率
  648. '''
  649. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  650. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  651. if c3 == 0:
  652. return 0
  653. recall = c1 / c3
  654. return recall
  655. def f1_score(y_true, y_pred):
  656. '''
  657. 计算F1
  658. @Argus:
  659. y_true: 正确的标签
  660. y_pred: 模型预测的标签
  661. @Return
  662. F1值
  663. '''
  664. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  665. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  666. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  667. precision = c1 / c2
  668. if c3 == 0:
  669. recall = 0
  670. else:
  671. recall = c1 / c3
  672. f1_score = 2 * (precision * recall) / (precision + recall)
  673. return f1_score
  674. def precision(y_true, y_pred):
  675. '''
  676. 计算精确率
  677. @Argus:
  678. y_true: 正确的标签
  679. y_pred: 模型预测的标签
  680. @Return
  681. 精确率
  682. '''
  683. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  684. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  685. precision = c1 / c2
  686. return precision
  687. # def print_metrics(history):
  688. # '''
  689. # 制作每次迭代的各metrics变化图片
  690. #
  691. # @Arugs:
  692. # history: 模型训练迭代的历史记录
  693. # '''
  694. # import matplotlib.pyplot as plt
  695. #
  696. # # loss图
  697. # loss = history.history['loss']
  698. # val_loss = history.history['val_loss']
  699. # epochs = range(1, len(loss) + 1)
  700. # plt.subplot(2, 2, 1)
  701. # plt.plot(epochs, loss, 'bo', label='Training loss')
  702. # plt.plot(epochs, val_loss, 'b', label='Validation loss')
  703. # plt.title('Training and validation loss')
  704. # plt.xlabel('Epochs')
  705. # plt.ylabel('Loss')
  706. # plt.legend()
  707. #
  708. # # f1图
  709. # f1 = history.history['f1_score']
  710. # val_f1 = history.history['val_f1_score']
  711. # plt.subplot(2, 2, 2)
  712. # plt.plot(epochs, f1, 'bo', label='Training f1')
  713. # plt.plot(epochs, val_f1, 'b', label='Validation f1')
  714. # plt.title('Training and validation f1')
  715. # plt.xlabel('Epochs')
  716. # plt.ylabel('F1')
  717. # plt.legend()
  718. #
  719. # # precision图
  720. # prec = history.history['precision']
  721. # val_prec = history.history['val_precision']
  722. # plt.subplot(2, 2, 3)
  723. # plt.plot(epochs, prec, 'bo', label='Training precision')
  724. # plt.plot(epochs, val_prec, 'b', label='Validation pecision')
  725. # plt.title('Training and validation precision')
  726. # plt.xlabel('Epochs')
  727. # plt.ylabel('Precision')
  728. # plt.legend()
  729. #
  730. # # recall图
  731. # recall = history.history['recall']
  732. # val_recall = history.history['val_recall']
  733. # plt.subplot(2, 2, 4)
  734. # plt.plot(epochs, recall, 'bo', label='Training recall')
  735. # plt.plot(epochs, val_recall, 'b', label='Validation recall')
  736. # plt.title('Training and validation recall')
  737. # plt.xlabel('Epochs')
  738. # plt.ylabel('Recall')
  739. # plt.legend()
  740. #
  741. # plt.show()
  742. import pandas as pd
  743. dict_name_locations = {}
  744. dict_id_location = {}
  745. def getLocationDict():
  746. global dict_name_locations,dict_id_location
  747. df = pd.read_excel(os.path.dirname(__file__)+"/省份信息.xlsx")
  748. for _id,_cname,_parentid,_ctype in zip(df["id"],df["cname"],df["parentid"],df["ctype"]):
  749. _dict = {"id":_id,"cname":_cname,"parentid":_parentid,"ctype":_ctype}
  750. dict_id_location[_id] =_dict
  751. if _cname not in dict_name_locations:
  752. dict_name_locations[_cname] = []
  753. dict_name_locations[_cname].append(_dict)
  754. getLocationDict()
  755. def getProvinceCityDistrict(loc):
  756. list_loc = dict_name_locations.get(loc,[])
  757. list_result = []
  758. for _loc in list_loc:
  759. dict_loc_parents = {}
  760. _current_loc = _loc
  761. while(True):
  762. if _current_loc is None:
  763. break
  764. if _current_loc.get("ctype")>=20:
  765. dict_loc_parents[_current_loc.get("ctype")] = _current_loc
  766. _current_loc = dict_id_location.get(_current_loc.get("parentid"))
  767. if len(dict_loc_parents.keys())>0:
  768. list_result.append(dict_loc_parents)
  769. return list_result
  770. def chooseLocation(list_result):
  771. province = ""
  772. city = ""
  773. district = ""
  774. dict_province = {}
  775. for _dict in list_result:
  776. province = _dict.get(20,{}).get("cname","")
  777. if province!="":
  778. if province not in dict_province:
  779. dict_province[province] = 0
  780. dict_province[province] += 1
  781. max_province = ""
  782. max_province_count = 0
  783. for k,v in dict_province.items():
  784. if v>max_province_count:
  785. max_province = k
  786. max_province_count = v
  787. if len(list_result)>0:
  788. list_result.sort(key=lambda x:len(list(x.keys())),reverse=True)
  789. for _dict in list_result:
  790. _province = _dict.get(20,{}).get("cname","")
  791. if _province!=max_province:
  792. continue
  793. province = _province
  794. city = _dict.get(30,{}).get("cname","")
  795. district = _dict.get(40,{}).get("cname","")
  796. break
  797. return province,city,district
  798. def getCurrent_date(format="%Y-%m-%d %H:%M:%S"):
  799. _time = time.strftime(format,time.localtime())
  800. return _time
  801. def getLocation(_str):
  802. list_names = list(dict_name_locations.keys())
  803. name_pattern = "(?P<locations>%s)"%"|".join(list_names)
  804. name_pattern_subcompany = "(%s)分"%name_pattern
  805. list_result = []
  806. for _iter in re.finditer(name_pattern_subcompany,_str):
  807. _loc = _iter.groupdict().get("locations")
  808. list_result.extend(getProvinceCityDistrict(_loc))
  809. if len(list_result)>0:
  810. return chooseLocation(list_result)
  811. for _iter in re.finditer(name_pattern,_str):
  812. _loc = _iter.groupdict().get("locations")
  813. list_result.extend(getProvinceCityDistrict(_loc))
  814. return chooseLocation(list_result)
  815. if __name__=="__main__":
  816. print(getLocation("佛山市顺德区顺控路桥投资有限公司"))
  817. # print(fool_char_to_id[">"])
  818. # model = getModel_w2v()
  819. # vocab,matrix = getVocabAndMatrix(model, Embedding_size=128)
  820. # save([vocab,matrix],"vocabMatrix_words.pk")
  821. pass