test_segments.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import unittest
  2. from iepy.data.models import TextSegment, EntityOccurrence
  3. from iepy.preprocess.segmenter import RawSegment, SyntacticSegmenterRunner
  4. from .factories import IEDocFactory, EntityFactory, EntityOccurrenceFactory, TextSegmentFactory
  5. from .manager_case import ManagerTestCase
  6. def RSF(offset=0, offset_end=0, entity_occurrences=None):
  7. # hand made RawSegmentFactory
  8. return RawSegment(offset, offset_end, entity_occurrences)
  9. class TextSegmentCreationTest(ManagerTestCase):
  10. def setUp(self):
  11. self.d = IEDocFactory()
  12. def build_and_get_segment_from_raw(self, raw):
  13. self.d.set_segmentation_result([raw], override=True)
  14. return list(self.d.get_text_segments())[0].hydrate()
  15. def test_empty(self):
  16. s = self.build_and_get_segment_from_raw(RSF(0, 0))
  17. self.assertEqual(s.document, self.d)
  18. self.assertEqual(s.text, "")
  19. self.assertEqual(s.offset, 0)
  20. self.assertEqual(s.tokens, [])
  21. self.assertEqual(s.postags, [])
  22. self.assertEqual(list(s.get_entity_occurrences()), [])
  23. def test_data_copied_simple(self):
  24. d = self.d
  25. d.offsets_to_text = list(range(7))
  26. d.tokens = list("ABCDEFG")
  27. d.postags = list("NNVANVA")
  28. d.text = "ABCDEFG"
  29. s = self.build_and_get_segment_from_raw(RSF(2, 5))
  30. self.assertEqual(s.offset, 2)
  31. self.assertEqual(s.tokens, ["C", "D", "E"])
  32. self.assertEqual(s.postags, ["V", "A", "N"])
  33. self.assertEqual(list(s.get_entity_occurrences()), [])
  34. self.assertEqual(s.text, "CDE")
  35. def hack_document(self, text):
  36. # monkey patch the iedocument with some valid attributes
  37. # if text has repeated tokens, we are dead.
  38. # veeery naive punctuation symbols handling: "," and "." only.
  39. self.d.text = text
  40. pre_tokens = text.split()
  41. # punctuation handling
  42. self.d.tokens = []
  43. for tk in pre_tokens:
  44. for symbol in [',', '.']:
  45. if tk.endswith(symbol):
  46. self.d.tokens.append(tk.strip(symbol))
  47. self.d.tokens.append(symbol)
  48. break
  49. else:
  50. self.d.tokens.append(tk)
  51. assert len(set(self.d.tokens)) == len(self.d.tokens) # repeated tks check
  52. self.d.offsets_to_text = [self.d.text.index(t) for t in self.d.tokens]
  53. self.d.postags = ["NN" for t in self.d.tokens]
  54. def test_data_copied_complex(self):
  55. self.hack_document("The people around the world is crazy.")
  56. s = self.build_and_get_segment_from_raw(RSF(2, 5))
  57. self.assertEqual(s.tokens, ["around", "the", "world"])
  58. self.assertEqual(s.text, "around the world")
  59. s = self.build_and_get_segment_from_raw(RSF(4, 8))
  60. self.assertEqual(s.tokens, ["world", "is", "crazy", "."])
  61. self.assertEqual(s.text, "world is crazy.")
  62. def test_entities_capture_simple(self):
  63. self.hack_document("The people around the world is crazy.")
  64. self.d.save()
  65. eo = EntityOccurrenceFactory(document=self.d, offset=4, offset_end=5)
  66. s = self.build_and_get_segment_from_raw(RSF(2, 6))
  67. self.assertEqual(list(s.get_entity_occurrences()), [eo])
  68. def test_hydrated_entity_occurrences_from_segment(self):
  69. # verify the segment offsets: ie, the goal is that obtaining tokens from
  70. # document or from segment shall be equivalent
  71. self.hack_document("The people around the world is crazy.")
  72. eo = EntityOccurrenceFactory(document=self.d, offset=3, offset_end=5)
  73. expected = ["the", "world"]
  74. assert self.d.tokens[eo.offset:eo.offset_end] == expected
  75. segm = self.build_and_get_segment_from_raw(RSF(2, 6))
  76. s_eo = list(segm.get_entity_occurrences())[0]
  77. self.assertEqual(
  78. segm.tokens[s_eo.segment_offset:s_eo.segment_offset_end],
  79. expected)
  80. def test_entities_capture_end_border(self):
  81. # check that occurrence ending on the last token is correcly captured
  82. self.hack_document("The people around the world is crazy.")
  83. eo = EntityOccurrenceFactory(document=self.d, offset=4, offset_end=5)
  84. s = self.build_and_get_segment_from_raw(RSF(2, 5))
  85. self.assertEqual(list(s.get_entity_occurrences()), [eo])
  86. def test_entities_capture_start_border(self):
  87. # check that occurrence starting on the first token is correcly captured
  88. self.hack_document("The people around the world is crazy.")
  89. eo = EntityOccurrenceFactory(document=self.d, offset=2, offset_end=3)
  90. s = self.build_and_get_segment_from_raw(RSF(2, 5))
  91. self.assertEqual(list(s.get_entity_occurrences()), [eo])
  92. def test_entities_capture_ending_outside_are_not_included(self):
  93. self.hack_document("The people around the world is crazy.")
  94. EntityOccurrenceFactory(document=self.d, offset=4, offset_end=6)
  95. s = self.build_and_get_segment_from_raw(RSF(2, 5))
  96. self.assertEqual(list(s.get_entity_occurrences()), [])
  97. def test_entities_capture_starting_before_are_not_included(self):
  98. self.hack_document("The people around the world is crazy.")
  99. EntityOccurrenceFactory(document=self.d, offset=1, offset_end=3)
  100. s = self.build_and_get_segment_from_raw(RSF(2, 5))
  101. self.assertEqual(list(s.get_entity_occurrences()), [])
  102. def test_sentence_information(self):
  103. d = self.d
  104. L = 100
  105. d.offsets = list(range(L))
  106. d.tokens = ["X"]*L
  107. d.postags = ["N"]*L
  108. d.sentences = [0, 5, 35, 36, 41, 90]
  109. d.syntactic_sentences = [""] * max(d.sentences)
  110. s = self.build_and_get_segment_from_raw(RSF(30, 60))
  111. self.assertEqual(s.sentences, [5, 6, 11])
  112. s = self.build_and_get_segment_from_raw(RSF(0, 60))
  113. self.assertEqual(s.sentences, [0, 5, 35, 36, 41])
  114. class TestGetOccurrencesPairsFromSegment(ManagerTestCase):
  115. def setUp(self):
  116. self.s = TextSegmentFactory()
  117. self.d = self.s.document
  118. def create_occurrence(self, e, offset, end):
  119. return EntityOccurrenceFactory(document=self.d, entity=e,
  120. offset=offset, offset_end=end)
  121. def test_entity_occurrence_pairs(self):
  122. e1 = EntityFactory()
  123. e2 = EntityFactory()
  124. e3 = EntityFactory()
  125. # eo2_1 means: occurrence of entity "e2", first of them
  126. eo1_1 = self.create_occurrence(e1, 0, 1)
  127. eo2_1 = self.create_occurrence(e2, 1, 2)
  128. eo1_2 = self.create_occurrence(e1, 2, 3)
  129. eo2_2 = self.create_occurrence(e2, 3, 4)
  130. eo3_1 = self.create_occurrence(e3, 4, 5)
  131. self.s.entity_occurrences = [eo1_1, eo1_2, eo2_1, eo2_2, eo3_1]
  132. ps = self.s.entity_occurrence_pairs(e1, e2)
  133. self.assertEqual(ps, [(eo1_1, eo2_1), (eo1_1, eo2_2),
  134. (eo1_2, eo2_1), (eo1_2, eo2_2)])
  135. def test_entity_occurrence_pairs_does_not_repeat(self):
  136. e1 = EntityFactory()
  137. o11 = self.create_occurrence(e1, 0, 1)
  138. o12 = self.create_occurrence(e1, 1, 2)
  139. self.s.entity_occurrences = [o11, o12]
  140. ps = self.s.entity_occurrence_pairs(e1, e1)
  141. self.assertEqual(ps, [(o11, o12), (o12, o11)])
  142. def test_kind_occurrence_pairs(self):
  143. e1 = EntityFactory(kind__name='person')
  144. e2 = EntityFactory(kind__name='location')
  145. e3 = EntityFactory(kind=e2.kind)
  146. eo1_1 = self.create_occurrence(e1, 0, 1)
  147. eo2_1 = self.create_occurrence(e2, 1, 2)
  148. eo1_2 = self.create_occurrence(e1, 2, 3)
  149. eo3_1 = self.create_occurrence(e3, 3, 4)
  150. self.s.entity_occurrences = [eo1_1, eo2_1, eo1_2, eo3_1]
  151. ps = self.s.kind_occurrence_pairs(e1.kind, e2.kind)
  152. self.assertEqual(ps, [(eo1_1, eo2_1), (eo1_1, eo3_1),
  153. (eo1_2, eo2_1), (eo1_2, eo3_1)])
  154. class TestDocumentSegmenter(ManagerTestCase):
  155. ManagerClass = TextSegment
  156. def setUp(self):
  157. self.doc = IEDocFactory()
  158. super(TestDocumentSegmenter, self).setUp()
  159. self.segmenter = SyntacticSegmenterRunner()
  160. def set_doc_length(self, n):
  161. self.doc.tokens = ["x"] * n
  162. self.doc.offsets = list(range(n))
  163. self.doc.postags = ["tag"] * n
  164. self.doc.sentences = [0]
  165. def add_entities(self, positions):
  166. e1 = EntityFactory()
  167. for p in positions:
  168. if isinstance(p, tuple):
  169. start, length = p
  170. else:
  171. start, length = p, 1
  172. EntityOccurrenceFactory(
  173. document=self.doc,
  174. entity=e1, offset=start,
  175. offset_end=start + length,
  176. alias="AB")
  177. def test_no_entities(self):
  178. self.set_doc_length(100)
  179. raws = self.segmenter.build_syntactic_segments(self.doc)
  180. self.assertEqual(raws, [])
  181. def test_sentence_segmenter_limits(self):
  182. self.set_doc_length(100)
  183. self.add_entities([1, 2, 22, 23, 35, 61, 80])
  184. self.doc.sentences = [0, 20, 50]
  185. raws = self.segmenter.build_syntactic_segments(self.doc)
  186. self.assertEqual(len(raws), 3)
  187. s = raws[0]
  188. self.assertEqual(s.offset, 0)
  189. self.assertEqual(s.offset_end, 20)
  190. self.assertEqual(len(s.entity_occurrences), 2)
  191. s = raws[1]
  192. self.assertEqual(s.offset, 20)
  193. self.assertEqual(s.offset_end, 50)
  194. self.assertEqual(len(s.entity_occurrences), 3)
  195. s = raws[2]
  196. self.assertEqual(s.offset, 50)
  197. self.assertEqual(s.offset_end, len(self.doc.tokens))
  198. self.assertEqual(len(s.entity_occurrences), 2)
  199. def test_sentence_segmenter_requires_2_entities(self):
  200. self.set_doc_length(100)
  201. self.add_entities([1, 2, 22])
  202. self.doc.sentences = [0, 20, 50]
  203. raws = self.segmenter.build_syntactic_segments(self.doc)
  204. self.assertEqual(len(raws), 1)
  205. s = raws[0]
  206. self.assertEqual(s.offset, 0)
  207. self.assertEqual(s.offset_end, 20)
  208. self.assertEqual(len(s.entity_occurrences), 2)