test_core_active_learning.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # -*- coding: utf-8 -*-
  2. from unittest import mock
  3. from iepy.data.models import EvidenceCandidate
  4. from iepy.extraction.active_learning_core import ActiveLearningCore
  5. from .factories import EvidenceCandidateFactory, RelationFactory
  6. from .manager_case import ManagerTestCase
  7. class ActiveLearningTestMixin:
  8. def setUp(self):
  9. self.relation = RelationFactory(left_entity_kind__name='person',
  10. right_entity_kind__name='location')
  11. self.ev1 = EvidenceCandidateFactory()
  12. self.ev2 = EvidenceCandidateFactory()
  13. self.ev3 = EvidenceCandidateFactory()
  14. def lbl_evs(self, values):
  15. return dict(zip(EvidenceCandidate.objects.all().order_by('id'), values))
  16. class TestQuestions(ActiveLearningTestMixin, ManagerTestCase):
  17. def test_cant_start_core_with_all_evidence_labeled(self):
  18. # why not? Well, simply because is overkill to ask IEPY something
  19. # that's already available
  20. evidences = self.lbl_evs([True, False, True])
  21. self.assertRaises(ValueError, ActiveLearningCore,
  22. self.relation, evidences
  23. )
  24. def test_every_evidence_without_label_is_a_question(self):
  25. c = ActiveLearningCore(self.relation, self.lbl_evs([None]*3))
  26. self.assertEqual(len(c.questions), 3)
  27. c = ActiveLearningCore(self.relation, self.lbl_evs([False, True, None]))
  28. self.assertEqual(len(c.questions), 1)
  29. c = ActiveLearningCore(self.relation, self.lbl_evs([None, True, None]))
  30. self.assertEqual(len(c.questions), 2)
  31. def test_every_question_answered_is_not_a_question_any_more(self):
  32. c = ActiveLearningCore(self.relation, self.lbl_evs([None]*3))
  33. c.add_answer(self.ev1, False)
  34. self.assertEqual(len(c.questions), 2)
  35. self.assertNotIn(self.ev1, c.questions)
  36. class TestProcess(ActiveLearningTestMixin, ManagerTestCase):
  37. def setUp(self):
  38. super().setUp()
  39. self.c = ActiveLearningCore(self.relation, self.lbl_evs([None]*3))
  40. patcher = mock.patch.object(self.c, 'train_relation_classifier')
  41. self.mock_train_classifier = patcher.start()
  42. self.addCleanup(patcher.stop)
  43. def test_process_with_no_available_labels_does_nothing(self):
  44. self.c.process()
  45. self.assertFalse(self.mock_train_classifier.called)
  46. def test_process_with_not_both_labels_does_nothing(self):
  47. # by "both", we mean True and False
  48. self.c.add_answer(self.ev1, True)
  49. self.c.process()
  50. self.assertFalse(self.mock_train_classifier.called)
  51. self.c.add_answer(self.ev2, True)
  52. self.c.process()
  53. self.assertFalse(self.mock_train_classifier.called)
  54. self.c.add_answer(self.ev3, False)
  55. self.c.process()
  56. self.assertTrue(self.mock_train_classifier.called)
  57. def test_more_than_binary_labels_is_raise(self):
  58. self.c.add_answer(self.ev1, True)
  59. self.c.add_answer(self.ev2, False)
  60. self.c.add_answer(self.ev3, False)
  61. self.c.labeled_evidence[self.ev3] = 'weird thing'
  62. self.assertRaises(ValueError, self.c.process)
  63. self.assertFalse(self.mock_train_classifier.called)