test_core_rules.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # -*- coding: utf-8 -*-
  2. from unittest import mock
  3. from refo.patterns import Pattern
  4. from refo import Question, Star, Any
  5. from iepy.data.db import CandidateEvidenceManager
  6. from iepy.extraction.rules import rule, Token
  7. from iepy.extraction.rules_core import RuleBasedCore
  8. from .factories import (
  9. EntityKindFactory, RelationFactory, TextSegmentFactory,
  10. IEDocFactory, EntityOccurrenceFactory, EntityFactory,
  11. )
  12. from .manager_case import ManagerTestCase
  13. class TestRuleBasedCore(ManagerTestCase):
  14. def setUp(self):
  15. super(TestRuleBasedCore, self).setUp()
  16. kind_person = EntityKindFactory(name="person")
  17. kind_date = EntityKindFactory(name="date")
  18. self.person_date_relation = RelationFactory(
  19. name="born in",
  20. left_entity_kind=kind_person,
  21. right_entity_kind=kind_date,
  22. )
  23. text = "John Soplete ( 15 august 1990 - 26 september 2058 ) " \
  24. "was a software developer"
  25. document = self._create_simple_document(text)
  26. segment = TextSegmentFactory(
  27. document=document,
  28. offset=0,
  29. offset_end=len(document.tokens)
  30. )
  31. self.segment = segment
  32. e_john = EntityFactory(key="John Soplete", kind=kind_person)
  33. eo1 = EntityOccurrenceFactory(
  34. entity=e_john, document=document,
  35. offset=0, offset_end=2,
  36. alias="j0hn",
  37. )
  38. eo1.segments.add(segment)
  39. e_date = EntityFactory(key="15 august 1990", kind=kind_date)
  40. eo2 = EntityOccurrenceFactory(
  41. entity=e_date, document=document,
  42. offset=3, offset_end=6,
  43. alias="1990-08-15",
  44. )
  45. eo2.segments.add(segment)
  46. self._candidates = self.get_candidates(self.person_date_relation)
  47. def get_candidates(self, relation):
  48. return list(CandidateEvidenceManager.candidates_for_relation(relation))
  49. def _create_simple_document(self, text):
  50. tokens = tuple(text.split())
  51. lemmas = [""] * len(tokens)
  52. postags = ["POSTAG"] * len(tokens)
  53. indexes = tuple(list(range(len(tokens))))
  54. document = IEDocFactory(text=text)
  55. document.set_tokenization_result(list(zip(indexes, tokens)))
  56. document.set_lemmatization_result(lemmas)
  57. document.set_tagging_result(postags)
  58. document.save()
  59. return document
  60. def test_rule_that_matches(self):
  61. @rule(True)
  62. def test_rule(Subject, Object):
  63. anything = Question(Star(Any()))
  64. return Subject + Token("(") + Object + Token("-") + anything
  65. pipeline = RuleBasedCore(self.person_date_relation, [test_rule])
  66. pipeline.start()
  67. pipeline.process()
  68. facts = pipeline.predict(self._candidates)
  69. candidate = self._candidates[0]
  70. self.assertTrue(facts[candidate])
  71. def test_rule_that_not_matches(self):
  72. @rule(True)
  73. def test_rule(Subject, Object):
  74. return Subject + Object + Token("something here")
  75. pipeline = RuleBasedCore(self.person_date_relation, [test_rule])
  76. pipeline.start()
  77. pipeline.process()
  78. facts = pipeline.predict(self._candidates)
  79. candidate = self._candidates[0]
  80. self.assertFalse(facts[candidate])
  81. def test_empty_rules(self):
  82. pipeline = RuleBasedCore(self.person_date_relation, [])
  83. pipeline.start()
  84. pipeline.process()
  85. facts = pipeline.predict(self._candidates)
  86. self.assertEqual(len([x for x in facts if facts[x]]), 0)
  87. def test_match_run_on_every_rule(self):
  88. mocked_rules = [
  89. rule(True)(mock.MagicMock(return_value=Token("asd")))
  90. ] * 10
  91. pipeline = RuleBasedCore(self.person_date_relation, mocked_rules)
  92. pipeline.start()
  93. pipeline.process()
  94. pipeline.predict(self._candidates)
  95. for mock_rule in mocked_rules:
  96. self.assertTrue(mock_rule.called)
  97. Subject, Object = mock_rule.call_args[0]
  98. self.assertIsInstance(Subject, Pattern)
  99. def test_rule_priority(self):
  100. matcher = lambda *args: True
  101. not_matcher = lambda *args: None
  102. rule_should_run = rule(True, priority=1)(mock.MagicMock(return_value=matcher))
  103. rule_should_not_run = rule(True, priority=0)(
  104. mock.MagicMock(return_value=not_matcher))
  105. pipeline = RuleBasedCore(self.person_date_relation,
  106. [rule_should_not_run, rule_should_run])
  107. pipeline.start()
  108. # All rules are compiled on start
  109. self.assertTrue(rule_should_run.called)
  110. self.assertTrue(rule_should_not_run.called)
  111. pipeline.process()
  112. import refo
  113. with mock.patch.object(refo, 'match') as fake_refo_match:
  114. fake_refo_match.side_effect = lambda regex, evidence: regex()
  115. pipeline.predict(self._candidates)
  116. self.assertEqual(fake_refo_match.call_count, len(self._candidates))
  117. # check that on every call, the called is rule_match
  118. for c_args in fake_refo_match.call_args_list:
  119. args, kwargs = c_args
  120. self.assertEqual(args[0], matcher)
  121. def test_rule_incorrect_answer(self):
  122. with self.assertRaises(ValueError):
  123. @rule("YE")
  124. def rule_match(Subject, Object):
  125. anything = Question(Star(Any()))
  126. return Subject + Token("(") + Object + Token("-") + anything
  127. def test_rule_with_negative_answer(self):
  128. @rule(False)
  129. def test_rule(Subject, Object):
  130. anything = Question(Star(Any()))
  131. return Subject + Token("(") + Object + Token("-") + anything
  132. pipeline = RuleBasedCore(self.person_date_relation, [test_rule])
  133. pipeline.start()
  134. pipeline.process()
  135. facts = pipeline.predict(self._candidates)
  136. candidate = self._candidates[0]
  137. self.assertFalse(facts[candidate])