import torch from dev.grammar.THU_preprocess import get_legal_postag,get_legal_relation class ParseDataset(torch.utils.data.Dataset): def __init__(self, dataset_file): self.legal_postag = set(get_legal_postag()) self.legal_relation = set(get_legal_relation()) self.data = self.get_sentence_data(self.read_sentences(dataset_file)) def read_sentences(self,dataset_file): list_sentences = [] with open(dataset_file,"r",encoding="utf8") as f: data = f.read() list_sent = data.split("\n\n") for sent in list_sent: l_sent = [] for word_line in sent.split("\n"): if word_line.strip()=="": continue id,form,lemma,cpostag,postag,feats,head,deprel = word_line.strip().split("\t")[:8] word_dict = {"id": id, "form": form, "lemma": lemma, "cpostag": cpostag, "postag": postag,"feats": feats, "head": head, "deprel": deprel} l_sent.append(word_dict) list_sentences.append(l_sent) return list_sentences def get_sentence_data(self,list_sentences): list_data = [] sent_i = -1 not_legal = 0 for sent in list_sentences: sent_i += 1 print(sent_i,len(sent)) is_legal = True for word in sent: if word["postag"] not in self.legal_postag: is_legal = False break if word["deprel"] not in self.legal_relation: is_legal = False break if is_legal: buffer = [] buffer.extend(sent) stack = [] transitions = [] word_idx = 0 while len(buffer)>0 or len(stack)>1: if len(stack)<=1: if len(buffer)>0: transitions.append("SHIFT") data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()} list_data.append(data) #excute the action stack.append(buffer.pop(0)) else: if stack[-1]["head"]==stack[-2]["id"]: transitions.append("LEFTARC_%s"%(stack[-1]["deprel"])) data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()} list_data.append(data) stack.pop(-1) elif stack[-2]["head"]==stack[-1]["id"]: transitions.append("RIGHTARC_%s"%(stack[-2]["deprel"])) data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()} list_data.append(data) stack.pop(-2) else: if len(buffer)>0: transitions.append("SHIFT") data = {"stack":stack.copy(),"buffer":buffer.copy(),"transitions":transitions.copy()} list_data.append(data) #excute the action stack.append(buffer.pop(0)) else: not_legal += len(stack) break print(len(list_data),not_legal) return list_data def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) if __name__ == '__main__': filepath = "data/THU/train.conll" ParseDataset(filepath)