init.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # -*- coding: utf-8 -*-
  2. import torch.nn as nn
  3. import torch.nn.init as init
  4. def weight_init(m):
  5. """
  6. Usage:
  7. model = Model()
  8. model.apply(weight_init)
  9. """
  10. if isinstance(m, nn.Conv1d):
  11. init.kaiming_normal_(m.weight.data)
  12. if m.bias is not None:
  13. init.kaiming_normal_(m.bias.data)
  14. elif isinstance(m, nn.Conv2d):
  15. init.kaiming_normal_(m.weight.data)
  16. if m.bias is not None:
  17. init.normal_(m.bias.data)
  18. elif isinstance(m, nn.Conv3d):
  19. init.kaiming_normal_(m.weight.data)
  20. if m.bias is not None:
  21. init.kaiming_normal_(m.bias.data)
  22. elif isinstance(m, nn.ConvTranspose1d):
  23. init.kaiming_normal_(m.weight.data)
  24. if m.bias is not None:
  25. init.kaiming_normal_(m.bias.data)
  26. elif isinstance(m, nn.ConvTranspose2d):
  27. init.xavier_uniform_(m.weight.data)
  28. if m.bias is not None:
  29. init.normal_(m.bias.data)
  30. elif isinstance(m, nn.ConvTranspose3d):
  31. init.xavier_normal_(m.weight.data)
  32. if m.bias is not None:
  33. init.normal_(m.bias.data)
  34. elif isinstance(m, nn.BatchNorm1d):
  35. init.normal_(m.weight.data, mean=1, std=0.02)
  36. init.constant_(m.bias.data, 0)
  37. elif isinstance(m, nn.BatchNorm2d):
  38. init.constant_(m.weight.data,1)
  39. # init.normal_(m.weight.data, mean=1, std=0.02)
  40. init.constant_(m.bias.data, 0)
  41. elif isinstance(m, nn.BatchNorm3d):
  42. init.normal_(m.weight.data, mean=1, std=0.02)
  43. init.constant_(m.bias.data, 0)
  44. elif isinstance(m, nn.Linear):
  45. # init.xavier_uniform_(m.weight.data)
  46. # if m.bias is not None:
  47. # init.normal_(m.bias.data)
  48. init.normal_(m.weight.data,0,0.01)
  49. if m.bias is not None:
  50. init.constant_(m.bias.data,0)
  51. elif isinstance(m, nn.LSTM):
  52. for param in m.parameters():
  53. if len(param.shape) >= 2:
  54. init.orthogonal_(param.data)
  55. else:
  56. init.normal_(param.data)
  57. elif isinstance(m, nn.LSTMCell):
  58. for param in m.parameters():
  59. if len(param.shape) >= 2:
  60. init.orthogonal_(param.data)
  61. else:
  62. init.xavier_uniform_(param.data)
  63. elif isinstance(m, nn.GRU):
  64. for param in m.parameters():
  65. if len(param.shape) >= 2:
  66. init.orthogonal_(param.data)
  67. else:
  68. init.xavier_uniform_(param.data)
  69. elif isinstance(m, nn.GRUCell):
  70. for param in m.parameters():
  71. if len(param.shape) >= 2:
  72. init.orthogonal_(param.data)
  73. else:
  74. init.xavier_uniform_(param.data)