layer_utils.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import os
  2. import sys
  3. import tensorflow as tf
  4. from keras.callbacks import Callback
  5. from keras.layers import Layer, warnings
  6. import numpy as np
  7. sys.path.append(os.path.dirname(__file__))
  8. from pre_process import get_best_padding_size
  9. class BatchReshape1(Layer):
  10. """
  11. 将表格的行列维度合并到Batch维度中
  12. (batch, rows, cols, character_num, character_embed) -> (batch*rows*cols, character_num, character_embed)
  13. """
  14. def __init__(self, character_num, character_embed):
  15. super(BatchReshape1, self).__init__()
  16. self.character_num = character_num
  17. self.character_embed = character_embed
  18. def call(self, inputs, mask=None, **kwargs):
  19. batch = tf.shape(inputs)[0]
  20. height = tf.shape(inputs)[1]
  21. width = tf.shape(inputs)[2]
  22. outputs = tf.reshape(inputs, (batch*height*width,
  23. self.character_num, self.character_embed))
  24. return outputs
  25. def compute_output_shape(self, input_shape):
  26. return None, self.character_num, self.character_embed
  27. class BatchReshape2(Layer):
  28. """
  29. 将Batch维度中的行列拆分出来
  30. (batch*rows*cols, cell_embed) -> (batch, rows, cols, cell_embed)
  31. """
  32. def __init__(self, cell_embed):
  33. super(BatchReshape2, self).__init__()
  34. self.cell_embed = cell_embed
  35. def call(self, inputs, mask=None, **kwargs):
  36. input1 = inputs[0]
  37. input2 = inputs[1]
  38. batch = tf.shape(input1)[0]
  39. height = tf.shape(input1)[1]
  40. width = tf.shape(input1)[2]
  41. outputs = tf.reshape(input2, (batch, height, width, self.cell_embed))
  42. return outputs
  43. def compute_output_shape(self, input_shape):
  44. return None, None, None, self.cell_embed
  45. class BatchReshape3(Layer):
  46. """
  47. 将表格的行维度合并到Batch维度中
  48. (batch, rows, cols, cell_embed) -> (batch*rows, cols, cell_embed)
  49. """
  50. def __init__(self, cell_embed):
  51. super(BatchReshape3, self).__init__()
  52. self.cell_embed = cell_embed
  53. def call(self, inputs, mask=None, **kwargs):
  54. batch = tf.shape(inputs)[0]
  55. height = tf.shape(inputs)[1]
  56. width = tf.shape(inputs)[2]
  57. outputs = tf.reshape(inputs, (batch*height, width, self.cell_embed))
  58. return outputs
  59. def compute_output_shape(self, input_shape):
  60. return None, None, self.cell_embed
  61. class BatchReshape4(Layer):
  62. """
  63. 将Batch维度中的行拆出来
  64. (batch*rows, cols, cell_embed) -> (batch, rows, cols, cell_embed)
  65. """
  66. def __init__(self, cell_embed):
  67. super(BatchReshape4, self).__init__()
  68. self.supports_masking = True
  69. self.cell_embed = cell_embed
  70. def compute_mask(self, inputs, mask=None):
  71. print(mask)
  72. # if mask[0] is None:
  73. # return mask
  74. # input1 = inputs[0]
  75. # input2 = inputs[1]
  76. # batch = tf.shape(input1)[0]
  77. # height = tf.shape(input1)[1]
  78. # width = tf.shape(input1)[2]
  79. #
  80. # mask_tensor = tf.reshape(mask[1], (batch, height, width, self.cell_embed))
  81. return mask
  82. def call(self, inputs, mask=None, **kwargs):
  83. input1 = inputs[0]
  84. input2 = inputs[1]
  85. batch = tf.shape(input1)[0]
  86. height = tf.shape(input1)[1]
  87. width = tf.shape(input1)[2]
  88. outputs = tf.reshape(input2, (batch, height, width, self.cell_embed))
  89. return outputs
  90. def compute_output_shape(self, input_shape):
  91. return None, None, None, self.cell_embed
  92. class BatchReshape5(Layer):
  93. """
  94. 将表格的行维度合并到Batch维度中
  95. (batch, rows, cols, cell_embed) -> (batch, rows*cols, cell_embed)
  96. """
  97. def __init__(self, cell_embed):
  98. super(BatchReshape5, self).__init__()
  99. self.cell_embed = cell_embed
  100. def call(self, inputs, mask=None, **kwargs):
  101. batch = tf.shape(inputs)[0]
  102. height = tf.shape(inputs)[1]
  103. width = tf.shape(inputs)[2]
  104. outputs = tf.reshape(inputs, (batch, height*width, self.cell_embed))
  105. return outputs
  106. def compute_output_shape(self, input_shape):
  107. return None, None, self.cell_embed
  108. class BatchReshape6(Layer):
  109. """
  110. 将Batch维度中的行拆出来
  111. (batch, rows*cols, cell_embed) -> (batch, rows, cols, cell_embed)
  112. """
  113. def __init__(self, cell_embed):
  114. super(BatchReshape6, self).__init__()
  115. self.cell_embed = cell_embed
  116. def call(self, inputs, mask=None, **kwargs):
  117. input1 = inputs[0]
  118. input2 = inputs[1]
  119. batch = tf.shape(input1)[0]
  120. height = tf.shape(input1)[1]
  121. width = tf.shape(input1)[2]
  122. outputs = tf.reshape(input2, (batch, height, width, self.cell_embed))
  123. return outputs
  124. def compute_output_shape(self, input_shape):
  125. return None, None, None, self.cell_embed
  126. class MyPadding(Layer):
  127. def __init__(self, pad_height, pad_width, cell_embed):
  128. super(MyPadding, self).__init__()
  129. self.pad_height = pad_height
  130. self.pad_width = pad_width
  131. self.cell_embed = cell_embed
  132. def call(self, inputs, mask=None, **kwargs):
  133. batch = tf.shape(inputs)[0]
  134. height = tf.shape(inputs)[1]
  135. width = tf.shape(inputs)[2]
  136. outputs = tf.pad(inputs, [[0, 0],
  137. [0, self.pad_height - height],
  138. [0, self.pad_width - width],
  139. [0, 0]])
  140. return outputs
  141. def compute_output_shape(self, input_shape):
  142. return None, None, None, self.cell_embed
  143. class MySplit(Layer):
  144. def __init__(self, height, width, **kwargs):
  145. super(MySplit, self).__init__(**kwargs)
  146. self.height = height
  147. self.width = width
  148. def call(self, inputs, mask=None, **kwargs):
  149. outputs = inputs[:, 0:self.height, 0:self.width]
  150. return outputs
  151. def compute_output_shape(self, input_shape):
  152. return None, None, None
  153. class MyModelCheckpoint(Callback):
  154. def __init__(self, filepath, monitor='val_loss', verbose=0,
  155. save_best_only=False, save_weights_only=False,
  156. mode='auto', period=1):
  157. super(MyModelCheckpoint, self).__init__()
  158. self.monitor = monitor
  159. self.verbose = verbose
  160. self.filepath = filepath
  161. self.save_best_only = save_best_only
  162. self.save_weights_only = save_weights_only
  163. self.period = period
  164. self.epochs_since_last_save = 0
  165. if mode not in ['auto', 'min', 'max']:
  166. warnings.warn('ModelCheckpoint mode %s is unknown, '
  167. 'fallback to auto mode.' % (mode),
  168. RuntimeWarning)
  169. mode = 'auto'
  170. if mode == 'min':
  171. self.monitor_op = np.less
  172. self.best = np.Inf
  173. elif mode == 'max':
  174. self.monitor_op = np.greater
  175. self.best = -np.Inf
  176. else:
  177. if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
  178. self.monitor_op = np.greater
  179. self.best = -np.Inf
  180. else:
  181. self.monitor_op = np.less
  182. self.best = np.Inf
  183. def on_epoch_end(self, epoch, logs=None):
  184. logs = logs or {}
  185. self.epochs_since_last_save += 1
  186. if self.epochs_since_last_save >= self.period:
  187. self.epochs_since_last_save = 0
  188. filepath = self.filepath.format(epoch=epoch + 1, **logs)
  189. if self.save_best_only:
  190. current = (logs.get(self.monitor[0]) + logs.get(self.monitor[1])) / 2
  191. if current is None:
  192. warnings.warn('Can save best model only with %s available, '
  193. 'skipping.' % (self.monitor), RuntimeWarning)
  194. else:
  195. if self.monitor_op(current, self.best):
  196. if self.verbose > 0:
  197. print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
  198. ' saving model to %s'
  199. % (epoch + 1, self.monitor, self.best,
  200. current, filepath))
  201. self.best = current
  202. if self.save_weights_only:
  203. self.model.save_weights(filepath, overwrite=True)
  204. else:
  205. self.model.save(filepath, overwrite=True)
  206. else:
  207. if self.verbose > 0:
  208. print('\nEpoch %05d: %s did not improve from %0.5f' %
  209. (epoch + 1, self.monitor, self.best))
  210. else:
  211. if self.verbose > 0:
  212. print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
  213. if self.save_weights_only:
  214. self.model.save_weights(filepath, overwrite=True)
  215. else:
  216. self.model.save(filepath, overwrite=True)