distillation_loss.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import torch
  2. import cv2
  3. import numpy as np
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from .DBLoss import DBLoss
  7. def _sum_loss(loss_dict):
  8. if "loss" in loss_dict.keys():
  9. return loss_dict
  10. else:
  11. loss_dict["loss"] = 0.
  12. for k, value in loss_dict.items():
  13. if k == "loss":
  14. continue
  15. else:
  16. loss_dict["loss"] += value
  17. return loss_dict
  18. class KLJSLoss(object):
  19. def __init__(self, mode='kl'):
  20. assert mode in ['kl', 'js', 'KL', 'JS'
  21. ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
  22. self.mode = mode
  23. def __call__(self, p1, p2, reduction="mean"):
  24. loss = torch.mul(p2, torch.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
  25. if self.mode.lower() == "js":
  26. loss += torch.mul(
  27. p1, torch.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
  28. loss *= 0.5
  29. if reduction == "mean":
  30. loss = torch.mean(loss)
  31. elif reduction == "none" or reduction is None:
  32. return loss
  33. else:
  34. loss = torch.sum(loss)
  35. return loss
  36. class DMLLoss(nn.Module):
  37. """
  38. DMLLoss
  39. """
  40. def __init__(self, act=None, use_log=False):
  41. super().__init__()
  42. if act is not None:
  43. assert act in ["softmax", "sigmoid"]
  44. if act == "softmax":
  45. self.act = nn.Softmax(axis=-1)
  46. elif act == "sigmoid":
  47. self.act = nn.Sigmoid()
  48. else:
  49. self.act = None
  50. self.use_log = use_log
  51. self.jskl_loss = KLJSLoss(mode="js")
  52. def forward(self, out1, out2):
  53. if self.act is not None:
  54. out1 = self.act(out1)
  55. out2 = self.act(out2)
  56. if self.use_log:
  57. # for recognition distillation, log is needed for feature map
  58. log_out1 = torch.log(out1)
  59. log_out2 = torch.log(out2)
  60. loss = (F.kl_div(
  61. log_out1, out2, reduction='batchmean') + F.kl_div(
  62. log_out2, out1, reduction='batchmean')) / 2.0
  63. else:
  64. # for detection distillation log is not needed
  65. loss = self.jskl_loss(out1, out2)
  66. return loss
  67. class DistanceLoss(nn.Module):
  68. """
  69. DistanceLoss:
  70. mode: loss mode
  71. """
  72. def __init__(self, mode="l2", **kargs):
  73. super().__init__()
  74. assert mode in ["l1", "l2", "smooth_l1"]
  75. if mode == "l1":
  76. self.loss_func = nn.L1Loss(**kargs)
  77. elif mode == "l2":
  78. self.loss_func = nn.MSELoss(**kargs)
  79. elif mode == "smooth_l1":
  80. self.loss_func = nn.SmoothL1Loss(**kargs)
  81. def forward(self, x, y):
  82. return self.loss_func(x, y)
  83. class DistillationDMLLoss(DMLLoss):
  84. """
  85. """
  86. def __init__(self,
  87. model_name_pairs=[],
  88. act=None,
  89. use_log=False,
  90. key=None,
  91. maps_name=None,
  92. name="dml"):
  93. super().__init__(act=act, use_log=use_log)
  94. assert isinstance(model_name_pairs, list)
  95. self.key = key
  96. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  97. self.name = name
  98. self.maps_name = self._check_maps_name(maps_name)
  99. def _check_model_name_pairs(self, model_name_pairs):
  100. if not isinstance(model_name_pairs, list):
  101. return []
  102. elif isinstance(model_name_pairs[0], list) and isinstance(
  103. model_name_pairs[0][0], str):
  104. return model_name_pairs
  105. else:
  106. return [model_name_pairs]
  107. def _check_maps_name(self, maps_name):
  108. if maps_name is None:
  109. return None
  110. elif type(maps_name) == str:
  111. return [maps_name]
  112. elif type(maps_name) == list:
  113. return [maps_name]
  114. else:
  115. return None
  116. def _slice_out(self, outs):
  117. new_outs = {}
  118. for k in self.maps_name:
  119. if k == "thrink_maps":
  120. new_outs[k] = outs[:, 0, :, :]
  121. elif k == "threshold_maps":
  122. new_outs[k] = outs[:, 1, :, :]
  123. elif k == "binary_maps":
  124. new_outs[k] = outs[:, 2, :, :]
  125. else:
  126. continue
  127. return new_outs
  128. def forward(self, predicts, batch):
  129. loss_dict = dict()
  130. for idx, pair in enumerate(self.model_name_pairs):
  131. out1 = predicts[pair[0]]
  132. out2 = predicts[pair[1]]
  133. if self.maps_name is None:
  134. loss = super().forward(out1, out2)
  135. if isinstance(loss, dict):
  136. for key in loss:
  137. loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],idx)] = loss[key]
  138. else:
  139. loss_dict["{}_{}".format(self.name, idx)] = loss
  140. else:
  141. outs1 = self._slice_out(out1)
  142. outs2 = self._slice_out(out2)
  143. for _c, k in enumerate(outs1.keys()):
  144. loss = super().forward(outs1[k], outs2[k])
  145. if isinstance(loss, dict):
  146. for key in loss:
  147. loss_dict["{}_{}_{}_{}_{}".format(key, pair[0], pair[1], self.maps_name[_c], idx)] = loss[key]
  148. else:
  149. loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c], idx)] = loss
  150. loss_dict = _sum_loss(loss_dict)
  151. return loss_dict
  152. class DistillationDBLoss(DBLoss):
  153. def __init__(self,
  154. model_name_list=[],
  155. balance_loss=True,
  156. main_loss_type='DiceLoss',
  157. alpha=5,
  158. beta=10,
  159. ohem_ratio=3,
  160. eps=1e-6,
  161. name="db",
  162. **kwargs):
  163. super().__init__()
  164. self.model_name_list = model_name_list
  165. self.name = name
  166. self.key = None
  167. def forward(self, predicts, batch):
  168. loss_dict = {}
  169. for idx, model_name in enumerate(self.model_name_list):
  170. out = predicts[model_name]
  171. loss = super().forward(out, batch)
  172. if isinstance(loss, dict):
  173. for key in loss.keys():
  174. if key == "loss":
  175. continue
  176. name = "{}_{}_{}".format(self.name, model_name, key)
  177. loss_dict[name] = loss[key]
  178. else:
  179. loss_dict["{}_{}".format(self.name, model_name)] = loss
  180. loss_dict = _sum_loss(loss_dict)
  181. return loss_dict
  182. class DistillationDilaDBLoss(DBLoss):
  183. def __init__(self,
  184. model_name_pairs=[],
  185. key=None,
  186. balance_loss=True,
  187. main_loss_type='DiceLoss',
  188. alpha=5,
  189. beta=10,
  190. ohem_ratio=3,
  191. eps=1e-6,
  192. name="dila_dbloss"):
  193. super().__init__()
  194. self.model_name_pairs = model_name_pairs
  195. self.name = name
  196. self.key = key
  197. def forward(self, predicts, batch):
  198. loss_dict = dict()
  199. for idx, pair in enumerate(self.model_name_pairs):
  200. # stu_outs = predicts[pair[0]]
  201. # tch_outs = predicts[pair[1]]
  202. # if self.key is not None:
  203. # stu_preds = stu_outs[self.key]
  204. # tch_preds = tch_outs[self.key]
  205. stu_preds = predicts[pair[0]]
  206. tch_preds = predicts[pair[1]]
  207. stu_shrink_maps = stu_preds[:, 0, :, :]
  208. stu_binary_maps = stu_preds[:, 2, :, :]
  209. # dilation to teacher prediction
  210. dilation_w = np.array([[1, 1], [1, 1]])
  211. th_shrink_maps = tch_preds[:, 0, :, :]
  212. th_shrink_maps = th_shrink_maps.cpu().detach().numpy() > 0.3 # thresh = 0.3
  213. dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
  214. for i in range(th_shrink_maps.shape[0]):
  215. dilate_maps[i] = cv2.dilate(
  216. th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
  217. th_shrink_maps = torch.tensor(dilate_maps).cuda()
  218. label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch['threshold_map'], batch['threshold_mask'], batch['shrink_map'], batch['shrink_mask']
  219. # calculate the shrink map loss
  220. bce_loss = self.alpha * self.bce_loss(
  221. stu_shrink_maps, th_shrink_maps, label_shrink_mask)
  222. loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
  223. label_shrink_mask)
  224. # k = f"{self.name}_{pair[0]}_{pair[1]}"
  225. k = "{}_{}_{}".format(self.name, pair[0], pair[1])
  226. loss_dict[k] = bce_loss + loss_binary_maps
  227. loss_dict = _sum_loss(loss_dict)
  228. return loss_dict
  229. class DistillationDistanceLoss(DistanceLoss):
  230. """
  231. """
  232. def __init__(self,
  233. mode="l2",
  234. model_name_pairs=[],
  235. key=None,
  236. name="loss_distance",
  237. **kargs):
  238. super().__init__(mode=mode, **kargs)
  239. assert isinstance(model_name_pairs, list)
  240. self.key = key
  241. self.model_name_pairs = model_name_pairs
  242. self.name = name + "_l2"
  243. def forward(self, predicts, batch):
  244. loss_dict = dict()
  245. for idx, pair in enumerate(self.model_name_pairs):
  246. out1 = predicts[pair[0]]
  247. out2 = predicts[pair[1]]
  248. # if self.key is not None:
  249. # out1 = out1[self.key]
  250. # out2 = out2[self.key]
  251. loss = super().forward(out1, out2)
  252. if isinstance(loss, dict):
  253. for key in loss:
  254. loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
  255. key]
  256. else:
  257. loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
  258. idx)] = loss
  259. return loss_dict