iepy_runner.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """
  2. Run IEPY active-learning extractor
  3. Usage:
  4. iepy_runner.py [options] <relation_name> <output>
  5. iepy_runner.py [options] --db-store <relation_name>
  6. iepy_runner.py -h | --help | --version
  7. Options:
  8. --store-extractor=<extractor_output> Stores the trained classifier
  9. --trained-extractor=<extractor_path> Load an already trained extractor
  10. --db-store Stores the predictions on the database
  11. --no-questions Won't generate questions to answer. Will predict
  12. as is. Should be used with --trained-extractor
  13. --tune-for=<tune-for> Predictions tuning. Options are high-prec
  14. or high-recall [default: high-prec]
  15. --extractor-config=<config.json> Sets the extractor config
  16. --version Version number
  17. -h --help Show this screen
  18. """
  19. import os
  20. import json
  21. import logging
  22. from docopt import docopt
  23. from sys import exit
  24. import iepy
  25. INSTANCE_PATH = iepy.setup(__file__)
  26. from iepy.extraction.active_learning_core import ActiveLearningCore, HIPREC, HIREC
  27. from iepy.data.db import CandidateEvidenceManager
  28. from iepy.data.models import Relation
  29. from iepy.extraction.terminal import TerminalAdministration
  30. from iepy.data import output
  31. def print_all_relations():
  32. print("All available relations:")
  33. for relation in Relation.objects.all():
  34. print(" {}".format(relation))
  35. def load_labeled_evidences(relation, evidences):
  36. CEM = CandidateEvidenceManager # shorcut
  37. return CEM.labels_for(relation, evidences, CEM.conflict_resolution_newest_wins)
  38. def _get_tuning_mode(opts):
  39. if opts['--tune-for'] == 'high-prec':
  40. tuning_mode = HIPREC
  41. elif opts['--tune-for'] == 'high-recall':
  42. tuning_mode = HIREC
  43. else:
  44. print ('Invalid tuning mode')
  45. print (__doc__)
  46. exit(1)
  47. return tuning_mode
  48. def _get_relation(opts):
  49. relation_name = opts['<relation_name>']
  50. try:
  51. relation = Relation.objects.get(name=relation_name)
  52. except Relation.DoesNotExist:
  53. print("Relation {!r} non existent".format(relation_name))
  54. print_all_relations()
  55. exit(1)
  56. return relation
  57. def _load_extractor(opts, relation, labeled_evidences):
  58. extractor_path = opts.get('--trained-extractor')
  59. try:
  60. iextractor = ActiveLearningCore.load(extractor_path,
  61. labeled_evidences=labeled_evidences)
  62. except ValueError:
  63. print("Error: unable to load extractor, invalid file")
  64. exit(1)
  65. if iextractor.relation != relation:
  66. print('The loaded extractor is not for the requested relation'
  67. ' but for relation {} instead'.format(iextractor.relation))
  68. exit(1)
  69. print('Extractor successfully loaded')
  70. return iextractor
  71. def _construct_extractor(opts, relation, labeled_evidences, tuning_mode):
  72. config_filepath = opts.get("--extractor-config")
  73. if not config_filepath:
  74. config_filepath = os.path.join(INSTANCE_PATH, "extractor_config.json")
  75. if not os.path.exists(config_filepath):
  76. print("Error: extractor config does not exists, please create the "
  77. "file extractor_config.json or use the --extractor-config")
  78. exit(1)
  79. with open(config_filepath) as filehandler:
  80. try:
  81. extractor_config = json.load(filehandler)
  82. except Exception as error:
  83. print("Error: unable to load extractor config: {}".format(error))
  84. exit(1)
  85. iextractor = ActiveLearningCore(
  86. relation, labeled_evidences, extractor_config, tradeoff=tuning_mode
  87. )
  88. return iextractor
  89. def run_from_command_line():
  90. opts = docopt(__doc__, version=iepy.__version__)
  91. logging.basicConfig(level=logging.INFO, format='%(message)s')
  92. logging.getLogger("featureforge").setLevel(logging.WARN)
  93. tuning_mode = _get_tuning_mode(opts)
  94. relation = _get_relation(opts)
  95. candidates = CandidateEvidenceManager.candidates_for_relation(relation)
  96. labeled_evidences = load_labeled_evidences(relation, candidates)
  97. if opts.get('--trained-extractor'):
  98. iextractor = _load_extractor(opts, relation, labeled_evidences)
  99. was_ever_trained = True
  100. opts["--no-questions"] = True
  101. else:
  102. iextractor = _construct_extractor(opts, relation, labeled_evidences, tuning_mode)
  103. iextractor.start()
  104. was_ever_trained = False
  105. if not opts.get("--no-questions", False):
  106. questions_loop(iextractor, relation, was_ever_trained)
  107. # Candidates generator was consumed when generating labeled_evidences, so we'll
  108. # define it fresh again
  109. candidates = CandidateEvidenceManager.candidates_for_relation(relation)
  110. # Predict and store output
  111. predictions = iextractor.predict(candidates) # asking predictions for EVERYTHING
  112. if not predictions:
  113. print("Nothing was predicted")
  114. exit(1)
  115. if opts.get("--db-store"):
  116. output.dump_predictions_to_database(relation, predictions)
  117. output_file = opts.get("<output>")
  118. if output_file:
  119. output.dump_runner_output_to_csv(predictions, output_file)
  120. classifier_output = opts.get("--store-extractor")
  121. if classifier_output:
  122. iextractor.save(classifier_output)
  123. def questions_loop(iextractor, relation, was_ever_trained):
  124. STOP = u'STOP'
  125. term = TerminalAdministration(
  126. relation,
  127. extra_options=[(STOP, u'Stop execution')]
  128. )
  129. while iextractor.questions:
  130. questions = list(iextractor.questions) # copying the list
  131. term.update_candidate_evidences_to_label(questions)
  132. result = term()
  133. i = 0
  134. for c, label_value in load_labeled_evidences(relation, questions).items():
  135. if label_value is not None:
  136. iextractor.add_answer(c, label_value)
  137. i += 1
  138. print ('Added %s new human labels to the extractor core' % i)
  139. iextractor.process()
  140. was_ever_trained = True
  141. if result == STOP:
  142. break
  143. if not was_ever_trained:
  144. # It's needed to run some process before asking for predictions
  145. iextractor.process()
  146. if __name__ == u'__main__':
  147. run_from_command_line()