import module.model as model import featureEngine from keras.callbacks import ModelCheckpoint import numpy as np import pickle import os from module.Utils import * from keras import models def save(object_to_save, path): ''' 保存对象 @Arugs: object_to_save: 需要保存的对象 @Return: 保存的路径 ''' with open(path, 'wb') as f: pickle.dump(object_to_save, f) def load(path): ''' 读取对象 @Arugs: path: 读取的路径 @Return: 读取的对象 ''' with open(path, 'rb') as f: object = pickle.load(f) return object def train(): pk_file = "iterator/data_28849_16.pk" if os.path.exists(pk_file): data = load(pk_file) #data = featureEngine.paddinig(data, pad=False) #data[1] = np.argmax(data[1],-1) #print(np.shape(data[0])) else: data = featureEngine.getAllData() save(data,"data_"+str(len(data[1]))+".pk") model1 = model.getBiRNNModel() #model1.load_weights("../model_data/ep028-loss0.062-val_loss0.102-f10.9624.h5") model_file = "contentExtract.h5" log_dir = "log/" train_percent = 0.8 test_percent=0.9 print(np.shape(data[0])) train_len = round(len(data[0])*train_percent) test_len = round(len(data[0])*test_percent) checkpoint = ModelCheckpoint(log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_acc{val_acc:.3f}-val_loss{val_loss:.3f}-f1{val_f1_score:.4f}.h5', monitor='loss', save_best_only=True, period=1,mode="min") history_model = model1.fit(x=data[0][:train_len],y=data[1][:train_len],validation_data=(data[0][train_len:test_len],data[1][train_len:test_len]),epochs=400,batch_size=256,shuffle=True,callbacks=[checkpoint]) def predict(x): ''' model1 = model.getBiRNNModel() model1.load_weights("../model_data/ep133-loss-0.991-val_acc0.972-val_loss-0.951-f10.3121.h5") ''' path = "log/ep011-loss0.160-val_acc0.900-val_loss0.156-f10.4536.h5" model1 = models.load_model(path, custom_objects={"acc":acc,"precision":precision,"recall":recall,"f1_score":f1_score,"my_loss":my_loss}) return model1.predict(x,batch_size=1) def test(url): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "" data = featureEngine.getInput_byJS(url) if data is not None: x,list_inner,list_xpath,_ = data print("x:",x) p = predict(x) print(p) print(np.argmax(p,1)) print(p[0][np.argmax(p,1)[0][1]]) print(list_inner[np.argmax(p,1)[0][1]]) #print(list_inner[4]) print(list_xpath[np.argmax(p,1)[0][1]]) def val(): pk_file = "iterator/data_28849_16.pk" data = load(pk_file) train_percent = 0.9 train_len = round(len(data[0])*train_percent) #print(np.shape(data)) predict_y = np.argmax(predict(data[0][train_len:]),1) label_y = np.argmax(data[1][train_len:],1) list_url = data[2][train_len:] size_predict = 0 size_considence = 0 dict_root_true_wrong = dict() for _predict,_label,_url in zip(predict_y,label_y,list_url): root = _url.split("/")[2] if root not in dict_root_true_wrong: dict_root_true_wrong[root] = [0,0] if _predict[1]==_label[1]: size_considence += 1 dict_root_true_wrong[root][0] += 1 else: dict_root_true_wrong[root][1] += 1 print(_url) size_predict += 1 list_root_true_wrong = [] for _key in dict_root_true_wrong.keys(): list_root_true_wrong.append([_key,dict_root_true_wrong[_key]]) list_root_true_wrong.sort(key=lambda x:x[1][1]/(x[1][0]+x[1][1])) print(list_root_true_wrong) print(size_considence,size_predict) def iteratorLabel(): ''' @summary: 迭代地进行数据的修复 ''' data_file = "iterator/data_28849_35.pk" threshold = 0.93 train_epochs = 10 batch_size=96 data = load(data_file) data_split = round(len(data[1])*0.5) last_change_set = set() this_change_set = set() max_not_change_times = 10 _not_change_times = 0 _time = 0 while(True): _time += 1 #训练模型 model_1 = model.getBiRNNModel() model_2 = model.getBiRNNModel() model_1.fit(x=data[0][:data_split],y=data[1][:data_split],epochs=train_epochs,batch_size=batch_size,shuffle=True) model_2.fit(x=data[0][data_split:],y=data[1][data_split:],epochs=train_epochs,batch_size=batch_size,shuffle=True) predict_1 = model_1.predict(data[0]) predict_2 = model_2.predict(data[0]) _index = 0 for _max_1,_max_2,_y1,_y2,Y,_url in zip(np.max(predict_1,1),np.max(predict_2,1),np.argmax(predict_1,1),np.argmax(predict_2,1),np.argmax(data[1],1),data[2]): if _y1[1]==_y2[1] and _y1[1]!=Y[1] and _max_1[1]>threshold and _max_2[1]>threshold: #修改标注 data[1][_index][Y[1]] = 0 data[1][_index][_y1[1]] = 1 this_change_set.add(_url) _index += 1 if len(this_change_set-last_change_set)<10: _not_change_times += 1 else: _not_change_times = 0 if _not_change_times>=max_not_change_times: break last_change_set = this_change_set this_change_set = set() save(data,"iterator/data_"+str(len(data[1]))+"_1.pk") if __name__=="__main__": #train() test(url = "https://www.600757.com.cn/show-106-14208-1.html") #val() #print(2248/2555) #iteratorLabel() pass