features.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import ast
  2. from string import punctuation
  3. import importlib
  4. import refo
  5. from featureforge.feature import output_schema
  6. from iepy.extraction.rules import generate_tokens_to_match, compile_rule
  7. from iepy.data.models import Relation
  8. punct_set = set(punctuation)
  9. def all_len_two(v):
  10. return all(len(x) == 2 for x in v)
  11. def all_len_two_inner_too(v):
  12. return all(len(x) == 2 and all(len(y) == 2 for y in x) for x in v)
  13. def binary_values(x):
  14. return x in (0, 1)
  15. def ge_than_zero(v):
  16. return v >= 0
  17. def ge_than_two(v):
  18. return v >= 2
  19. _loaded_modules = {}
  20. def load_module(module_name):
  21. module = _loaded_modules.get(module_name)
  22. if module is None:
  23. module = importlib.import_module(module_name)
  24. _loaded_modules[module_name] = module
  25. return module
  26. def rule_wrapper(rule_feature, relation):
  27. @output_schema(int, binary_values)
  28. def inner(evidence):
  29. regex = compile_rule(rule_feature, relation)
  30. tokens_to_match = generate_tokens_to_match(evidence)
  31. return int(bool(refo.match(regex, tokens_to_match)))
  32. return inner
  33. def parse_features(feature_names):
  34. features = []
  35. for line in feature_names:
  36. if not line or line != line.strip():
  37. raise ValueError("Garbage in feature set: {!r}".format(line))
  38. fname, _, args = line.partition(" ")
  39. if fname.count("."): # Is a module path
  40. feature_module, feature_name = fname.rsplit(".", 1)
  41. try:
  42. module = load_module(feature_module)
  43. except ImportError:
  44. raise KeyError("Couldn't load module {!r}".format(feature_module))
  45. try:
  46. feature = getattr(module, feature_name)
  47. except AttributeError:
  48. raise KeyError(
  49. "Feature {!r} not found in {!r} module".format(feature_name, feature_module)
  50. )
  51. if feature_module.endswith(".rules"):
  52. relation = Relation.objects.get(name=module.RELATION)
  53. feature = rule_wrapper(feature, relation)
  54. else:
  55. try:
  56. feature = globals()[fname]
  57. except KeyError:
  58. raise KeyError("There is not such feature: "
  59. "{!r}".format(fname))
  60. args = args.strip()
  61. if args:
  62. args = ast.literal_eval(args + ",")
  63. feature = feature(*args)
  64. features.append(feature)
  65. return features
  66. @output_schema({str})
  67. def bag_of_words(datapoint):
  68. return set(words(datapoint))
  69. @output_schema({str})
  70. def bag_of_pos(datapoint):
  71. return set(pos(datapoint))
  72. @output_schema({(str,)}, all_len_two)
  73. def bag_of_word_bigrams(datapoint):
  74. return set(bigrams(words(datapoint)))
  75. @output_schema({(str,)}, all_len_two)
  76. def bag_of_wordpos(datapoint):
  77. return set(zip(words(datapoint), pos(datapoint)))
  78. @output_schema({((str,),)}, all_len_two_inner_too)
  79. def bag_of_wordpos_bigrams(datapoint):
  80. xs = list(zip(words(datapoint), pos(datapoint)))
  81. return set(bigrams(xs))
  82. @output_schema({str})
  83. def bag_of_words_in_between(datapoint):
  84. i, j = in_between_offsets(datapoint)
  85. return set(words(datapoint)[i:j])
  86. @output_schema({str})
  87. def bag_of_pos_in_between(datapoint):
  88. i, j = in_between_offsets(datapoint)
  89. return set(pos(datapoint)[i:j])
  90. @output_schema({(str,)}, all_len_two)
  91. def bag_of_word_bigrams_in_between(datapoint):
  92. i, j = in_between_offsets(datapoint)
  93. return set(bigrams(words(datapoint)[i:j]))
  94. @output_schema({(str,)}, all_len_two)
  95. def bag_of_wordpos_in_between(datapoint):
  96. i, j = in_between_offsets(datapoint)
  97. return set(list(zip(words(datapoint), pos(datapoint)))[i:j])
  98. @output_schema({((str,),)}, all_len_two_inner_too)
  99. def bag_of_wordpos_bigrams_in_between(datapoint):
  100. i, j = in_between_offsets(datapoint)
  101. xs = list(zip(words(datapoint), pos(datapoint)))[i:j]
  102. return set(bigrams(xs))
  103. @output_schema(int, binary_values)
  104. def entity_order(datapoint):
  105. """
  106. Returns 1 if A occurs prior to B in the segment and 0 otherwise.
  107. """
  108. A, B = get_AB(datapoint)
  109. if A.segment_offset < B.segment_offset:
  110. return 1
  111. return 0
  112. @output_schema(int, ge_than_zero)
  113. def entity_distance(datapoint):
  114. """
  115. Returns the distance (in tokens) that separates the ocurrence of the
  116. entities.
  117. """
  118. i, j = in_between_offsets(datapoint)
  119. return j - i
  120. @output_schema(int, ge_than_zero)
  121. def other_entities_in_between(datapoint):
  122. """
  123. Returns the number of entity ocurrences in between the datapoint entities.
  124. """
  125. n = 0
  126. i, j = in_between_offsets(datapoint)
  127. for other in datapoint.all_eos:
  128. if other.segment_offset >= i and other.segment_offset < j:
  129. n += 1
  130. return n
  131. @output_schema(int, ge_than_two)
  132. def total_number_of_entities(datapoint):
  133. """
  134. Returns the number of entity in the text segment
  135. """
  136. return len(datapoint.all_eos)
  137. @output_schema(int, ge_than_zero)
  138. def verbs_count_in_between(datapoint):
  139. """
  140. Returns the number of Verb POS tags in between of the 2 entities.
  141. """
  142. i, j = in_between_offsets(datapoint)
  143. return len(verbs(datapoint, i, j))
  144. @output_schema(int, ge_than_zero)
  145. def verbs_count(datapoint):
  146. """
  147. Returns the number of Verb POS tags in the datapoint.
  148. """
  149. return len(verbs(datapoint))
  150. @output_schema(int, binary_values)
  151. def in_same_sentence(datapoint): # TODO: Test
  152. """
  153. Returns 1 if the datapoints entities are in the same sentence, 0 otherwise.
  154. """
  155. i, j = in_between_offsets(datapoint)
  156. for k in datapoint.segment.sentences:
  157. if i <= k and k < j:
  158. return 0
  159. return 1
  160. @output_schema(int, binary_values)
  161. def symbols_in_between(datapoint):
  162. """
  163. Returns 1 if there are symbols between the entities, 0 if not.
  164. """
  165. i, j = in_between_offsets(datapoint)
  166. tokens = datapoint.segment.tokens[i:j]
  167. for tkn in tokens:
  168. if punct_set.intersection(tkn):
  169. return 1
  170. return 0
  171. @output_schema(int, ge_than_zero)
  172. def number_of_tokens(datapoint):
  173. return len(datapoint.segment.tokens)
  174. ###
  175. # Aux functions
  176. ###
  177. def words(datapoint):
  178. return [word.lower() for word in datapoint.segment.tokens]
  179. def pos(datapoint):
  180. return list(map(str, datapoint.segment.postags))
  181. def verbs(datapoint, slice_i=0, slice_j=None):
  182. pairs = zip(datapoint.segment.tokens, datapoint.segment.postags)
  183. if slice_j is not None:
  184. pairs = list(pairs)[slice_i:slice_j]
  185. return [tkn for tkn, tag in pairs if tag.startswith(u'VB')]
  186. def bigrams(xs):
  187. return list(zip(xs, xs[1:]))
  188. def in_between_offsets(datapoint):
  189. A, B = get_AB(datapoint)
  190. if A.segment_offset_end < B.segment_offset:
  191. return A.segment_offset_end, B.segment_offset
  192. elif B.segment_offset_end < A.segment_offset:
  193. return B.segment_offset_end, A.segment_offset
  194. elif A.segment_offset_end < B.segment_offset_end:
  195. return A.segment_offset_end, A.segment_offset_end
  196. return B.segment_offset_end, B.segment_offset_end
  197. def get_AB(datapoint):
  198. a = datapoint.right_entity_occurrence
  199. b = datapoint.left_entity_occurrence
  200. return a, b