import os import sys import tensorflow as tf from keras.callbacks import Callback from keras.layers import Layer, warnings import numpy as np sys.path.append(os.path.dirname(__file__)) from pre_process import get_best_padding_size class BatchReshape1(Layer): """ 将表格的行列维度合并到Batch维度中 (batch, rows, cols, character_num, character_embed) -> (batch*rows*cols, character_num, character_embed) """ def __init__(self, character_num, character_embed): super(BatchReshape1, self).__init__() self.character_num = character_num self.character_embed = character_embed def call(self, inputs, mask=None, **kwargs): batch = tf.shape(inputs)[0] height = tf.shape(inputs)[1] width = tf.shape(inputs)[2] outputs = tf.reshape(inputs, (batch*height*width, self.character_num, self.character_embed)) return outputs def compute_output_shape(self, input_shape): return None, self.character_num, self.character_embed class BatchReshape2(Layer): """ 将Batch维度中的行列拆分出来 (batch*rows*cols, cell_embed) -> (batch, rows, cols, cell_embed) """ def __init__(self, cell_embed): super(BatchReshape2, self).__init__() self.cell_embed = cell_embed def call(self, inputs, mask=None, **kwargs): input1 = inputs[0] input2 = inputs[1] batch = tf.shape(input1)[0] height = tf.shape(input1)[1] width = tf.shape(input1)[2] outputs = tf.reshape(input2, (batch, height, width, self.cell_embed)) return outputs def compute_output_shape(self, input_shape): return None, None, None, self.cell_embed class BatchReshape3(Layer): """ 将表格的行维度合并到Batch维度中 (batch, rows, cols, cell_embed) -> (batch*rows, cols, cell_embed) """ def __init__(self, cell_embed): super(BatchReshape3, self).__init__() self.cell_embed = cell_embed def call(self, inputs, mask=None, **kwargs): batch = tf.shape(inputs)[0] height = tf.shape(inputs)[1] width = tf.shape(inputs)[2] outputs = tf.reshape(inputs, (batch*height, width, self.cell_embed)) return outputs def compute_output_shape(self, input_shape): return None, None, self.cell_embed class BatchReshape4(Layer): """ 将Batch维度中的行拆出来 (batch*rows, cols, cell_embed) -> (batch, rows, cols, cell_embed) """ def __init__(self, cell_embed): super(BatchReshape4, self).__init__() self.supports_masking = True self.cell_embed = cell_embed def compute_mask(self, inputs, mask=None): print(mask) # if mask[0] is None: # return mask # input1 = inputs[0] # input2 = inputs[1] # batch = tf.shape(input1)[0] # height = tf.shape(input1)[1] # width = tf.shape(input1)[2] # # mask_tensor = tf.reshape(mask[1], (batch, height, width, self.cell_embed)) return mask def call(self, inputs, mask=None, **kwargs): input1 = inputs[0] input2 = inputs[1] batch = tf.shape(input1)[0] height = tf.shape(input1)[1] width = tf.shape(input1)[2] outputs = tf.reshape(input2, (batch, height, width, self.cell_embed)) return outputs def compute_output_shape(self, input_shape): return None, None, None, self.cell_embed class BatchReshape5(Layer): """ 将表格的行维度合并到Batch维度中 (batch, rows, cols, cell_embed) -> (batch, rows*cols, cell_embed) """ def __init__(self, cell_embed): super(BatchReshape5, self).__init__() self.cell_embed = cell_embed def call(self, inputs, mask=None, **kwargs): batch = tf.shape(inputs)[0] height = tf.shape(inputs)[1] width = tf.shape(inputs)[2] outputs = tf.reshape(inputs, (batch, height*width, self.cell_embed)) return outputs def compute_output_shape(self, input_shape): return None, None, self.cell_embed class BatchReshape6(Layer): """ 将Batch维度中的行拆出来 (batch, rows*cols, cell_embed) -> (batch, rows, cols, cell_embed) """ def __init__(self, cell_embed): super(BatchReshape6, self).__init__() self.cell_embed = cell_embed def call(self, inputs, mask=None, **kwargs): input1 = inputs[0] input2 = inputs[1] batch = tf.shape(input1)[0] height = tf.shape(input1)[1] width = tf.shape(input1)[2] outputs = tf.reshape(input2, (batch, height, width, self.cell_embed)) return outputs def compute_output_shape(self, input_shape): return None, None, None, self.cell_embed class MyPadding(Layer): def __init__(self, pad_height, pad_width, cell_embed): super(MyPadding, self).__init__() self.pad_height = pad_height self.pad_width = pad_width self.cell_embed = cell_embed def call(self, inputs, mask=None, **kwargs): batch = tf.shape(inputs)[0] height = tf.shape(inputs)[1] width = tf.shape(inputs)[2] outputs = tf.pad(inputs, [[0, 0], [0, self.pad_height - height], [0, self.pad_width - width], [0, 0]]) return outputs def compute_output_shape(self, input_shape): return None, None, None, self.cell_embed class MySplit(Layer): def __init__(self, height, width, **kwargs): super(MySplit, self).__init__(**kwargs) self.height = height self.width = width def call(self, inputs, mask=None, **kwargs): outputs = inputs[:, 0:self.height, 0:self.width] return outputs def compute_output_shape(self, input_shape): return None, None, None class MyModelCheckpoint(Callback): def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1): super(MyModelCheckpoint, self).__init__() self.monitor = monitor self.verbose = verbose self.filepath = filepath self.save_best_only = save_best_only self.save_weights_only = save_weights_only self.period = period self.epochs_since_last_save = 0 if mode not in ['auto', 'min', 'max']: warnings.warn('ModelCheckpoint mode %s is unknown, ' 'fallback to auto mode.' % (mode), RuntimeWarning) mode = 'auto' if mode == 'min': self.monitor_op = np.less self.best = np.Inf elif mode == 'max': self.monitor_op = np.greater self.best = -np.Inf else: if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): self.monitor_op = np.greater self.best = -np.Inf else: self.monitor_op = np.less self.best = np.Inf def on_epoch_end(self, epoch, logs=None): logs = logs or {} self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 filepath = self.filepath.format(epoch=epoch + 1, **logs) if self.save_best_only: current = (logs.get(self.monitor[0]) + logs.get(self.monitor[1])) / 2 if current is None: warnings.warn('Can save best model only with %s available, ' 'skipping.' % (self.monitor), RuntimeWarning) else: if self.monitor_op(current, self.best): if self.verbose > 0: print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' ' saving model to %s' % (epoch + 1, self.monitor, self.best, current, filepath)) self.best = current if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True) else: if self.verbose > 0: print('\nEpoch %05d: %s did not improve from %0.5f' % (epoch + 1, self.monitor, self.best)) else: if self.verbose > 0: print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True)