rules.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # -*- coding: utf-8 -*-
  2. from functools import lru_cache
  3. from collections import namedtuple
  4. import refo
  5. import iepy
  6. TokenToMatch = namedtuple("TokenToMatch", "token lemma pos kinds, is_subj is_obj")
  7. def rule(answer, priority=0):
  8. if answer not in [False, True]:
  9. message = "Rule has invalid answer, it has to be either False or True"
  10. raise ValueError(message)
  11. def inner(f):
  12. f.priority = priority
  13. f.is_rule = True
  14. f.answer = answer
  15. return f
  16. return inner
  17. def is_rule(fun):
  18. """ Returns whether something is a rule or not """
  19. is_callable = hasattr(fun, '__call__')
  20. return is_callable and hasattr(fun, "is_rule") and fun.is_rule
  21. def load_rules():
  22. result = []
  23. for attr_name in dir(iepy.instance.rules):
  24. attr = getattr(iepy.instance.rules, attr_name)
  25. if is_rule(attr):
  26. result.append(attr)
  27. return result
  28. class ObjectAttrPredicate(refo.Predicate):
  29. def __init__(self, attr_name, attr_value):
  30. self.attr_name = attr_name
  31. self.attr_value = attr_value
  32. super().__init__(self._predicate)
  33. self.arg = attr_value
  34. def _predicate(self, obj):
  35. return getattr(obj, self.attr_name) == self.attr_value
  36. def obj_attr_predicate_factory(attr_values, attr_name):
  37. attr_values = attr_values.split()
  38. result = None
  39. for attr_value in attr_values:
  40. if result is None:
  41. result = ObjectAttrPredicate(attr_name, attr_value)
  42. else:
  43. result += ObjectAttrPredicate(attr_name, attr_value)
  44. return result
  45. def Token(string):
  46. return obj_attr_predicate_factory(string, "token")
  47. def Lemma(string):
  48. return obj_attr_predicate_factory(string, "lemma")
  49. def Pos(string):
  50. return obj_attr_predicate_factory(string, "pos")
  51. class Kind(refo.Predicate):
  52. def __init__(self, kind):
  53. self.kind = kind
  54. super().__init__(self._predicate)
  55. self.arg = kind
  56. def _predicate(self, obj):
  57. if hasattr(obj, "kinds"):
  58. return self.kind in obj.kinds
  59. return False
  60. class ConditionPredicate(refo.Predicate):
  61. def __init__(self, **kwargs):
  62. self.conditions = kwargs
  63. super().__init__(self._predicate)
  64. self.arg = str(kwargs)
  65. def _predicate(self, obj):
  66. for attr_name, attr_value in self.conditions.items():
  67. check_inclusion = False
  68. if attr_name.endswith("__has"):
  69. attr_name = attr_name[:-5]
  70. check_inclusion = True
  71. if hasattr(obj, attr_name):
  72. if check_inclusion:
  73. if not attr_value in getattr(obj, attr_name):
  74. return False
  75. else:
  76. if not getattr(obj, attr_name) == attr_value:
  77. return False
  78. else:
  79. return False
  80. return True
  81. @lru_cache(maxsize=None)
  82. def compile_rule(rule, relation):
  83. s, o = generate_subject_and_object(relation)
  84. return rule(s, o)
  85. @lru_cache(maxsize=8)
  86. def generate_subject_and_object(relation):
  87. subject_kind = relation.left_entity_kind.name
  88. object_kind = relation.right_entity_kind.name
  89. Subject = refo.Plus(ConditionPredicate(is_subj=True, kinds__has=subject_kind))
  90. Object = refo.Plus(ConditionPredicate(is_obj=True, kinds__has=object_kind))
  91. return Subject, Object
  92. @lru_cache(maxsize=8)
  93. def cached_segment_enriched_tokens(segment):
  94. return list(segment.get_enriched_tokens())
  95. @lru_cache(maxsize=8)
  96. def generate_tokens_to_match(evidence):
  97. tokens_to_match = []
  98. l_eo_id = evidence.left_entity_occurrence_id
  99. r_eo_id = evidence.right_entity_occurrence_id
  100. segment = evidence.segment
  101. for rich_token in cached_segment_enriched_tokens(segment):
  102. is_subj = False
  103. is_obj = False
  104. if l_eo_id in rich_token.eo_ids:
  105. is_subj = True
  106. if r_eo_id in rich_token.eo_ids:
  107. is_obj = True
  108. tokens_to_match.append(TokenToMatch(
  109. token=rich_token.token,
  110. pos=rich_token.pos,
  111. lemma=rich_token.lemma,
  112. kinds=set([x.name for x in rich_token.eo_kinds]),
  113. is_subj=is_subj,
  114. is_obj=is_obj,
  115. ))
  116. return tokens_to_match