ParseDataset.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import torch
  2. from dev.grammar.THU_preprocess import get_legal_postag,get_legal_relation
  3. class ParseDataset(torch.utils.data.Dataset):
  4. def __init__(self, dataset_file):
  5. self.legal_postag = set(get_legal_postag())
  6. self.legal_relation = set(get_legal_relation())
  7. self.data = self.get_sentence_data(self.read_sentences(dataset_file))
  8. def read_sentences(self,dataset_file):
  9. list_sentences = []
  10. with open(dataset_file,"r",encoding="utf8") as f:
  11. data = f.read()
  12. list_sent = data.split("\n\n")
  13. for sent in list_sent:
  14. l_sent = []
  15. for word_line in sent.split("\n"):
  16. if word_line.strip()=="":
  17. continue
  18. id,form,lemma,cpostag,postag,feats,head,deprel = word_line.strip().split("\t")[:8]
  19. word_dict = {"id": id, "form": form, "lemma": lemma, "cpostag": cpostag, "postag": postag,"feats": feats, "head": head, "deprel": deprel}
  20. l_sent.append(word_dict)
  21. list_sentences.append(l_sent)
  22. return list_sentences
  23. def get_sentence_data(self,list_sentences):
  24. list_data = []
  25. sent_i = -1
  26. not_legal = 0
  27. for sent in list_sentences:
  28. sent_i += 1
  29. print(sent_i,len(sent))
  30. is_legal = True
  31. for word in sent:
  32. if word["postag"] not in self.legal_postag:
  33. is_legal = False
  34. break
  35. if word["deprel"] not in self.legal_relation:
  36. is_legal = False
  37. break
  38. if is_legal:
  39. buffer = []
  40. buffer.extend(sent)
  41. stack = []
  42. transitions = []
  43. word_idx = 0
  44. while len(buffer)>0 or len(stack)>1:
  45. if len(stack)<=1:
  46. if len(buffer)>0:
  47. transitions.append("SHIFT")
  48. data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()}
  49. list_data.append(data)
  50. #excute the action
  51. stack.append(buffer.pop(0))
  52. else:
  53. if stack[-1]["head"]==stack[-2]["id"]:
  54. transitions.append("LEFTARC_%s"%(stack[-1]["deprel"]))
  55. data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()}
  56. list_data.append(data)
  57. stack.pop(-1)
  58. elif stack[-2]["head"]==stack[-1]["id"]:
  59. transitions.append("RIGHTARC_%s"%(stack[-2]["deprel"]))
  60. data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()}
  61. list_data.append(data)
  62. stack.pop(-2)
  63. else:
  64. if len(buffer)>0:
  65. transitions.append("SHIFT")
  66. data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()}
  67. list_data.append(data)
  68. #excute the action
  69. stack.append(buffer.pop(0))
  70. else:
  71. not_legal += len(stack)
  72. break
  73. print(len(list_data),not_legal)
  74. return list_data
  75. def __getitem__(self, index):
  76. return self.data[index]
  77. def __len__(self):
  78. return len(self.data)
  79. if __name__ == '__main__':
  80. filepath = "data/THU/train.conll"
  81. ParseDataset(filepath)