combiner.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from iepy.preprocess.ner.base import BaseNERRunner
  2. class CombinedNERRunner(BaseNERRunner):
  3. """A NER runner that is the combination of different NER runners
  4. (therefore, different NERs).
  5. The entities returned by each NER are combined by the method merge_entities
  6. without any check, possibly leading to duplicate or overlapping entities;
  7. but subclassing this combiner you may define something different.
  8. """
  9. def __init__(self, ners, override=False):
  10. """The NER runners should be instances of BasePreProcessStepRunner.
  11. Notes:
  12. - Each of the sub-ners will be configured to run with override-mode
  13. "on", no matter what is the global override value.
  14. The global override, will be used for determining wether to start
  15. or not the global-combined process.
  16. - Overriding only some NERs and not others is not allowed.
  17. """
  18. super(CombinedNERRunner, self).__init__(override=override)
  19. if not ners:
  20. raise ValueError(u'Empty NERs to combine')
  21. self.ners = ners
  22. for sub_ner in self.ners:
  23. sub_ner.override = True
  24. def merge_entities(self, sub_results):
  25. # Default merger does nothing but merging & sorting by offset
  26. all_entities = []
  27. for ner, sub_entities in sub_results:
  28. all_entities.extend(sub_entities)
  29. return sorted(all_entities, key=lambda x: x.offset)
  30. def run_ner(self, doc):
  31. sub_results = []
  32. for sub_ner in self.ners:
  33. sub_results.append(
  34. (sub_ner,
  35. sub_ner.run_ner(doc)
  36. )
  37. )
  38. return self.merge_entities(sub_results)
  39. class NoOverlapCombinedNERRunner(CombinedNERRunner):
  40. """
  41. Similar to the CombinedNERRunner, but when merging results from different
  42. taggers avoids overlapping by discarding those entities that were provided
  43. by later subners.
  44. It's assumed that each sub NER provides non overlapped entities.
  45. """
  46. def overlapped_entities(self, e1, e2):
  47. min1, max1 = e1.offset, e1.offset_end
  48. min2, max2 = e2.offset, e2.offset_end
  49. return bool(max(0, min(max1, max2) - max(min1, min2)))
  50. def merge_entities(self, sub_results):
  51. result = []
  52. for ner, sub_res in sub_results:
  53. if not result:
  54. # first ner returning something. all in.
  55. result.extend(sub_res)
  56. else:
  57. for ent in sub_res:
  58. if any(self.overlapped_entities(ent, e_i) for e_i in result):
  59. continue
  60. result.append(ent)
  61. return sorted(result, key=lambda x: x.offset)
  62. class KindPreferenceCombinedNERRunner(CombinedNERRunner):
  63. """
  64. Similar to the CombinedNERRunner, but when merging results from different
  65. taggers avoids overlapping by discarding those entities whose kind was worst
  66. ranked on the Combiner creation.
  67. If a given entity kind is not ranked on Combiner, will be treated worst than
  68. the worst ranked.
  69. If conflict remains, following rules apply:
  70. - shorter occurrences are preferred over larger
  71. - occurrences of former sub NERs are preferred.
  72. """
  73. def __init__(self, ners, override=False, rank=tuple()):
  74. """
  75. """
  76. # the lower the rank, the more important
  77. if not isinstance(rank, (tuple, list)):
  78. raise ValueError(u'rank can only be a list or tuple')
  79. self.kinds_rank = dict((k, i) for i, k in enumerate(rank))
  80. self.worst_rank = len(self.kinds_rank)
  81. super(KindPreferenceCombinedNERRunner, self).__init__(ners, override)
  82. def get_rank(self, found_entity):
  83. return self.kinds_rank.setdefault(found_entity.kind_name, self.worst_rank)
  84. def merge_entities(self, sub_results):
  85. sorted_occurrences = super(KindPreferenceCombinedNERRunner,
  86. self).merge_entities(sub_results)
  87. if not sorted_occurrences:
  88. return sorted_occurrences
  89. prev = sorted_occurrences[0]
  90. to_remove = set()
  91. # given that entities are sorted, cannot be the case that one entity
  92. # has offset lower than the previous one
  93. for eo in sorted_occurrences[1:]:
  94. if eo.offset < prev.offset_end:
  95. # there's an overlap. One of these 2 must be removed
  96. prev_criteria = (
  97. self.get_rank(prev), # kind rank
  98. -1 * (prev.offset_end - prev.offset) # inversed length
  99. )
  100. eo_criteria = (
  101. self.get_rank(eo),
  102. -1 * (eo.offset_end - eo.offset)
  103. )
  104. if prev_criteria <= eo_criteria:
  105. to_remove.add(eo)
  106. else:
  107. to_remove.add(prev)
  108. prev = eo
  109. else:
  110. prev = eo
  111. return [eo_i for eo_i in sorted_occurrences if eo_i not in to_remove]