tagger.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. import os.path
  3. import logging
  4. from nltk.tag.stanford import StanfordPOSTagger
  5. import wget
  6. from iepy.preprocess.pipeline import BasePreProcessStepRunner, PreProcessSteps
  7. from iepy.utils import DIRS, unzip_file
  8. logger = logging.getLogger(__name__)
  9. stanford_postagger_name = 'stanford-postagger-2014-01-04'
  10. download_url_base = 'http://nlp.stanford.edu/software/'
  11. class TaggerRunner(BasePreProcessStepRunner):
  12. """Wrapper to insert a generic callable sentence POS tagger into the pipeline.
  13. In order to run, require documents with sentence splitting already done.
  14. """
  15. step = PreProcessSteps.tagging
  16. def __init__(self, postagger, override=False):
  17. """override:
  18. """
  19. self.postagger = postagger
  20. self.override = override
  21. def __call__(self, doc):
  22. if not doc.was_preprocess_step_done(PreProcessSteps.sentencer):
  23. # cannot proceed if the document wasn't split in senteces
  24. return
  25. if not self.override and doc.was_preprocess_step_done(PreProcessSteps.tagging):
  26. return
  27. tagged_doc = []
  28. for ts in self.postagger(doc.get_sentences()):
  29. tagged_doc.extend(tag for token, tag in ts)
  30. assert len(tagged_doc) == len(doc.tokens)
  31. doc.set_tagging_result(tagged_doc)
  32. doc.save()
  33. logger.debug("POS tagged a document")
  34. class StanfordTaggerRunner(TaggerRunner):
  35. def __init__(self, override=False):
  36. tagger_path = os.path.join(DIRS.user_data_dir, stanford_postagger_name)
  37. if not os.path.exists(tagger_path):
  38. raise LookupError("Stanford POS tagger not found. Try running the "
  39. "command download_third_party_data.py")
  40. postagger = StanfordPOSTagger(
  41. os.path.join(tagger_path, 'models', 'english-bidirectional-distsim.tagger'),
  42. os.path.join(tagger_path, 'stanford-postagger.jar'),
  43. encoding='utf8')
  44. super(StanfordTaggerRunner, self).__init__(postagger.tag_sents, override)
  45. def download():
  46. logger.info("Downloading Stanford POS tagger...")
  47. try:
  48. StanfordTaggerRunner()
  49. except LookupError:
  50. if not os.path.exists(DIRS.user_data_dir):
  51. os.mkdir(DIRS.user_data_dir)
  52. os.chdir(DIRS.user_data_dir)
  53. package_filename = '{0}.zip'.format(stanford_postagger_name)
  54. zip_path = os.path.join(DIRS.user_data_dir, package_filename)
  55. wget.download(download_url_base + package_filename)
  56. unzip_file(zip_path, DIRS.user_data_dir)
  57. else:
  58. logger.info("Stanford POS tagger is already downloaded and functional.")