dependencyParsing.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. # 定义状态类
  5. class DependencyParserState:
  6. def __init__(self, sentence):
  7. self.sentence = sentence
  8. self.buffer = list(range(len(sentence))) # 待分析的词索引列表
  9. self.stack = [] # 用于构建依存树的栈
  10. self.heads = [-1] * len(sentence) # 词的依存头词索引列表
  11. self.transitions = [] # 转移动作序列
  12. def is_terminal(self):
  13. return len(self.buffer) == 0 and len(self.stack) == 1
  14. # 定义转移动作类
  15. class TransitionSystem:
  16. def __init__(self):
  17. self.actions = ['SHIFT', 'LEFT', 'RIGHT']
  18. def legal_actions(self, state):
  19. actions = []
  20. if len(state.buffer) > 0:
  21. actions.append('SHIFT')
  22. if len(state.stack) >= 2:
  23. actions.append('LEFT')
  24. if len(state.stack) >= 2:
  25. actions.append('RIGHT')
  26. return actions
  27. import torch.nn.functional as F
  28. class DependencyParser(nn.Module):
  29. def __init__(self, input_dim, hidden_dim, output_dim):
  30. super(DependencyParser, self).__init__()
  31. self.embedding = nn.Embedding(input_dim, hidden_dim)
  32. self.linear = nn.Linear(hidden_dim * 2, output_dim)
  33. def forward(self, state):
  34. if len(state.stack) < 2:
  35. return torch.tensor([0, 0, 1]) # 不执行转移
  36. head_word = state.stack[-1]
  37. modifier_word = state.stack[-2]
  38. head_embed = self.embedding(torch.tensor(head_word))
  39. modifier_embed = self.embedding(torch.tensor(modifier_word))
  40. combined = torch.cat((head_embed, modifier_embed), dim=0)
  41. output = self.linear(combined)
  42. return F.softmax(output, dim=0)
  43. def train_model(train_data, model, transition_system, optimizer, criterion, num_epochs):
  44. for epoch in range(num_epochs):
  45. total_loss = 0.0
  46. for sentence in train_data:
  47. state = DependencyParserState(sentence)
  48. while not state.is_terminal():
  49. legal_actions = transition_system.legal_actions(state)
  50. action_scores = model(state)
  51. action_index = torch.argmax(action_scores).item()
  52. action = transition_system.actions[action_index]
  53. print(action,",",legal_actions)
  54. if action in legal_actions:
  55. if action == 'SHIFT':
  56. state.stack.append(state.buffer.pop(0))
  57. elif action == 'LEFT':
  58. state.heads[state.stack[-2]] = state.stack[-1]
  59. state.stack.pop(-2)
  60. elif action == 'RIGHT':
  61. state.heads[state.stack[-1]] = state.stack[-2]
  62. state.stack.pop()
  63. state.transitions.append(action_index)
  64. else:
  65. state.transitions.append(action_index)
  66. print("====")
  67. break
  68. if len(state.transitions)>0:
  69. optimizer.zero_grad()
  70. print(action_scores.view(1, -1))
  71. print("===")
  72. print(torch.tensor(state.transitions))
  73. loss = criterion(action_scores.view(1, -1), torch.tensor(state.transitions))
  74. loss.backward()
  75. optimizer.step()
  76. total_loss += loss.item()
  77. print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_data)}')
  78. # 示例数据
  79. train_data = [[('I', 'PRP'), ('love', 'VBP'), ('natural', 'JJ'), ('language', 'NN'), ('processing', 'NN')],
  80. [('Dependency', 'NN'), ('syntax', 'NN'), ('analysis', 'NN'), ('is', 'VBZ'), ('important', 'JJ')],
  81. [('The', 'DT'), ('cat', 'NN'), ('chased', 'VBD'), ('the', 'DT'), ('mouse', 'NN')]]
  82. # 超参数
  83. input_dim = 10 # 输入维度(词汇量大小)
  84. hidden_dim = 5 # 隐藏层维度
  85. output_dim = 3 # 输出维度(转移动作数量)
  86. num_epochs = 10
  87. learning_rate = 0.01
  88. # 初始化模型、优化器和损失函数
  89. model = DependencyParser(input_dim, hidden_dim, output_dim)
  90. optimizer = optim.SGD(model.parameters(), lr=learning_rate)
  91. criterion = nn.CrossEntropyLoss()
  92. # 训练模型
  93. transition_system = TransitionSystem()
  94. train_model(train_data, model, transition_system, optimizer, criterion, num_epochs)