loss.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. import os
  2. import sys
  3. import tensorflow as tf
  4. import keras.backend as K
  5. from keras import Input
  6. import numpy as np
  7. np.set_printoptions(threshold=np.inf)
  8. from keras.engine.base_layer import Layer
  9. from tensorflow.python.ops.control_flow_ops import while_loop
  10. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  11. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  12. from click_captcha.utils import box_iou
  13. from click_captcha.post_process import yolo_head
  14. def contrastive_loss(y_true, y_pred):
  15. """Contrastive loss from Hadsell-et-al.'06
  16. http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
  17. """
  18. margin = 1
  19. square_pred = K.square(y_pred)
  20. margin_square = K.square(K.maximum(margin - y_pred, 0))
  21. return K.mean(y_true * square_pred + (1 - y_true) * margin_square)
  22. def focal_loss(gamma=3., alpha=.5, only_tf=True):
  23. def focal_loss_fixed(y_true, y_pred):
  24. pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
  25. pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
  26. if only_tf:
  27. return - tf.reduce_sum(alpha * tf.pow(1. - pt_1, gamma) * tf.math.log(1e-07 + pt_1)) \
  28. - tf.reduce_sum((1 - alpha) * tf.pow(pt_0, gamma) * tf.math.log(1. - pt_0 + 1e-07))
  29. else:
  30. return - K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon()+pt_1)) \
  31. - K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
  32. return focal_loss_fixed
  33. def l1_loss():
  34. def mae(y_true, y_pred):
  35. return tf.reduce_mean(tf.abs(y_pred-y_true)) * 100
  36. return mae
  37. def l2_loss():
  38. def mse(y_true, y_pred):
  39. return tf.reduce_mean(tf.square(y_true - y_pred))
  40. return mse
  41. def l2_focal_loss(threshold=0.2, ratio=1000, reverse=False):
  42. def mse(y_true, y_pred):
  43. if reverse:
  44. y_minus = tf.where(tf.abs(y_pred-y_true) <= threshold, 1/ratio*tf.abs(y_pred-y_true), 0.1*tf.abs(y_pred-y_true))
  45. else:
  46. y_minus = tf.where(tf.abs(y_pred-y_true) <= threshold, tf.abs(y_pred-y_true), ratio*tf.abs(y_pred-y_true))
  47. return tf.reduce_mean(tf.square(y_minus))
  48. return mse
  49. def l1_focal_loss(threshold=0.2):
  50. def mae(y_true, y_pred):
  51. y_minus = tf.where(tf.abs(y_pred-y_true) <= threshold, 0., tf.abs(y_pred-y_true))
  52. return tf.reduce_sum(tf.abs(y_minus))
  53. return mae
  54. def l3_loss():
  55. def l3_loss_fixed(y_true, y_pred):
  56. return tf.reduce_mean(tf.abs(tf.pow(y_pred-y_true, 3)))
  57. return l3_loss_fixed
  58. def yolo_loss(args, anchors, num_classes, ignore_thresh=.5, print_loss=False):
  59. """Return yolo_loss tensor
  60. Parameters
  61. ----------
  62. yolo_outputs: list of tensor, the output of yolo_body or tiny_yolo_body
  63. y_true: list of array, the output of preprocess_true_boxes
  64. anchors: array, shape=(N, 2), wh
  65. num_classes: integer
  66. ignore_thresh: float, the iou threshold whether to ignore object confidence loss
  67. Returns
  68. -------
  69. loss: tensor, shape=(1,)
  70. """
  71. from keras import backend as K
  72. # default setting
  73. num_layers = len(anchors)//3
  74. yolo_outputs = args[:num_layers]
  75. y_true = args[num_layers:]
  76. anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] if num_layers == 3 else [[3, 4, 5], [1, 2, 3]]
  77. input_shape = K.cast(K.shape(yolo_outputs[0])[1:3] * 32, K.dtype(y_true[0]))
  78. grid_shapes = [K.cast(K.shape(yolo_outputs[l])[1:3], K.dtype(y_true[0])) for l in range(num_layers)]
  79. loss = 0
  80. # batch size, tensor
  81. m = K.shape(yolo_outputs[0])[0]
  82. mf = K.cast(m, K.dtype(yolo_outputs[0]))
  83. for l in range(num_layers):
  84. object_mask = y_true[l][..., 4:5]
  85. true_class_probs = y_true[l][..., 5:]
  86. grid, raw_pred, pred_xy, pred_wh = yolo_head(yolo_outputs[l],
  87. anchors[anchor_mask[l]], num_classes, input_shape, calc_loss=True)
  88. pred_box = K.concatenate([pred_xy, pred_wh])
  89. # Darknet raw box to calculate loss.
  90. raw_true_xy = y_true[l][..., :2]*grid_shapes[l][::-1] - grid
  91. raw_true_wh = K.log(y_true[l][..., 2:4] / anchors[anchor_mask[l]] * input_shape[::-1])
  92. # avoid log(0)=-inf
  93. raw_true_wh = K.switch(object_mask, raw_true_wh, K.zeros_like(raw_true_wh))
  94. box_loss_scale = 2 - y_true[l][..., 2:3]*y_true[l][..., 3:4]
  95. # Find ignore mask, iterate over each of batch.
  96. ignore_mask = tf.TensorArray(K.dtype(y_true[0]), size=1, dynamic_size=True)
  97. object_mask_bool = K.cast(object_mask, 'bool')
  98. def loop_body(b, ignore_mask):
  99. true_box = tf.boolean_mask(y_true[l][b, ..., 0:4], object_mask_bool[b,...,0])
  100. iou = box_iou(pred_box[b], true_box)
  101. best_iou = K.max(iou, axis=-1)
  102. ignore_mask = ignore_mask.write(b, K.cast(best_iou<ignore_thresh, K.dtype(true_box)))
  103. return b+1, ignore_mask
  104. _, ignore_mask = while_loop(lambda b, *args: b < m, loop_body, [0, ignore_mask])
  105. ignore_mask = ignore_mask.stack()
  106. ignore_mask = K.expand_dims(ignore_mask, -1)
  107. # K.binary_crossentropy is helpful to avoid exp overflow.
  108. xy_loss = object_mask * box_loss_scale * K.binary_crossentropy(raw_true_xy, raw_pred[..., 0:2], from_logits=True)
  109. wh_loss = object_mask * box_loss_scale * 0.5 * K.square(raw_true_wh-raw_pred[..., 2:4])
  110. confidence_loss = object_mask * K.binary_crossentropy(object_mask, raw_pred[..., 4:5], from_logits=True) + \
  111. (1-object_mask) * K.binary_crossentropy(object_mask, raw_pred[..., 4:5], from_logits=True) * ignore_mask
  112. class_loss = object_mask * K.binary_crossentropy(true_class_probs, raw_pred[..., 5:], from_logits=True)
  113. xy_loss = K.sum(xy_loss) / mf
  114. wh_loss = K.sum(wh_loss) / mf
  115. confidence_loss = K.sum(confidence_loss) / mf
  116. class_loss = K.sum(class_loss) / mf
  117. loss += xy_loss * 10 + wh_loss * 10 + confidence_loss
  118. # if print_loss:
  119. # loss = tf.Print(loss, [loss, xy_loss, wh_loss, confidence_loss, class_loss, K.sum(ignore_mask)], message='loss: ')
  120. return loss
  121. def ctc_lambda_func(args):
  122. """
  123. 定义ctc损失函数
  124. 参数:y_pred:预测值,labels:标签,input_length:lstm tiemstep,label_length:标签长度
  125. """
  126. y_pred, labels, input_length, label_length = args
  127. # return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
  128. return my_ctc_batch_cost(labels, y_pred, input_length, label_length, mode=0)
  129. def my_ctc_batch_cost(y_true, y_pred, input_length, label_length, mode=0):
  130. """Runs CTC loss algorithm on each batch element.
  131. Args:
  132. y_true: tensor `(samples, max_string_length)`
  133. containing the truth labels.
  134. y_pred: tensor `(samples, time_steps, num_categories)`
  135. containing the prediction, or output of the softmax.
  136. input_length: tensor `(samples, 1)` containing the sequence length for
  137. each batch item in `y_pred`.
  138. label_length: tensor `(samples, 1)` containing the sequence length for
  139. each batch item in `y_true`.
  140. Returns:
  141. Tensor with shape (samples,1) containing the
  142. CTC loss of each element.
  143. """
  144. input_length = tf.cast(
  145. tf.squeeze(input_length, axis=-1), tf.int32)
  146. label_length = tf.cast(
  147. tf.squeeze(label_length, axis=-1), tf.int32)
  148. sparse_labels = tf.cast(
  149. K.ctc_label_dense_to_sparse(y_true, label_length), tf.int32)
  150. y_pred = tf.math.log(tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
  151. loss = tf.compat.v1.nn.ctc_loss(inputs=y_pred,
  152. labels=sparse_labels,
  153. sequence_length=input_length,
  154. preprocess_collapse_repeated=False,
  155. ctc_merge_repeated=True)
  156. loss = tf.expand_dims(loss, 1)
  157. if mode == 1:
  158. loss = focal_ctc(sparse_labels, y_pred, input_length, loss)
  159. # if mode == 2:
  160. # loss = loss + ctc_decode_mse_loss((y_pred, y_true, input_length, label_length))
  161. # print("loss1", loss.shape)
  162. return loss
  163. # @tf.function
  164. def ctc_decode_mse_loss(args):
  165. num_classes = 35+2
  166. time_step = 11
  167. # y_pred [32, 21, 37]
  168. y_pred, labels, input_length, label_length = args
  169. # print("y_pred", y_pred.shape)
  170. # y_pred [37, 32, 21]
  171. # y_pred = tf.compat.v1.transpose(y_pred, perm=[2, 0, 1])
  172. # y_max [32, 21]
  173. y_max = tf.argmax(y_pred, axis=-1, name='raw_prediction')
  174. # 判断是否为预测的字符
  175. is_char = tf.greater(y_max, 0)
  176. # 错位比较法,找到重复字符
  177. char_rep = tf.equal(y_max[:, :-1], y_max[:, 1:])
  178. tail = tf.greater(y_max[:, :1], num_classes - 1)
  179. char_rep = tf.concat([char_rep, tail], axis=1)
  180. # 去掉重复字符之后的字符位置,重复字符取其 最后一次 出现的位置
  181. # [32, 21]
  182. char_no_rep = tf.math.logical_and(is_char, tf.math.logical_not(char_rep))
  183. # char_no_rep = tf.expand_dims(char_no_rep, axis=-1)
  184. # char_no_rep = tf.concat([char_no_rep]*37, axis=-1)
  185. # [32, 21, 37]
  186. # y_pred = tf.compat.v1.transpose(y_pred, perm=[1, 2, 0])
  187. # y_pred_no_rep [32*?, 37]
  188. # y_pred_no_rep = tf.boolean_mask(y_pred, char_no_rep)
  189. # y_pred_no_rep [32, ?, 37]
  190. # y_pred_no_rep = tf.compat.v1.transpose(y_pred_no_rep, perm=[1, 0, 2])
  191. # time_step = tf.cast(K.shape(y_pred_no_rep)[0]/K.shape(y_pred)[0], tf.int32)
  192. # y_pred_no_rep [32, 21, 37]
  193. # y_pred_no_rep = tf.reshape(y_pred_no_rep, (K.shape(y_pred)[0], time_step, K.shape(y_pred_no_rep)[-1]))
  194. # 填充两个张量的时间步维度到同一大小
  195. # y_pred_no_rep = tf.concat([y_pred_no_rep, tf.zeros((K.shape(labels)[0], K.shape(labels)[1], K.shape(y_pred)[2]-K.shape(labels)[2]))],
  196. # axis=2)
  197. # [32, 37, 21]
  198. labels = tf.cast(labels, tf.int32)
  199. labels = tf.one_hot(labels, depth=num_classes, axis=1, dtype=tf.float32)
  200. labels = tf.concat([labels, tf.zeros((K.shape(labels)[0], K.shape(labels)[1], K.shape(y_pred)[2]-K.shape(labels)[2]))],
  201. axis=2)
  202. # [32, 21, 37]
  203. labels = tf.compat.v1.transpose(labels, perm=[0, 2, 1])
  204. new_label = tf.zeros((1, time_step, num_classes), dtype=tf.float32)
  205. # tf.autograph.experimental.set_loop_options(
  206. # shape_invariants=[(new_label, tf.TensorShape([None, None, 37]))]
  207. # )
  208. @tf.function
  209. def body(_i, _label):
  210. # print("_i", _i)
  211. sample = char_no_rep[_i, :]
  212. if sample[0]:
  213. new_sample = labels[_i:_i+1, 0:1, :]
  214. new_sample = tf.cast(new_sample, tf.float32)
  215. else:
  216. new_sample = tf.zeros((1, 1, 37), dtype=tf.float32)
  217. for j in range(1, 11):
  218. step = char_no_rep[_i, j]
  219. k = 0
  220. if step and k < K.shape(labels)[1]:
  221. new_sample = tf.concat([new_sample, labels[_i:_i+1, k:k+1, :]], axis=1)
  222. k += 1
  223. else:
  224. new_sample = tf.concat([new_sample, tf.zeros((1, 1, 37), dtype=tf.float32)], axis=1)
  225. if _i == 0:
  226. _label = new_sample
  227. else:
  228. _label = tf.concat([_label, new_sample], axis=0)
  229. _i = tf.add(_i, 1)
  230. return _i, _label
  231. def cond(_i, _label):
  232. return tf.less(_i, K.shape(labels)[0])
  233. i = tf.constant(1, dtype=tf.int32)
  234. # time_step_tensor = tf.constant(time_step, dtype=tf.int32)
  235. # num_classes_tensor = tf.constant(num_classes, dtype=tf.int32)
  236. _, new_label = tf.while_loop(cond, body, [i, new_label],
  237. shape_invariants=[i.get_shape(), tf.TensorShape([None, None, 37]),])
  238. # print("new_label", new_label.shape)
  239. # for i in range(32):
  240. # sample = char_no_rep[i, :]
  241. # if sample[0]:
  242. # new_sample = labels[i:i+1, 0:1, :]
  243. # new_sample = tf.cast(new_sample, tf.float32)
  244. # else:
  245. # new_sample = tf.zeros((1, 1, 37), dtype=tf.float32)
  246. # for j in range(1, 21):
  247. # step = char_no_rep[i, j]
  248. # k = 0
  249. # if step and k < K.shape(labels)[1]:
  250. # new_sample = tf.concat([new_sample, labels[i:i+1, k:k+1, :]], axis=1)
  251. # k += 1
  252. # else:
  253. # new_sample = tf.concat([new_sample, tf.zeros((1, 1, 37), dtype=tf.float32)], axis=1)
  254. # # if i == 0:
  255. # # new_label = new_sample
  256. # # else:
  257. # new_label = tf.concat([new_label, new_sample], axis=0)
  258. # def cond(_i, _j):
  259. # return tf.less(_i, K.shape(char_no_rep)[-1])
  260. #
  261. # def body(_i, _j):
  262. # def func1(j):
  263. # tf.add(j, 1)
  264. # return tf.cast(labels[:, j-1], tf.int32)
  265. #
  266. # def func2():
  267. # return tf.zeros((K.shape(labels)[0], K.shape(labels)[0]-31), dtype=tf.int32)
  268. #
  269. # cond_func = tf.cond(char_no_rep[:, _i], lambda: func1(_j), func2)
  270. # return cond_func
  271. #
  272. # i = K.constant(1, tf.int32)
  273. # j = K.constant(1, tf.int32)
  274. # y_pred_no_rep, _ = tf.while_loop(cond, body, [i, j])
  275. # pred_sum = tf.reduce_sum(y_pred)
  276. # label_sum = tf.reduce_sum(raw_labels)
  277. # labels [32, 37, 21]
  278. # y_pred [32, 37, ]
  279. # new_label = tf.reshape(new_label, (None, 777))
  280. loss = tf.reduce_mean(tf.abs((new_label-y_pred)), axis=-1)
  281. loss = tf.reduce_mean(loss, axis=-1)
  282. loss = tf.expand_dims(loss, -1)
  283. # loss = tf.reduce_mean(loss, axis=-1)
  284. # print("loss2", loss.shape)
  285. # loss.set_shape(None, 1)
  286. # print("loss22", loss.shape)
  287. return loss
  288. def ctc_decode_mse_loss2(args):
  289. batch_size = 32
  290. num_classes = 35+2
  291. time_step = 21
  292. label_len = 8
  293. blank_index = num_classes-1
  294. # [32, 21, 37]
  295. y_pred, labels, input_length, label_length = args
  296. # [32, 21]
  297. y_max = tf.argmax(y_pred, axis=-1, name='raw_prediction', output_type=tf.int32)
  298. # [32, 8]
  299. labels = tf.cast(labels, tf.int32)
  300. # [batch, step]
  301. # new_label = tf.zeros((batch_size, time_step), dtype=tf.int32)
  302. new_label = tf.fill((batch_size, time_step), blank_index)
  303. @tf.function
  304. def body(_i, _label):
  305. # new_sample = tf.zeros((1, time_step), dtype=tf.int32)
  306. new_sample = tf.fill((1, time_step), blank_index)
  307. for j in range(0, label_len):
  308. # if tf.greater(0, y_max[_i, j]):
  309. find_flag = False
  310. for k in range(0, time_step):
  311. # 循环y_pred,找对应labels,会漏掉
  312. # if k < K.shape(labels)[1] and tf.equal(y_max[_i, j], labels[_i, k]):
  313. # # tf.print("equal", y_max[_i, j], labels[_i, k])
  314. # if j == 0:
  315. # new_sample = tf.concat([labels[_i:_i+1, k:k+1], new_sample[:, j+1:]], axis=-1)
  316. # elif j >= time_step-1:
  317. # new_sample = tf.concat([new_sample[:, :j], labels[_i:_i+1, k:k+1]], axis=-1)
  318. # else:
  319. # new_sample = tf.concat([new_sample[:, :j], labels[_i:_i+1, k:k+1], new_sample[:, j+1:]], axis=-1)
  320. # 循环labels,找对应y_pred,漏掉的找个0位置覆盖
  321. # tf.print("labels", labels[_i], last_k, j, labels[_i].shape, new_sample.shape)
  322. if tf.equal(y_max[_i, k], labels[_i, j]) and tf.not_equal(y_max[_i, k], blank_index):
  323. find_flag = True
  324. if k == 0:
  325. new_sample = tf.concat([labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
  326. elif k >= time_step-1:
  327. new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1]], axis=-1)
  328. else:
  329. new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
  330. # tf.print("new_sample", new_sample, last_k, j, K.shape(labels[_i]), K.shape(new_sample))
  331. if not find_flag and tf.not_equal(labels[_i, j], blank_index):
  332. find_flag2 = False
  333. for k in range(0, time_step):
  334. if not find_flag2 and tf.equal(new_sample[0, k], blank_index):
  335. find_flag2 = True
  336. if k == 0:
  337. new_sample = tf.concat([labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
  338. elif k >= time_step-1:
  339. new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1]], axis=-1)
  340. else:
  341. new_sample = tf.concat([new_sample[:, :k], labels[_i:_i+1, j:j+1], new_sample[:, k+1:]], axis=-1)
  342. # tf.print("new_sample", new_sample, labels[_i, j], find_flag, find_flag2, summarize=100)
  343. # tf.print("new_sample", new_sample, summarize=100)
  344. tf.print("y_max[_i]", y_max[_i], summarize=100)
  345. tf.print("new_samele", new_sample, summarize=100)
  346. tf.print("labels[_i]", labels[_i], summarize=100)
  347. tf.print("loss", tf.reduce_mean(tf.abs((y_max[_i]-new_sample)), axis=-1))
  348. if _i == 0:
  349. _label = tf.concat([new_sample[:, :], _label[_i+1:, :]], axis=0)
  350. elif _i >= time_step-1:
  351. _label = tf.concat([_label[:_i, :], new_sample[:, :]], axis=0)
  352. else:
  353. _label = tf.concat([_label[:_i, :], new_sample[:, :], _label[_i+1:, :]], axis=0)
  354. _i = tf.add(_i, 1)
  355. return _i, _label
  356. def cond(_i, _label):
  357. return tf.less(_i, K.shape(labels)[0])
  358. i = tf.constant(1, dtype=tf.int32)
  359. _, new_label = tf.while_loop(cond, body, [i, new_label],
  360. shape_invariants=[i.get_shape(), tf.TensorShape([None, None])])
  361. new_label = tf.one_hot(new_label, depth=num_classes, axis=1, dtype=tf.float32)
  362. new_label = tf.compat.v1.transpose(new_label, perm=[0, 2, 1])
  363. # print("y_pred", y_pred.shape)
  364. # print("new_label", new_label.shape)
  365. loss = tf.reduce_mean(tf.abs((new_label-y_pred)), axis=-1)
  366. loss = tf.reduce_mean(loss*1, axis=-1)
  367. loss = tf.expand_dims(loss, -1)
  368. return loss
  369. class CtcDecodeMseLoss(Layer):
  370. def __init__(self, **kwargs):
  371. super(CtcDecodeMseLoss, self).__init__(**kwargs)
  372. def build(self, input_shape):
  373. # Create a trainable weight variable for this layer.
  374. super(CtcDecodeMseLoss, self).build(input_shape) # Be sure to call this somewhere!
  375. def call(self, inputs):
  376. # y_pred [32, 21, 37]
  377. y_pred, labels, input_length, label_length = inputs
  378. # y_max [32, 21]
  379. y_max = tf.argmax(y_pred, axis=-1, name='raw_prediction')
  380. num_classes = 35+2
  381. # 判断是否为预测的字符
  382. is_char = tf.greater(y_max, 0)
  383. # 错位比较法,找到重复字符
  384. char_rep = tf.equal(y_max[:, :-1], y_max[:, 1:])
  385. tail = tf.greater(y_max[:, :1], num_classes - 1)
  386. char_rep = tf.concat([char_rep, tail], axis=1)
  387. # 去掉重复字符之后的字符位置,重复字符取其 最后一次 出现的位置
  388. # [32, 21]
  389. char_no_rep = tf.math.logical_and(is_char, tf.math.logical_not(char_rep))
  390. # [32, 37, 21]
  391. labels = tf.cast(labels, tf.int32)
  392. labels = tf.one_hot(labels, depth=37, axis=1, dtype=tf.float32)
  393. labels = tf.concat([labels, tf.zeros((K.shape(labels)[0], K.shape(labels)[1], K.shape(y_pred)[2]-K.shape(labels)[2]))],
  394. axis=2)
  395. # [32, 21, 37]
  396. labels = tf.compat.v1.transpose(labels, perm=[0, 2, 1])
  397. for i in range(32):
  398. sample = char_no_rep[i, :]
  399. if sample[0]:
  400. new_sample = labels[i:i+1, 0:1, :]
  401. new_sample = tf.cast(new_sample, tf.float32)
  402. else:
  403. new_sample = tf.zeros((1, 1, 37), dtype=tf.float32)
  404. for j in range(1, 21):
  405. step = char_no_rep[i, j]
  406. k = 0
  407. if step and k < K.shape(labels)[1]:
  408. new_sample = tf.concat([new_sample, labels[i:i+1, k:k+1, :]], axis=1)
  409. k += 1
  410. else:
  411. new_sample = tf.concat([new_sample, tf.zeros((1, 1, 37), dtype=tf.float32)], axis=1)
  412. if i == 0:
  413. new_label = new_sample
  414. else:
  415. new_label = tf.concat([new_label, new_sample], axis=0)
  416. loss = tf.reduce_mean(tf.abs((new_label-y_pred)*100))
  417. # loss = tf.expand_dims(loss, 1)
  418. print("loss2", loss.shape)
  419. return loss
  420. def compute_output_shape(self, input_shape):
  421. return (K.shape(input_shape)[0], 1)
  422. def focal_ctc(targets, logits, seq_len, ctc_loss, alpha=0.8, gamma=2.0):
  423. # FOCAL LOSS
  424. # This function computes Focal Loss
  425. # Inputs: alpha, gamma, targets, logits, seq_len
  426. # Default Values: alpha=0.5 and gamma=2.0
  427. # Output: loss
  428. # ctc_loss = tf.compat.v1.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len, time_major=True)
  429. p = tf.exp(-ctc_loss)
  430. # ((alpha)*((1-p)**gamma)*(ctc_loss))
  431. focal_ctc_loss = tf.multiply(tf.multiply(alpha, tf.pow((1-p), gamma)), ctc_loss)
  432. loss = tf.reduce_mean(focal_ctc_loss)
  433. return loss
  434. def ctc_center_loss(labels, features, _lambda=0.0005):
  435. def center_loss(labels, features, alpha=0.6, num_classes=240):
  436. """
  437. 获取center loss及更新样本的center
  438. :param labels: Tensor,表征样本label,非one-hot编码,shape应为(batch_size,).
  439. :param features: Tensor,表征样本特征,最后一个fc层的输出,shape应该为(batch_size, num_classes).
  440. :param alpha: 0-1之间的数字,控制样本类别中心的学习率,细节参考原文.
  441. :param num_classes: 整数,表明总共有多少个类别,网络分类输出有多少个神经元这里就取多少.
  442. :return: Tensor, center-loss, shape因为(batch_size,)
  443. """
  444. # 获取特征的维数,例如256维
  445. len_features = features.get_shape()[1]
  446. # 建立一个Variable,shape为[num_classes, len_features],用于存储整个网络的样本中心,
  447. # 设置trainable=False是因为样本中心不是由梯度进行更新的
  448. centers = tf.compat.v1.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
  449. initializer=tf.constant_initializer(0), trainable=False)
  450. # 将label展开为一维的,如果labels已经是一维的,则该动作其实无必要
  451. labels = tf.reshape(labels, [-1])
  452. # 根据样本label,获取mini-batch中每一个样本对应的中心值
  453. centers_batch = tf.gather(centers, labels)
  454. # 当前mini-batch的特征值与它们对应的中心值之间的差
  455. diff = centers_batch - features
  456. # 获取mini-batch中同一类别样本出现的次数,了解原理请参考原文公式(4)
  457. unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
  458. appear_times = tf.gather(unique_count, unique_idx)
  459. appear_times = tf.reshape(appear_times, [-1, 1])
  460. diff = diff / tf.cast((1 + appear_times), tf.float32)
  461. diff = alpha * diff
  462. # 更新centers
  463. centers_update_op = tf.compat.v1.scatter_sub(centers, labels, diff)
  464. # 这里使用tf.control_dependencies更新centers
  465. with tf.control_dependencies([centers_update_op]):
  466. # 计算center-loss
  467. c_loss = tf.nn.l2_loss(features - centers_batch)
  468. return c_loss
  469. def get_slice(pos):
  470. feature_one_char = features[pos[1], pos[0], :]
  471. return feature_one_char
  472. num_classes = 35+2
  473. # 判断是否为预测的字符
  474. raw_pred = tf.argmax(features, axis=2, name='raw_prediction')
  475. is_char = tf.greater(raw_pred, 0)
  476. # 错位比较法,找到重复字符
  477. char_rep = tf.equal(raw_pred[:, :-1], raw_pred[:, 1:])
  478. tail = tf.greater(raw_pred[:, :1], num_classes - 1)
  479. char_rep = tf.concat([char_rep, tail], axis=1)
  480. # 去掉重复字符之后的字符位置,重复字符取其 最后一次 出现的位置
  481. char_no_rep = tf.math.logical_and(is_char, tf.math.logical_not(char_rep))
  482. char_pos = tf.boolean_mask(features, char_no_rep)
  483. features = tf.map_fn(get_slice, char_pos, dtype=tf.float32)
  484. labels = K.cast(labels, dtype=tf.float32)
  485. # softmax loss
  486. s_loss = K.categorical_crossentropy(labels, K.softmax(features, axis=-1))
  487. # center loss
  488. c_loss = center_loss(K.argmax(labels, axis=-1), features)
  489. return s_loss + _lambda * c_loss
  490. def ctc_center_accuracy(y_true, y_pred):
  491. """
  492. 重写categorical_accuracy函数,以适应去掉softmax层的模型
  493. :param y_true: 等同于labels,
  494. :param y_pred: 等同于features。
  495. :return: 准确率
  496. """
  497. # 计算y_pred的softmax值
  498. sm_y_pred = K.softmax(y_pred, axis=-1)
  499. # 返回准确率
  500. return K.cast(K.equal(K.argmax(y_true, axis=-1), K.argmax(sm_y_pred, axis=-1)), K.floatx())
  501. def ctc_accuracy(y_true, y_pred):
  502. # 使用CTC decoder
  503. decoded = K.ctc_decode(y_pred, input_length=21, greedy=False, beam_width=6)
  504. # 计算编辑距离
  505. distance = tf.edit_distance(tf.cast(decoded[0], tf.int32), y_true)
  506. # 计算label error rate (accuracy)
  507. label_error_rate = tf.reduce_mean(distance, name='label_error_rate')
  508. return label_error_rate
  509. def perceptual_loss(gamma=2., alpha=.25):
  510. from click_captcha.model import Vgg19
  511. def perceptual_loss_fixed(y_true, y_pred):
  512. if globals().get("vgg") is None:
  513. vgg = Vgg19("./vgg19.npy")
  514. globals().update({"vgg": vgg})
  515. print("init vgg19 success!")
  516. else:
  517. vgg = globals().get("vgg")
  518. # mask_1 = tf.where(y_true[:, :, :, 0] >= 0.75, 1, 0)
  519. # mask_2 = tf.where(y_true[:, :, :, 1] >= 0.75, 1, 0)
  520. # mask_3 = tf.where(y_true[:, :, :, 2] >= 0.75, 1, 0)
  521. # mask_white = tf.expand_dims(mask_1 * mask_2 * mask_3, -1)
  522. # mask_white = tf.concat([mask_white, mask_white, mask_white], -1)
  523. # y_true_mask = tf.where(mask_white == 1, 1., y_true)
  524. # y_pred_mask = tf.where(mask_white == 1, 1., y_pred)
  525. # print("y_pred.shape", y_pred.shape)
  526. y_pred = tf.concat([y_pred, y_pred, y_pred], -1)
  527. y_true = tf.concat([y_true, y_true, y_true], -1)
  528. vgg.build(y_true)
  529. vgg_true_1 = vgg.conv1_1
  530. vgg_true_2 = vgg.conv2_1
  531. vgg_true_3 = vgg.conv3_1
  532. vgg_true_4 = vgg.conv4_1
  533. vgg_true_5 = vgg.conv5_1
  534. vgg.build(y_pred)
  535. vgg_pred_1 = vgg.conv1_1
  536. vgg_pred_2 = vgg.conv2_1
  537. vgg_pred_3 = vgg.conv3_1
  538. vgg_pred_4 = vgg.conv4_1
  539. vgg_pred_5 = vgg.conv5_1
  540. loss_0 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(y_true, y_pred)
  541. loss_1 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_1, vgg_pred_1)
  542. loss_2 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_2, vgg_pred_2)
  543. loss_3 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_3, vgg_pred_3)
  544. loss_4 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_4, vgg_pred_4)
  545. loss_5 = l2_focal_loss(threshold=0.2, ratio=1000, reverse=True)(vgg_true_5, vgg_pred_5)
  546. return (loss_0+loss_1+loss_2+loss_3+loss_4+loss_5) / 6
  547. return perceptual_loss_fixed