123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- import torch
- from dev.grammar.THU_preprocess import get_legal_postag,get_legal_relation
- class ParseDataset(torch.utils.data.Dataset):
- def __init__(self, dataset_file):
- self.legal_postag = set(get_legal_postag())
- self.legal_relation = set(get_legal_relation())
- self.data = self.get_sentence_data(self.read_sentences(dataset_file))
- def read_sentences(self,dataset_file):
- list_sentences = []
- with open(dataset_file,"r",encoding="utf8") as f:
- data = f.read()
- list_sent = data.split("\n\n")
- for sent in list_sent:
- l_sent = []
- for word_line in sent.split("\n"):
- if word_line.strip()=="":
- continue
- id,form,lemma,cpostag,postag,feats,head,deprel = word_line.strip().split("\t")[:8]
- word_dict = {"id": id, "form": form, "lemma": lemma, "cpostag": cpostag, "postag": postag,"feats": feats, "head": head, "deprel": deprel}
- l_sent.append(word_dict)
- list_sentences.append(l_sent)
- return list_sentences
- def get_sentence_data(self,list_sentences):
- list_data = []
- sent_i = -1
- not_legal = 0
- for sent in list_sentences:
- sent_i += 1
- print(sent_i,len(sent))
- is_legal = True
- for word in sent:
- if word["postag"] not in self.legal_postag:
- is_legal = False
- break
- if word["deprel"] not in self.legal_relation:
- is_legal = False
- break
- if is_legal:
- buffer = []
- buffer.extend(sent)
- stack = []
- transitions = []
- word_idx = 0
- while len(buffer)>0 or len(stack)>1:
- if len(stack)<=1:
- if len(buffer)>0:
- transitions.append("SHIFT")
- data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()}
- list_data.append(data)
- #excute the action
- stack.append(buffer.pop(0))
- else:
- if stack[-1]["head"]==stack[-2]["id"]:
- transitions.append("LEFTARC_%s"%(stack[-1]["deprel"]))
- data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()}
- list_data.append(data)
- stack.pop(-1)
- elif stack[-2]["head"]==stack[-1]["id"]:
- transitions.append("RIGHTARC_%s"%(stack[-2]["deprel"]))
- data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()}
- list_data.append(data)
- stack.pop(-2)
- else:
- if len(buffer)>0:
- transitions.append("SHIFT")
- data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()}
- list_data.append(data)
- #excute the action
- stack.append(buffer.pop(0))
- else:
- not_legal += len(stack)
- break
- print(len(list_data),not_legal)
- return list_data
- def __getitem__(self, index):
- return self.data[index]
- def __len__(self):
- return len(self.data)
- if __name__ == '__main__':
- filepath = "data/THU/train.conll"
- ParseDataset(filepath)
|