# -*- coding: utf-8 -*- import torch.nn as nn import torch.nn.init as init def weight_init(m): """ Usage: model = Model() model.apply(weight_init) """ if isinstance(m, nn.Conv1d): init.kaiming_normal_(m.weight.data) if m.bias is not None: init.kaiming_normal_(m.bias.data) elif isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.Conv3d): init.kaiming_normal_(m.weight.data) if m.bias is not None: init.kaiming_normal_(m.bias.data) elif isinstance(m, nn.ConvTranspose1d): init.kaiming_normal_(m.weight.data) if m.bias is not None: init.kaiming_normal_(m.bias.data) elif isinstance(m, nn.ConvTranspose2d): init.xavier_uniform_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.ConvTranspose3d): init.xavier_normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.BatchNorm1d): init.normal_(m.weight.data, mean=1, std=0.02) init.constant_(m.bias.data, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight.data,1) # init.normal_(m.weight.data, mean=1, std=0.02) init.constant_(m.bias.data, 0) elif isinstance(m, nn.BatchNorm3d): init.normal_(m.weight.data, mean=1, std=0.02) init.constant_(m.bias.data, 0) elif isinstance(m, nn.Linear): # init.xavier_uniform_(m.weight.data) # if m.bias is not None: # init.normal_(m.bias.data) init.normal_(m.weight.data,0,0.01) if m.bias is not None: init.constant_(m.bias.data,0) elif isinstance(m, nn.LSTM): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal_(param.data) else: init.normal_(param.data) elif isinstance(m, nn.LSTMCell): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal_(param.data) else: init.xavier_uniform_(param.data) elif isinstance(m, nn.GRU): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal_(param.data) else: init.xavier_uniform_(param.data) elif isinstance(m, nn.GRUCell): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal_(param.data) else: init.xavier_uniform_(param.data)