active_learning_core.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. from copy import copy
  2. import inspect
  3. import logging
  4. import pickle
  5. import random
  6. import os.path
  7. import numpy
  8. from sklearn.cross_validation import StratifiedKFold
  9. from sklearn.metrics import precision_recall_curve
  10. from iepy import defaults
  11. from iepy.extraction.relation_extraction_classifier import RelationExtractionClassifier
  12. logger = logging.getLogger(__name__)
  13. HIPREC = (10, 1) # Precision is 10x more important than recall
  14. HIREC = (1, 2) # Recall is 2x more important than precision
  15. class ActiveLearningCore:
  16. """
  17. IEPY's main class. Implements an active learning information extraction
  18. pipeline.
  19. From the user's point of view this class is meant to be used like this::
  20. extractor = ActiveLearningCore(relation, lbl_evidences)
  21. extractor.start() # blocking
  22. while UserIsNotTired and extractor.questions:
  23. question = extractor.questions[0]
  24. answer = ask_user(question)
  25. extractor.add_answer(question, answer)
  26. extractor.process()
  27. predictions = extractor.predict() # profit
  28. """
  29. #
  30. # IEPY User API
  31. #
  32. def __init__(self, relation, labeled_evidences, extractor_config=None,
  33. tradeoff=None, extractor=None, classifier=None):
  34. if extractor is None:
  35. extractor = RelationExtractionClassifier
  36. self.extractor = extractor
  37. self.relation = relation
  38. self.classifier = classifier
  39. self._setup_labeled_evidences(labeled_evidences)
  40. self._questions = list(self.candidate_evidence)
  41. if extractor_config is None:
  42. extractor_config = defaults.extractor_config
  43. self.extractor_config = extractor_config
  44. self.tradeoff = tradeoff
  45. self.aimed_tradeoff = None
  46. self.threshold = None
  47. _DUMPED_ATTRS = ['relation', 'extractor', 'extractor_config', 'classifier',
  48. 'tradeoff', 'aimed_tradeoff', 'threshold']
  49. def save(self, file_path):
  50. if os.path.exists(file_path):
  51. raise ValueError("Output file path already exists")
  52. to_dump = [getattr(self, attr) for attr in self._DUMPED_ATTRS]
  53. with open(file_path, 'wb') as filehandler:
  54. pickle.dump(to_dump, filehandler)
  55. @classmethod
  56. def load(cls, file_path, **kwargs):
  57. if not os.path.exists(file_path):
  58. raise ValueError("File does not exists")
  59. with open(file_path, 'rb') as filehandler:
  60. data = pickle.load(filehandler)
  61. loading_kwargs = copy(kwargs)
  62. if 'labeled_evidences' not in kwargs:
  63. loading_kwargs['labeled_evidences'] = {}
  64. after = {}
  65. specs = inspect.getargspec(cls)
  66. for attr, value in zip(cls._DUMPED_ATTRS, data):
  67. if attr in specs.args:
  68. loading_kwargs[attr] = value
  69. else:
  70. after[attr] = value
  71. self = cls(**loading_kwargs)
  72. for after_attr, value in after.items():
  73. print ('Setting ' + after_attr)
  74. setattr(self, after_attr, value)
  75. return self
  76. def start(self):
  77. """
  78. Organizes the internal information, and prepares the first "questions" that
  79. need to be answered.
  80. """
  81. # API compliance. Nothing is done on current implementation.s
  82. pass
  83. @property
  84. def questions(self):
  85. """Returns a list of candidate evidences that would be good to have
  86. labels for.
  87. Order is important: labels for evidences listed firsts are more valuable.
  88. """
  89. return self._questions
  90. def add_answer(self, evidence, answer):
  91. """
  92. Not blocking.
  93. Informs to the Core the evidence label (True or False) decided
  94. from the outside.
  95. """
  96. assert answer in (True, False)
  97. self.labeled_evidence[evidence] = answer
  98. for list_ in (self._questions, self.candidate_evidence): # TODO: Check performance. Should use set?
  99. list_.remove(evidence)
  100. def process(self):
  101. """
  102. Blocking.
  103. With all the labeled evidences, new questions are generated, optimizing the
  104. future gain of having those evidences labeled.
  105. After calling this method the values returned by `questions`
  106. and `predict` will change.
  107. """
  108. yesno = set(self.labeled_evidence.values())
  109. if len(yesno) > 2:
  110. msg = "Evidence is not binary! Can't proceed."
  111. logger.error(msg)
  112. raise ValueError(msg)
  113. if len(yesno) < 2:
  114. logger.debug("Not enough labels to train.")
  115. return
  116. if self.tradeoff:
  117. self.estimate_threshold()
  118. self.train_relation_classifier()
  119. self.rank_candidate_evidence()
  120. self.choose_questions()
  121. def predict(self, candidates):
  122. """
  123. Using the internal trained classifier, all candidate evicence are automatically
  124. labeled.
  125. Returns a dict {evidence: True/False}, where the boolean label indicates if
  126. the relation is present on that evidence or not.
  127. """
  128. if not self.classifier:
  129. logger.info("There is no trained classifier. Can't predict")
  130. return {}
  131. # for every already labeled candidate, instead of asking the classifier we'll use
  132. # the actual label
  133. knowns = copy(self.labeled_evidence)
  134. to_predict = [c for c in candidates if c not in knowns]
  135. if self.threshold is None:
  136. labels = self.classifier.predict(to_predict)
  137. else:
  138. scores = self.classifier.decision_function(to_predict)
  139. labels = scores >= self.threshold
  140. prediction = dict(zip(to_predict, map(bool, labels)))
  141. prediction.update(knowns)
  142. return prediction
  143. def estimate_threshold(self):
  144. scores, y_true = self.get_kfold_data()
  145. if scores is None:
  146. return
  147. prec, rec, thres = precision_recall_curve(y_true, scores)
  148. prec[-1] = 0.0 # To avoid choosing the last phony value
  149. c_prec, c_rec = self.tradeoff
  150. # Below is a linear function on precision and recall, expressed using
  151. # numpy notation, we're maximizing it.
  152. i = (prec * c_prec + rec * c_rec).argmax() # Index of the maximum
  153. assert i < len(thres) # Because prec[-1] is 0.0
  154. self.aimed_tradeoff = (prec[i], rec[i])
  155. self.threshold = thres[i]
  156. s = "Using {} samples, threshold aiming at precision={:.4f} and recall={:.4f}"
  157. logger.debug(s.format(len(scores), prec[i], rec[i]))
  158. # Instance attributes:
  159. # questions: A list of evidence
  160. # ranked_candidate_evidence: A dict candidate_evidence -> float
  161. # aimed_tradeoff: A (prec, rec) tuple with the precision/recall tradeoff
  162. # that the threshold aims to achieve.
  163. #
  164. # Private methods
  165. #
  166. def _setup_labeled_evidences(self, labeled_evidences):
  167. self.candidate_evidence = []
  168. self.labeled_evidence = {}
  169. for e, lbl in labeled_evidences.items():
  170. e.relation = self.relation
  171. if lbl is None:
  172. self.candidate_evidence.append(e)
  173. else:
  174. self.labeled_evidence[e] = bool(lbl)
  175. if not self.candidate_evidence:
  176. raise ValueError("Cannot start core without candidate evidence")
  177. logger.info("Loaded {} candidate evidence and {} labeled evidence".format(
  178. len(self.candidate_evidence), len(self.labeled_evidence)))
  179. def train_relation_classifier(self):
  180. X = []
  181. y = []
  182. for evidence, score in self.labeled_evidence.items():
  183. X.append(evidence)
  184. y.append(int(score))
  185. assert y[-1] in (True, False)
  186. self.classifier = self.extractor(**self.extractor_config)
  187. self.classifier.fit(X, y)
  188. def rank_candidate_evidence(self):
  189. if not self.candidate_evidence:
  190. self.ranked_candidate_evidence = {}
  191. logger.info("No evidence left to rank.")
  192. return
  193. N = min(10 * len(self.labeled_evidence), len(self.candidate_evidence))
  194. logger.info("Ranking a sample of {} candidate evidence".format(N))
  195. sample = random.sample(self.candidate_evidence, N)
  196. ranks = self.classifier.decision_function(sample)
  197. self.ranked_candidate_evidence = dict(zip(sample, ranks))
  198. ranks = [abs(x) for x in ranks]
  199. logger.info("Ranking completed, lowest absolute rank={}, "
  200. "highest absolute rank={}".format(min(ranks), max(ranks)))
  201. def choose_questions(self):
  202. # Criteria: Answer first candidates with decision function near 0
  203. # because they are the most uncertain for the classifier.
  204. self._questions = sorted(self.ranked_candidate_evidence,
  205. key=lambda x: abs(self.ranked_candidate_evidence[x]))
  206. def get_kfold_data(self):
  207. """
  208. Perform k-fold cross validation and return (scores, y_true) where
  209. scores is a numpy array with decision function scores and y_true
  210. is a numpy array with the true label for that evidence.
  211. """
  212. allX = []
  213. ally = []
  214. for evidence, score in self.labeled_evidence.items():
  215. allX.append(evidence)
  216. ally.append(int(score))
  217. assert ally[-1] in (True, False)
  218. allX = numpy.array(allX)
  219. ally = numpy.array(ally)
  220. if numpy.bincount(ally).min() < 5:
  221. return None, None # Too little data to do 5-fold cross validation
  222. logger.debug("Performing 5-fold cross validation")
  223. scores = []
  224. y_true = []
  225. for train_index, test_index in StratifiedKFold(ally, 5):
  226. X = allX[train_index]
  227. y = ally[train_index]
  228. c = self.extractor(**self.extractor_config)
  229. c.fit(X, y)
  230. y_true.append(ally[test_index])
  231. scores.append(c.decision_function(allX[test_index]))
  232. return numpy.hstack(scores), numpy.hstack(y_true)