123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- # -*- coding: utf-8 -*-
- import torch.nn as nn
- import torch.nn.init as init
- def weight_init(m):
- """
- Usage:
- model = Model()
- model.apply(weight_init)
- """
- if isinstance(m, nn.Conv1d):
- init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- init.kaiming_normal_(m.bias.data)
- elif isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- init.normal_(m.bias.data)
- elif isinstance(m, nn.Conv3d):
- init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- init.kaiming_normal_(m.bias.data)
- elif isinstance(m, nn.ConvTranspose1d):
- init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- init.kaiming_normal_(m.bias.data)
- elif isinstance(m, nn.ConvTranspose2d):
- init.xavier_uniform_(m.weight.data)
- if m.bias is not None:
- init.normal_(m.bias.data)
- elif isinstance(m, nn.ConvTranspose3d):
- init.xavier_normal_(m.weight.data)
- if m.bias is not None:
- init.normal_(m.bias.data)
- elif isinstance(m, nn.BatchNorm1d):
- init.normal_(m.weight.data, mean=1, std=0.02)
- init.constant_(m.bias.data, 0)
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight.data,1)
- # init.normal_(m.weight.data, mean=1, std=0.02)
- init.constant_(m.bias.data, 0)
- elif isinstance(m, nn.BatchNorm3d):
- init.normal_(m.weight.data, mean=1, std=0.02)
- init.constant_(m.bias.data, 0)
- elif isinstance(m, nn.Linear):
- # init.xavier_uniform_(m.weight.data)
- # if m.bias is not None:
- # init.normal_(m.bias.data)
- init.normal_(m.weight.data,0,0.01)
- if m.bias is not None:
- init.constant_(m.bias.data,0)
- elif isinstance(m, nn.LSTM):
- for param in m.parameters():
- if len(param.shape) >= 2:
- init.orthogonal_(param.data)
- else:
- init.normal_(param.data)
- elif isinstance(m, nn.LSTMCell):
- for param in m.parameters():
- if len(param.shape) >= 2:
- init.orthogonal_(param.data)
- else:
- init.xavier_uniform_(param.data)
- elif isinstance(m, nn.GRU):
- for param in m.parameters():
- if len(param.shape) >= 2:
- init.orthogonal_(param.data)
- else:
- init.xavier_uniform_(param.data)
- elif isinstance(m, nn.GRUCell):
- for param in m.parameters():
- if len(param.shape) >= 2:
- init.orthogonal_(param.data)
- else:
- init.xavier_uniform_(param.data)
|