123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import torch
- import torch.nn as nn
- import torch.optim as optim
- # 定义状态类
- class DependencyParserState:
- def __init__(self, sentence):
- self.sentence = sentence
- self.buffer = list(range(len(sentence))) # 待分析的词索引列表
- self.stack = [] # 用于构建依存树的栈
- self.heads = [-1] * len(sentence) # 词的依存头词索引列表
- self.transitions = [] # 转移动作序列
- def is_terminal(self):
- return len(self.buffer) == 0 and len(self.stack) == 1
- # 定义转移动作类
- class TransitionSystem:
- def __init__(self):
- self.actions = ['SHIFT', 'LEFT', 'RIGHT']
- def legal_actions(self, state):
- actions = []
- if len(state.buffer) > 0:
- actions.append('SHIFT')
- if len(state.stack) >= 2:
- actions.append('LEFT')
- if len(state.stack) >= 2:
- actions.append('RIGHT')
- return actions
- import torch.nn.functional as F
- class DependencyParser(nn.Module):
- def __init__(self, input_dim, hidden_dim, output_dim):
- super(DependencyParser, self).__init__()
- self.embedding = nn.Embedding(input_dim, hidden_dim)
- self.linear = nn.Linear(hidden_dim * 2, output_dim)
- def forward(self, state):
- if len(state.stack) < 2:
- return torch.tensor([0, 0, 1]) # 不执行转移
- head_word = state.stack[-1]
- modifier_word = state.stack[-2]
- head_embed = self.embedding(torch.tensor(head_word))
- modifier_embed = self.embedding(torch.tensor(modifier_word))
- combined = torch.cat((head_embed, modifier_embed), dim=0)
- output = self.linear(combined)
- return F.softmax(output, dim=0)
- def train_model(train_data, model, transition_system, optimizer, criterion, num_epochs):
- for epoch in range(num_epochs):
- total_loss = 0.0
- for sentence in train_data:
- state = DependencyParserState(sentence)
- while not state.is_terminal():
- legal_actions = transition_system.legal_actions(state)
- action_scores = model(state)
- action_index = torch.argmax(action_scores).item()
- action = transition_system.actions[action_index]
- print(action,",",legal_actions)
- if action in legal_actions:
- if action == 'SHIFT':
- state.stack.append(state.buffer.pop(0))
- elif action == 'LEFT':
- state.heads[state.stack[-2]] = state.stack[-1]
- state.stack.pop(-2)
- elif action == 'RIGHT':
- state.heads[state.stack[-1]] = state.stack[-2]
- state.stack.pop()
- state.transitions.append(action_index)
- else:
- state.transitions.append(action_index)
- print("====")
- break
- if len(state.transitions)>0:
- optimizer.zero_grad()
- print(action_scores.view(1, -1))
- print("===")
- print(torch.tensor(state.transitions))
- loss = criterion(action_scores.view(1, -1), torch.tensor(state.transitions))
- loss.backward()
- optimizer.step()
- total_loss += loss.item()
- print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_data)}')
- # 示例数据
- train_data = [[('I', 'PRP'), ('love', 'VBP'), ('natural', 'JJ'), ('language', 'NN'), ('processing', 'NN')],
- [('Dependency', 'NN'), ('syntax', 'NN'), ('analysis', 'NN'), ('is', 'VBZ'), ('important', 'JJ')],
- [('The', 'DT'), ('cat', 'NN'), ('chased', 'VBD'), ('the', 'DT'), ('mouse', 'NN')]]
- # 超参数
- input_dim = 10 # 输入维度(词汇量大小)
- hidden_dim = 5 # 隐藏层维度
- output_dim = 3 # 输出维度(转移动作数量)
- num_epochs = 10
- learning_rate = 0.01
- # 初始化模型、优化器和损失函数
- model = DependencyParser(input_dim, hidden_dim, output_dim)
- optimizer = optim.SGD(model.parameters(), lr=learning_rate)
- criterion = nn.CrossEntropyLoss()
- # 训练模型
- transition_system = TransitionSystem()
- train_model(train_data, model, transition_system, optimizer, criterion, num_epochs)
|