function.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from __future__ import absolute_import
  2. import time
  3. import lib.utils.utils as utils
  4. import torch
  5. class AverageMeter(object):
  6. """Computes and stores the average and current value"""
  7. def __init__(self):
  8. self.val = 0
  9. self.avg = 0
  10. self.sum = 0
  11. self.count = 0
  12. self.reset()
  13. def reset(self):
  14. self.val = 0
  15. self.avg = 0
  16. self.sum = 0
  17. self.count = 0
  18. def update(self, val, n=1):
  19. self.val = val
  20. self.sum += val * n
  21. self.count += n
  22. self.avg = self.sum / self.count
  23. def train(config, train_loader, dataset, converter, model, criterion, optimizer, device, epoch, writer_dict=None, output_dict=None):
  24. batch_time = AverageMeter()
  25. data_time = AverageMeter()
  26. losses = AverageMeter()
  27. model.train()
  28. end = time.time()
  29. for i, (inp, idx) in enumerate(train_loader):
  30. # measure data time
  31. data_time.update(time.time() - end)
  32. labels = utils.get_batch_label(dataset, idx)
  33. inp = inp.to(device)
  34. # inference
  35. preds = model(inp).cpu()
  36. # compute loss
  37. batch_size = inp.size(0)
  38. text, length = converter.encode(labels) # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标
  39. preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize
  40. loss = criterion(preds, text, preds_size, length)
  41. optimizer.zero_grad()
  42. loss.backward()
  43. optimizer.step()
  44. losses.update(loss.item(), inp.size(0))
  45. batch_time.update(time.time()-end)
  46. if i % config.PRINT_FREQ == 0:
  47. msg = 'Epoch: [{0}][{1}/{2}]\t' \
  48. 'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
  49. 'Speed {speed:.1f} samples/s\t' \
  50. 'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
  51. 'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
  52. epoch, i, len(train_loader), batch_time=batch_time,
  53. speed=inp.size(0)/batch_time.val,
  54. data_time=data_time, loss=losses)
  55. print(msg)
  56. if writer_dict:
  57. writer = writer_dict['writer']
  58. global_steps = writer_dict['train_global_steps']
  59. writer.add_scalar('train_loss', losses.avg, global_steps)
  60. writer_dict['train_global_steps'] = global_steps + 1
  61. end = time.time()
  62. def validate(config, val_loader, dataset, converter, model, criterion, device, epoch, writer_dict, output_dict):
  63. losses = AverageMeter()
  64. model.eval()
  65. n_correct = 0
  66. with torch.no_grad():
  67. for i, (inp, idx) in enumerate(val_loader):
  68. labels = utils.get_batch_label(dataset, idx)
  69. inp = inp.to(device)
  70. # inference
  71. preds = model(inp).cpu()
  72. # compute loss
  73. batch_size = inp.size(0)
  74. text, length = converter.encode(labels)
  75. preds_size = torch.IntTensor([preds.size(0)] * batch_size)
  76. loss = criterion(preds, text, preds_size, length)
  77. losses.update(loss.item(), inp.size(0))
  78. _, preds = preds.max(2)
  79. preds = preds.transpose(1, 0).contiguous().view(-1)
  80. sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
  81. for pred, target in zip(sim_preds, labels):
  82. if pred == target:
  83. n_correct += 1
  84. if (i + 1) % config.PRINT_FREQ == 0:
  85. print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))
  86. if i == config.TEST.NUM_TEST:
  87. break
  88. raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
  89. for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
  90. print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
  91. num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU
  92. if num_test_sample > len(dataset):
  93. num_test_sample = len(dataset)
  94. print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample))
  95. accuracy = n_correct / float(num_test_sample)
  96. print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))
  97. if writer_dict:
  98. writer = writer_dict['writer']
  99. global_steps = writer_dict['valid_global_steps']
  100. writer.add_scalar('valid_acc', accuracy, global_steps)
  101. writer_dict['valid_global_steps'] = global_steps + 1
  102. return accuracy