function.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from __future__ import absolute_import
  2. import time
  3. import lib.utils.utils as utils
  4. import torch
  5. class AverageMeter(object):
  6. """Computes and stores the average and current value"""
  7. def __init__(self):
  8. self.val = 0
  9. self.avg = 0
  10. self.sum = 0
  11. self.count = 0
  12. self.reset()
  13. def reset(self):
  14. self.val = 0
  15. self.avg = 0
  16. self.sum = 0
  17. self.count = 0
  18. def update(self, val, n=1):
  19. self.val = val
  20. self.sum += val * n
  21. self.count += n
  22. self.avg = self.sum / self.count
  23. def train(config, train_loader, dataset, converter, model, criterion, optimizer, device, epoch, writer_dict=None, output_dict=None):
  24. batch_time = AverageMeter()
  25. data_time = AverageMeter()
  26. losses = AverageMeter()
  27. model.train()
  28. end = time.time()
  29. for i, (inp, idx) in enumerate(train_loader):
  30. # measure data time
  31. data_time.update(time.time() - end)
  32. labels = utils.get_batch_label(dataset, idx)
  33. inp = inp.to(device)
  34. # print("inp",inp.size())
  35. # 中间层输出
  36. # outputs = []
  37. # def hook(module, input, output):
  38. # outputs.append(output)
  39. # handles = []
  40. # 注册钩子函数
  41. # for layer in model.children():
  42. # handle = layer.register_forward_hook(hook)
  43. # handles.append(handle)
  44. # inference
  45. preds = model(inp)
  46. # 打印中间层的输出
  47. # for output in outputs:
  48. # print(output.,output[0])
  49. # pass
  50. # for output,named_children in zip(outputs,model.named_children()):
  51. # print(named_children[0])
  52. # print(output.size())
  53. # pass
  54. # print("===============================================")
  55. # 移除钩子函数
  56. # for handle in handles:
  57. # handle.remove()
  58. # compute loss
  59. batch_size = inp.size(0)
  60. text, length = converter.encode(labels) # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标
  61. preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize
  62. # print('===============')
  63. # print(batch_size)
  64. # print(inp.size())
  65. # print('text',text)
  66. # text = text.reshape(16,10)
  67. # print('text-size',text.size())
  68. # print(length)
  69. # print(preds_size)
  70. # print('preds',preds)
  71. # print('preds-size',preds.size())
  72. # print(preds.size())
  73. loss = criterion(preds, text, preds_size, length)
  74. optimizer.zero_grad()
  75. loss.backward()
  76. optimizer.step()
  77. losses.update(loss.item(), inp.size(0))
  78. batch_time.update(time.time()-end)
  79. if i % config.PRINT_FREQ == 0:
  80. msg = 'Epoch: [{0}][{1}/{2}]\t' \
  81. 'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
  82. 'Speed {speed:.1f} samples/s\t' \
  83. 'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
  84. 'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
  85. epoch, i, len(train_loader), batch_time=batch_time,
  86. speed=inp.size(0)/batch_time.val,
  87. data_time=data_time, loss=losses)
  88. print(msg)
  89. if writer_dict:
  90. writer = writer_dict['writer']
  91. global_steps = writer_dict['train_global_steps']
  92. writer.add_scalar('train_loss', losses.avg, global_steps)
  93. writer_dict['train_global_steps'] = global_steps + 1
  94. end = time.time()
  95. def validate(config, val_loader, dataset, converter, model, criterion, device, epoch, writer_dict, output_dict):
  96. losses = AverageMeter()
  97. model.eval()
  98. n_correct = 0
  99. with torch.no_grad():
  100. for i, (inp, idx) in enumerate(val_loader):
  101. labels = utils.get_batch_label(dataset, idx)
  102. inp = inp.to(device)
  103. # inference
  104. preds = model(inp)
  105. # compute loss
  106. batch_size = inp.size(0)
  107. text, length = converter.encode(labels)
  108. preds_size = torch.IntTensor([preds.size(0)] * batch_size)
  109. loss = criterion(preds, text, preds_size, length)
  110. losses.update(loss.item(), inp.size(0))
  111. _, preds = preds.max(2)
  112. preds = preds.transpose(1, 0).contiguous().view(-1)
  113. sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
  114. for pred, target in zip(sim_preds, labels):
  115. if pred == target:
  116. n_correct += 1
  117. if (i + 1) % config.PRINT_FREQ == 0:
  118. print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))
  119. if i == config.TEST.NUM_TEST_BATCH:
  120. break
  121. raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
  122. for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
  123. print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
  124. num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU
  125. if num_test_sample > len(dataset):
  126. num_test_sample = len(dataset)
  127. print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample))
  128. accuracy = n_correct / float(num_test_sample)
  129. print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))
  130. if writer_dict:
  131. writer = writer_dict['writer']
  132. global_steps = writer_dict['valid_global_steps']
  133. writer.add_scalar('valid_acc', accuracy, global_steps)
  134. writer_dict['valid_global_steps'] = global_steps + 1
  135. return accuracy