test_combined_ner.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from unittest import mock
  2. from operator import attrgetter
  3. from unittest import TestCase
  4. from iepy.preprocess.ner.base import BaseNERRunner
  5. from iepy.preprocess.ner.combiner import (
  6. CombinedNERRunner, NoOverlapCombinedNERRunner,
  7. KindPreferenceCombinedNERRunner)
  8. from iepy.preprocess.pipeline import PreProcessSteps
  9. class BaseTestCombined(TestCase):
  10. def setUp(self):
  11. def ner_not_done(step):
  12. return step != PreProcessSteps.ner
  13. self.runner1 = mock.MagicMock()
  14. self.runner2 = mock.MagicMock()
  15. self.doc = mock.MagicMock()
  16. self.doc.was_preprocess_step_done.side_effect = ner_not_done
  17. class TestCombinedNERRunner(BaseTestCombined):
  18. def test_runners_called_when_not_done_before(self):
  19. runner1, runner2, doc = self.runner1, self.runner2, self.doc
  20. runner = CombinedNERRunner([runner1, runner2])
  21. runner(doc)
  22. runner1.run_ner.assert_called_once_with(doc)
  23. runner2.run_ner.assert_called_once_with(doc)
  24. def test_runners_called_when_override(self):
  25. runner1, runner2, doc = self.runner1, self.runner2, self.doc
  26. doc.was_preprocess_step_done.side_effect = lambda x: True
  27. runner = CombinedNERRunner([runner1, runner2], override=True)
  28. runner(doc)
  29. runner1.run_ner.assert_called_once_with(doc)
  30. runner2.run_ner.assert_called_once_with(doc)
  31. def test_runners_not_called_when_done_before(self):
  32. runner1, runner2, doc = self.runner1, self.runner2, self.doc
  33. doc.was_preprocess_step_done.side_effect = lambda x: True
  34. runner = CombinedNERRunner([runner1, runner2])
  35. runner(doc)
  36. self.assertFalse(runner1.run_ner.called)
  37. self.assertFalse(runner2.run_ner.called)
  38. def test_no_entities_are_lost(self):
  39. runner1, runner2, doc = self.runner1, self.runner2, self.doc
  40. e1 = mock.MagicMock()
  41. e1.offset = 1
  42. e2 = mock.MagicMock()
  43. e2.offset = 2
  44. runner1.run_ner.side_effect = lambda doc: [e1]
  45. runner2.run_ner.side_effect = lambda doc: [e2]
  46. runner = CombinedNERRunner([runner1, runner2])
  47. runner(doc)
  48. doc.set_ner_result.assert_called_once_with([e1, e2])
  49. def test_can_define_combiner_for_only_one_ner(self):
  50. runner = CombinedNERRunner([self.runner1])
  51. runner(self.doc)
  52. self.assertTrue(self.doc.set_ner_result.called)
  53. def test_can_define_combiner_for_lots_of_ners(self):
  54. runner1, runner2, doc = self.runner1, self.runner2, self.doc
  55. runner3, runner4 = mock.MagicMock(), mock.MagicMock()
  56. runners = [runner1, runner2, runner3, runner4]
  57. ents = []
  58. for i, r in enumerate(runners):
  59. ei = mock.MagicMock()
  60. ei.offset = i + 1
  61. ents.append(ei)
  62. runner1.run_ner.side_effect = lambda doc: [ents[0]]
  63. runner2.run_ner.side_effect = lambda doc: [ents[1]]
  64. runner3.run_ner.side_effect = lambda doc: [ents[2]]
  65. runner4.run_ner.side_effect = lambda doc: [ents[3]]
  66. runner = CombinedNERRunner(runners)
  67. runner(doc)
  68. doc.set_ner_result.assert_called_once_with(ents)
  69. class TestNEROverlappingHandling(BaseTestCombined):
  70. def setUp(self):
  71. super(TestNEROverlappingHandling, self).setUp()
  72. self.result1 = self.construct_occurrences(
  73. [(1, 3, u'X'), (6, 8, u'W'), (8, 9, u'X'), (11, 12, u'W')])
  74. self.result2 = self.construct_occurrences(
  75. [(2, 4, u'Y'), (5, 7, u'Z'), (8, 9, u'Y'), (9, 13, u'Z')])
  76. self.runner1.run_ner.side_effect = lambda doc: self.result1
  77. self.runner2.run_ner.side_effect = lambda doc: self.result2
  78. def construct_occurrences(self, data):
  79. r = BaseNERRunner()
  80. eos = []
  81. for offset, offset_end, kind in data:
  82. eos.append(r.build_occurrence(
  83. key='blah',
  84. kind_name=kind,
  85. alias='blah',
  86. offset=offset,
  87. offset_end=offset_end)
  88. )
  89. return eos
  90. def test_overlapped_are_stored_like_that_on_default_combiner(self):
  91. runner = CombinedNERRunner([self.runner1, self.runner2])
  92. runner(self.doc)
  93. self.doc.set_ner_result.assert_called_once_with(
  94. sorted(self.result1 + self.result2,
  95. key=attrgetter('offset', 'offset_end', 'kind_name'))
  96. )
  97. def test_simple_overlap_solver_prefers_from_former_subners(self):
  98. NER = NoOverlapCombinedNERRunner([self.runner1, self.runner2])
  99. NER(self.doc)
  100. self.doc.set_ner_result.assert_called_once_with(self.result1)
  101. # again, the other way around
  102. NER = NoOverlapCombinedNERRunner([self.runner2, self.runner1])
  103. self.doc.reset_mock()
  104. NER(self.doc)
  105. self.doc.set_ner_result.assert_called_once_with(self.result2)
  106. def test_overlaps_is_solved_prefering_some_kind_over_other(self):
  107. combiner = lambda rank: KindPreferenceCombinedNERRunner(
  108. [self.runner1, self.runner2],
  109. rank=rank
  110. )
  111. combiner([u'X', u'W', u'Y', u'Z'])(self.doc)
  112. self.assertEqual(
  113. self.doc.set_ner_result.call_args_list[-1],
  114. mock.call(self.result1))
  115. # Not ranked kinds rank bad
  116. combiner([u'X', u'W'])(self.doc)
  117. self.assertEqual(
  118. self.doc.set_ner_result.call_args_list[-1],
  119. mock.call(self.result1))
  120. combiner([u'Z', u'Y'])(self.doc)
  121. self.assertEqual(
  122. self.doc.set_ner_result.call_args_list[-1],
  123. mock.call(self.result2))
  124. def test_kindpreference_must_be_instantiated_with_tuple_or_list(self):
  125. combiner = lambda rank: KindPreferenceCombinedNERRunner(
  126. [self.runner1, self.runner2],
  127. rank=rank
  128. )
  129. self.assertRaises(ValueError, combiner, 'something')
  130. self.assertRaises(ValueError, combiner, None)
  131. self.assertRaises(ValueError, combiner, 1)
  132. # Not raises
  133. combiner(('some', 'thing'))
  134. combiner(['some', 'thing'])