compare.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. '''
  2. Created on 2019年6月13日
  3. @author: User
  4. '''
  5. import fool
  6. from bi_lstm_crf import *
  7. import pandas as pd
  8. import codecs
  9. import re
  10. ''''''
  11. def compare(text):
  12. print(fool.ner(text))
  13. '''
  14. bilstm.initVariables()
  15. '''
  16. # init_op = tf.global_variables_initializer()
  17. # sess.run(init_op)
  18. # summaryWriter = tf.summary.FileWriter('log/', tf.get_default_graph())
  19. print(bilstm.ner(text))
  20. _ner_fool = fool.ner(text)
  21. _ner_selffool = bilstm.ner(text)
  22. if len(set(_ner_fool[0]) & set(_ner_selffool[0])) == len(_ner_fool[0]):
  23. print(set(fool.ner(text)[0]) & set(bilstm.ner(text)[0]))
  24. def dealNotFoundEntity():
  25. '''
  26. @summary: 处理未识别数据
  27. '''
  28. df = pd.read_excel("C:\\Users\\User\\Desktop\\无法分离实体名称.xlsx")
  29. list_newname_fool = []
  30. list_newname_selffool = []
  31. count = 0
  32. for _name in df["name"]:
  33. count += 1
  34. print(_name)
  35. if str(_name) == "nan":
  36. list_newname_fool.append("")
  37. list_newname_selffool.append("")
  38. continue
  39. print(count, len(df["name"]))
  40. _newname_fool = ""
  41. _newname_selffool = ""
  42. for _ner in fool.ner(_name)[0]:
  43. _newname_fool += _ner[3] + "##"
  44. for _ner in bilstm.ner(_name)[0]:
  45. _newname_selffool += _ner[3] + "##"
  46. list_newname_fool.append(_newname_fool[:-2])
  47. list_newname_selffool.append(_newname_selffool[:-2])
  48. data = {"id": df["id"],
  49. "area": df["area"],
  50. "province": df["province"],
  51. "city": df["city"],
  52. "district": df["district"],
  53. "name": df["name"],
  54. "newname_fool": list_newname_fool,
  55. "newname_selffool": list_newname_selffool}
  56. _df = pd.DataFrame(data, columns=["id", "area", "province", "city", "district", "name", "newname_fool",
  57. "newname_selffool"])
  58. _df.to_excel("C:\\Users\\User\\Desktop\\无法分离实体名称_deal.xls")
  59. def nerEntity():
  60. file = "C:\\Users\\User\\Desktop\\select_company_name_from_bxkc_C_CONTACT_.tsv"
  61. file_found = "C:\\Users\\User\\Desktop\\company_found.tsv"
  62. file_notfound = "C:\\Users\\User\\Desktop\\company_notfound.tsv"
  63. with codecs.open(file, "r", encoding="utf8") as f:
  64. with codecs.open(file_found, "w", encoding="utf8") as f_found:
  65. with codecs.open(file_notfound, "w", encoding="utf8") as f_notfound:
  66. while (True):
  67. line = f.readline().strip()
  68. if not line:
  69. break
  70. entity = re.sub(")", ")", re.sub("(", "(", line))
  71. if re.search("公司$", entity):
  72. _ner = bilstm.ner(entity)[0]
  73. if len(_ner) == 1 and _ner[0][3] == entity:
  74. f_found.write(entity + "\n")
  75. else:
  76. f_notfound.write(entity + "\n")
  77. def cleanEntity():
  78. source_file = "C:\\Users\\User\\Desktop\\notcleanedEntity.tsv"
  79. temp_file = "C:\\Users\\User\\Desktop\\temp.tsv"
  80. set_cleanedEntity = set()
  81. set_notcleanedEntity = set()
  82. with codecs.open(source_file, "r", encoding="utf8") as f_nce:
  83. while (True):
  84. line = f_nce.readline().strip()
  85. if not line:
  86. break
  87. entity = re.sub('["\s]', "", line)
  88. f_1 = list(re.finditer("公司", entity))
  89. f_2 = list(re.finditer("[支分]公司", entity))
  90. # if len(f_1)==2 and len(f_2)==1 and re.search("[原;;.。、\|,,]",entity[f_1[0].span()[1]:f_1[1].span()[0]]) is None:
  91. if re.search("br|/", entity) is not None:
  92. # f_ce.write(entity+"\n")
  93. set_cleanedEntity.add(entity)
  94. else:
  95. set_notcleanedEntity.add(entity)
  96. list_cleanedEntity = list(set_cleanedEntity)
  97. list_cleanedEntity.sort(key=lambda x: len(x))
  98. list_notcleanedEntity = list(set_notcleanedEntity)
  99. list_notcleanedEntity.sort(key=lambda x: len(x))
  100. with codecs.open(temp_file, "w", encoding="utf8") as f_ce:
  101. with codecs.open(source_file, "w", encoding="utf8") as f_nce:
  102. for item in list_cleanedEntity:
  103. f_ce.write(item + "\n");
  104. for item in list_notcleanedEntity:
  105. f_nce.write(item + "\n")
  106. if __name__ == "__main__":
  107. '''
  108. path_add = "0-12/"
  109. path = 'model/'+path_add+'model.ckpt'
  110. bilstm = BiLSTM().restore(path)
  111. '''
  112. bertCrf = BertCRF().restore()
  113. text = '小册子一批采购计划一、采购人:广州市比地数据科技有限公司,二、采购项目编号:'
  114. print(bertCrf.ner(text))
  115. # dealNotFoundEntity()
  116. pass
  117. '''
  118. cleanEntity()
  119. '''