test_ner.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from iepy.data.models import IEDocument
  2. from iepy.preprocess.ner.stanford import NERRunner, StanfordNERRunner
  3. from iepy.preprocess.pipeline import PreProcessSteps
  4. from .factories import SentencedIEDocFactory, IEDocFactory
  5. from .manager_case import ManagerTestCase
  6. class NERTestMixin(object):
  7. entity_map = {
  8. 'Rami': 'PERSON',
  9. 'Eid': 'PERSON',
  10. 'Stony': 'ORGANIZATION',
  11. 'Brook': 'ORGANIZATION',
  12. 'University': 'ORGANIZATION',
  13. }
  14. def check_ner(self, doc, entities_triples):
  15. def ner(sents):
  16. return [[(t, self.entity_map.get(t, 'O')) for t in sent] for sent in sents]
  17. ner_runner = NERRunner(ner)
  18. ner_runner(doc)
  19. self.check_ner_result(doc, entities_triples)
  20. def check_ner_result(self, doc, entities_triples):
  21. self.assertTrue(doc.was_preprocess_step_done(PreProcessSteps.ner))
  22. entities = self.get_ner_result(doc)
  23. self.assertEqual(len(entities), len(entities_triples))
  24. for e, (offset, offset_end, kind) in zip(entities, entities_triples):
  25. self.assertEqual(e.offset, offset)
  26. self.assertEqual(e.offset_end, offset_end)
  27. self.assertEqual(e.entity.kind.name, kind)
  28. def get_ner_result(self, doc):
  29. # hacked ORM detail
  30. return list(doc.entity_occurrences.all())
  31. class TestNERRunner(ManagerTestCase, NERTestMixin):
  32. ManagerClass = IEDocument
  33. def test_ner_runner_is_calling_ner(self):
  34. doc = SentencedIEDocFactory(
  35. text='Rami Eid is studying . At Stony Brook University in NY')
  36. self.check_ner(doc, [(0, 2, 'PERSON'), (6, 9, 'ORGANIZATION')])
  37. def test_ner_runner_finds_consecutive_entities(self):
  38. doc = SentencedIEDocFactory(
  39. text='The student Rami Eid Stony Brook University in NY')
  40. self.check_ner(doc, [(2, 4, 'PERSON'), (4, 7, 'ORGANIZATION')])