utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from getpass import getuser
  2. import csv
  3. import gzip
  4. import logging
  5. import os
  6. import sys
  7. import tarfile
  8. import zipfile
  9. import wget
  10. from appdirs import AppDirs
  11. logger = logging.getLogger(__name__)
  12. DIRS = AppDirs('iepy', getuser())
  13. # if not os.path.exists(DIRS.user_data_dir):
  14. # # making sure that user_data_dir exists
  15. # os.mkdir(DIRS.user_data_dir)
  16. def unzip(zipped_list, n):
  17. """returns n lists with the elems of zipped_list unsplitted.
  18. The general case could be solved with zip(*zipped_list), but here we
  19. are also dealing with:
  20. - un-zipping empy list to n empty lists
  21. - ensuring that all zipped items in zipped_list have lenght n, raising
  22. ValueError if not.
  23. """
  24. if not zipped_list:
  25. return tuple([[]] * n)
  26. else:
  27. if not all(isinstance(x, tuple) and len(x) == n for x in zipped_list):
  28. raise ValueError
  29. return zip(*zipped_list)
  30. def unzip_from_url(zip_url, extraction_base_path):
  31. got_zipfile = None
  32. try:
  33. got_zipfile = wget.download(zip_url)
  34. print('') # just because wget progress-bar finishes a line with no EOL
  35. unzip_file(got_zipfile, extraction_base_path)
  36. finally:
  37. if zipfile:
  38. os.remove(got_zipfile)
  39. def unzip_file(zip_path, extraction_base_path):
  40. if zip_path.endswith('.tar.gz'):
  41. with tarfile.open(zip_path, mode='r:gz') as tfile:
  42. tfile.extractall(extraction_base_path)
  43. else:
  44. zfile = zipfile.ZipFile(zip_path)
  45. zfile.extractall(extraction_base_path)
  46. def make_feature_list(text):
  47. return [x.strip() for x in text.split("\n") if x.strip()]
  48. def evaluate(predicted_knowledge, gold_knowledge):
  49. """Computes evaluation metrics for a predicted knowledge with respect to a
  50. gold (or reference) knowledge. Returns a dictionary with the results.
  51. """
  52. # ignore predicted facts with no evidence:
  53. predicted_positives = set([p for p in predicted_knowledge.keys() if p.segment])
  54. gold_positives = set([p for p, b in gold_knowledge.items() if b])
  55. correct_positives = predicted_positives & gold_positives
  56. result = {}
  57. result['correct'] = correct = len(correct_positives)
  58. result['predicted'] = predicted = len(predicted_positives)
  59. result['gold'] = gold = len(gold_positives)
  60. if predicted > 0:
  61. result['precision'] = precision = float(correct) / predicted
  62. else:
  63. result['precision'] = precision = 1.0
  64. if gold > 0:
  65. result['recall'] = recall = float(correct) / gold
  66. else:
  67. result['recall'] = recall = 1.0
  68. if precision + recall > 0.0:
  69. result['f1'] = 2 * precision * recall / (precision + recall)
  70. else:
  71. result['f1'] = 0.0
  72. return result
  73. def csv_to_iepy(filepath):
  74. print ('Importing Documents to IEPY from {}'.format(filepath))
  75. from iepy.data.db import DocumentManager
  76. if filepath.endswith(".gz"):
  77. fin = gzip.open(filepath, "rt")
  78. else:
  79. fin = open(filepath, "rt")
  80. reader = csv.DictReader(fin)
  81. expected_fnames = ['document_id', 'document_text']
  82. if not set(reader.fieldnames).issuperset(expected_fnames):
  83. msg = "Couldn't find the expected field names on the provided csv {}"
  84. sys.exit(msg.format(expected_fnames))
  85. name = os.path.basename(filepath)
  86. docdb = DocumentManager()
  87. seen = set()
  88. i = 0
  89. while True:
  90. try:
  91. d = next(reader)
  92. except StopIteration:
  93. break
  94. except csv.Error as error:
  95. logger.warn("Couldn't load document: {}".format(error))
  96. continue
  97. i += 1
  98. doc_id = d["document_id"]
  99. if doc_id in seen:
  100. continue
  101. seen.add(doc_id)
  102. docdb.create_document(
  103. identifier=doc_id,
  104. text=d["document_text"],
  105. metadata={"input_filename": name},
  106. update_mode=True
  107. )
  108. print ('Added {} documents'.format(i))