__init__.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. #!/usr/bin/env python
  2. # -*-coding:utf-8-*-
  3. import sys
  4. import logging
  5. from collections import defaultdict
  6. from BiddingKG.dl.foolnltk.selffool import lexical
  7. from BiddingKG.dl.foolnltk.selffool import dictionary
  8. from BiddingKG.dl.foolnltk.selffool import model
  9. from BiddingKG.dl.foolnltk.selffool import selffool_ner
  10. #from BiddingKG.dl.BertNer.BertCRF import BertCRF
  11. LEXICAL_ANALYSER = lexical.LexicalAnalyzer()
  12. _DICTIONARY = dictionary.Dictionary()
  13. __log_console = logging.StreamHandler(sys.stderr)
  14. DEFAULT_LOGGER = logging.getLogger(__name__)
  15. DEFAULT_LOGGER.setLevel(logging.DEBUG)
  16. DEFAULT_LOGGER.addHandler(__log_console)
  17. '''
  18. from BiddingKG.dl.foolnltk.bi_lstm_crf import BiLSTM
  19. bilstm = BiLSTM()
  20. bilstm.restore()
  21. '''
  22. selfNer = selffool_ner.SelfNer()
  23. #bertCRF = BertCRF().restore()
  24. __all__= ["load_model", "cut", "pos_cut", "ner", "analysis", "load_userdict", "delete_userdict"]
  25. def load_model(map_file, model_file):
  26. m = model.Model(map_file=map_file, model_file=model_file)
  27. return m
  28. def _check_input(text, ignore=False):
  29. if not text:
  30. return []
  31. if not isinstance(text, list):
  32. text = [text]
  33. null_index = [i for i, t in enumerate(text) if not t]
  34. if null_index and not ignore:
  35. raise Exception("null text in input ")
  36. return text
  37. def ner(text, ignore=False):
  38. text = _check_input(text, ignore)
  39. if not text:
  40. return [[]]
  41. res = LEXICAL_ANALYSER.ner(text)
  42. return res
  43. def self_ner(text,ignore=False):
  44. text = _check_input(text, ignore)
  45. if not text:
  46. return [[]]
  47. res = selfNer.ner(text)
  48. #res = bilstm.ner(text)
  49. #res = bertCRF.ner(text)
  50. return res
  51. def analysis(text, ignore=False):
  52. text = _check_input(text, ignore)
  53. if not text:
  54. return [[]], [[]]
  55. res = LEXICAL_ANALYSER.analysis(text)
  56. return res
  57. def cut(text, ignore=False):
  58. text = _check_input(text, ignore)
  59. if not text:
  60. return [[]]
  61. text = [t for t in text if t]
  62. all_words = LEXICAL_ANALYSER.cut(text)
  63. new_words = []
  64. if _DICTIONARY.sizes != 0:
  65. for sent, words in zip(text, all_words):
  66. words = _mearge_user_words(sent, words)
  67. new_words.append(words)
  68. else:
  69. new_words = all_words
  70. return new_words
  71. def pos_cut(text):
  72. words = cut(text)
  73. pos_labels = LEXICAL_ANALYSER.pos(words)
  74. word_inf = [list(zip(ws, ps)) for ws, ps in zip(words, pos_labels)]
  75. return word_inf
  76. def load_userdict(path):
  77. _DICTIONARY.add_dict(path)
  78. def delete_userdict():
  79. _DICTIONARY.delete_dict()
  80. def _mearge_user_words(text, seg_results):
  81. if not _DICTIONARY:
  82. return seg_results
  83. matchs = _DICTIONARY.parse_words(text)
  84. graph = defaultdict(dict)
  85. text_len = len(text)
  86. for i in range(text_len):
  87. graph[i][i + 1] = 1.0
  88. index = 0
  89. for w in seg_results:
  90. w_len = len(w)
  91. graph[index][index + w_len] = _DICTIONARY.get_weight(w) + w_len
  92. index += w_len
  93. for m in matchs:
  94. graph[m.start][m.end] = _DICTIONARY.get_weight(m.keyword) * len(m.keyword)
  95. route = {}
  96. route[text_len] = (0, 0)
  97. for idx in range(text_len - 1, -1, -1):
  98. m = [((graph.get(idx).get(k) + route[k][0]), k) for k in graph.get(idx).keys()]
  99. mm = max(m)
  100. route[idx] = (mm[0], mm[1])
  101. index = 0
  102. path = [index]
  103. words = []
  104. while index < text_len:
  105. ind_y = route[index][1]
  106. path.append(ind_y)
  107. words.append(text[index:ind_y])
  108. index = ind_y
  109. return words