123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566 |
- #coding:utf8
- import sys
- import os
- import glob
- sys.path.append(os.path.abspath("../.."))
- import psycopg2
- from keras import models
- from keras import layers
- from keras import optimizers,losses,metrics
- from keras.callbacks import ModelCheckpoint
- import codecs
- import copy
- from BiddingKG.dl.common.Utils import *
- import pandas as pd
- sourcetable = "label_guest_role"
- domain = sourcetable.split("_")[2]
- model_file = "model_"+domain+".model"
- input_shape = (2,10,128)
- output_shape = [6]
- def getTokensLabels(t,isTrain=True,predict=False):
- '''
- @param:
- t:标注数据所在表
- isTrain:是否训练
- predict:是否是验证
- @return:返回标注数据的处理后的输入和标签
- '''
- conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
- cursor = conn.cursor()
-
- if predict:
- sql = '''
- select A.tokens,B.begin_index,B.end_index,0,B.entity_id from sentences A,entity_mention_copy B where B.entity_type in ('org','company') and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index
- and B.doc_id in (select doc_id from articles_validation ) order by B.doc_id
- '''
-
- else:
- select_sql = " select A.tokens,B.begin_index,B.end_index,C.label,C.entity_id "
-
- '''
- if isTrain:
- train_sql = " and C.id not in(select variable_id from dd_graph_variables_holdout) "
- else:
- train_sql = " and C.id in(select variable_id from dd_graph_variables_holdout)"
-
- '''
- '''
- if isTrain:
- train_sql = " and A.doc_id not in(select id from articles_processed order by id limit 1000) "
- else:
- train_sql = " and A.doc_id in(select id from articles_processed order by id limit 1000)"
- '''
- if isTrain:
- train_sql = " and C.entity_id not in(select entity_id from is_wintenderer_label_inference where id in(select variable_id from dd_graph_variables_holdout))"
- else:
- #train_sql = " and C.entity_id in(select entity_id from is_wintenderer_label_inference where id in(select variable_id from dd_graph_variables_holdout))"
- train_sql = " and exists(select 1 from test_predict_money h,entity_mention g where h.entity_id=g.entity_id and A.doc_id=g.doc_id) order by B.doc_id limit 2000 "
- sql = select_sql+" from sentences A,entity_mention_copy B,"+t+" C where B.entity_type in ('org','company') and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id "+train_sql
-
- print(sql)
- cursor.execute(sql)
-
- data_x = []
- data_y = []
- data_context = []
-
- rows = cursor.fetchmany(1000)
- allLimit = 330000
- all = 0
- while(rows):
- for row in rows:
- if all>=allLimit:
- break
- item_x = embedding(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=input_shape[1]),shape=input_shape)
- item_y = np.zeros(output_shape)
- item_y[row[3]] = 1
- all += 1
-
- if not isTrain:
- item_context = []
- item_context.append(row[4])
- data_context.append(item_context)
- data_x.append(item_x)
- data_y.append(item_y)
- rows = cursor.fetchmany(1000)
- return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_y),data_context
-
- def getBiRNNModel():
- '''
- @summary:获取模型
- '''
- L_input = layers.Input(shape=input_shape[1:],dtype="float32")
- #C_input = layers.Input(shape=(10,128),dtype="float32")
- R_input = layers.Input(shape=input_shape[1:],dtype="float32")
- #lstm_0 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(ThreeBilstm(0)(input))
- lstm_0 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(L_input)
- avg_0 = layers.GlobalAveragePooling1D()(lstm_0)
- #lstm_1 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(C_input)
- #avg_1 = layers.GlobalAveragePooling1D()(lstm_1)
- lstm_2 = layers.Bidirectional(layers.LSTM(16,return_sequences=True))(R_input)
- avg_2 = layers.GlobalAveragePooling1D()(lstm_2)
- #concat = layers.merge([avg_0,avg_1,avg_2],mode="concat")
- concat = layers.merge([avg_0,avg_2],mode="concat")
-
- output = layers.Dense(output_shape[0],activation="softmax")(concat)
-
- model = models.Model(inputs=[L_input,R_input],outputs=output)
- model.compile(optimizer=optimizers.Adam(lr=0.001),loss=losses.binary_crossentropy,metrics=[precision,recall,f1_score])
- return model
- def loadTrainData(percent=0.9):
- files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
- data_x = []
- data_y = []
- #data_id = []
- test_x = []
- test_y = []
- test_id = []
- for file in files:
- data = load(file)
-
- for row in data:
- item_x = embedding(spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=input_shape[1]),shape=input_shape)
- item_y = np.zeros(output_shape)
- label = int(row[5])
- if label not in [0,1,2,3,4,5]:
- continue
- item_y[label] = 1
- if np.random.random()<percent:
- data_x.append(item_x)
- data_y.append(item_y)
- #data_id.append(row[0])
- else:
- test_x.append(item_x)
- test_y.append(item_y)
- test_id.append(row[0])
- return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_y),np.transpose(np.array(test_x),(1,0,2,3)),np.array(test_y),None,test_id
-
- def training():
- '''
- @summary:训练模型
- '''
- model = getBiRNNModel()
- model.summary()
- #train_x,train_y,_ = getTokensLabels(isTrain=True,t="hand_label_role")
- train_x,train_y,test_x,test_y,_,test_id = loadTrainData()
- save([test_x,test_y,test_id],"val_data.pk")
- checkpoint = ModelCheckpoint(
- "../../dl_dev/role/log/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1{val_f1_score:.3f}.h5", monitor="val_loss", verbose=1, save_best_only=True, mode='min')
- print(np.shape(train_x))
- history_model = model.fit(x=[train_x[0],train_x[1]],y=train_y,validation_data=([test_x[0],test_x[1]],test_y),epochs=120,batch_size=512,shuffle=True,callbacks=[checkpoint])
- predict_y = model.predict([test_x[0],test_x[1]])
- model.save(model_file)
- #print_metrics(history_model)
-
- def val():
- files = []
- for file in glob.glob("C:\\Users\\User\\Desktop\\20190416要素\\*.html"):
- filename = file.split("\\")[-1]
- files.append(filename)
-
- conn = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
- cursor = conn.cursor()
-
- sql = '''
- select A.entity_id,A.entity_text,A.begin_index,A.end_index,A.label,A.values,B.tokens,A.doc_id
- from entity_mention A,sentences B
- where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index
- and A.entity_type in ('org','company')
- and A.label!='None'
- and not exists(select 1 from turn_label where entity_id=A.entity_id)
- order by A.label
- '''
-
- cursor.execute(sql)
- rows = cursor.fetchall()
-
- list_entity_id = []
- list_before = []
- list_after = []
- list_text = []
- list_label = []
- list_prob = []
- repeat = set()
- data_x = []
- cnn_x = []
- for row in rows:
- entity_id = row[0]
- entity_text = row[1]
- begin_index = row[2]
- end_index = row[3]
- label = int(row[4])
- values = row[5][1:-1].split(",")
- tokens = row[6]
- doc_id = row[7]
-
- if doc_id not in files:
- continue
-
- if float(values[label])<0.5:
- continue
-
- beforeafter = spanWindow(tokens, begin_index, end_index, 10,center_include=True,text=entity_text)
-
- if ("".join(beforeafter[0]),"".join(beforeafter[1]),"".join(beforeafter[2])) in repeat:
- continue
-
- repeat.add(("".join(beforeafter[0]),"".join(beforeafter[1]),"".join(beforeafter[2])))
-
- item_x = embedding(spanWindow(tokens=tokens,begin_index=begin_index,end_index=end_index,size=input_shape[1]),shape=input_shape)
- data_x.append(item_x)
-
- cnn_x.append(encodeInput(spanWindow(tokens=tokens,begin_index=begin_index,end_index=end_index,size=10,center_include=True,word_flag=True,text=entity_text), word_len=50, word_flag=True))
- list_entity_id.append(entity_id)
- list_before.append("".join(beforeafter[0]))
- list_after.append("".join(beforeafter[2]))
- list_text.append("".join(beforeafter[1]))
- list_label.append(label)
- list_prob.append(values[label])
- model = models.load_model("../../dl_dev/role/log/new_biLSTM-ep012-loss0.028-val_loss0.040-f10.954.h5", custom_objects={"precision":precision, "recall":recall, "f1_score":f1_score})
- data_x = np.transpose(np.array(data_x),(1,0,2,3))
- predict_value = model.predict([data_x[0],data_x[1]])
- predict_y = np.argmax(predict_value,1)
- list_newprob = []
- for label,value in zip(predict_y,predict_value):
- list_newprob.append(value[label])
- print("len",len(list_entity_id))
-
-
- model_cnn = models.load_model("../../dl_dev/role/log/ep071-loss0.107-val_loss0.122-f10.956.h5", custom_objects={"precision":precision, "recall":recall, "f1_score":f1_score})
- cnn_x = np.transpose(np.array(cnn_x),(1,0,2))
- predict_value = model_cnn.predict([cnn_x[0],cnn_x[1],cnn_x[2]])
- predict_y_cnn = np.argmax(predict_value,1)
- list_newprob_cnn = []
- for label,value in zip(predict_y_cnn,predict_value):
- list_newprob_cnn.append(value[label])
- print("len",len(list_entity_id))
-
- data = []
- for id,before,text,after,label_bi,prob,label_bi1,newprob,label_cnn,newprob_cnn in zip(list_entity_id,list_before,list_text,list_after,list_label,list_prob,predict_y,list_newprob,predict_y_cnn,list_newprob_cnn):
- if label_bi1!=label_cnn:
- data.append([id,before,text,after,label_bi,prob,label_bi1,newprob,label_cnn,newprob_cnn])
- data.sort(key=lambda x:x[6])
- list_entity_id = []
- list_before = []
- list_after = []
- list_text = []
- list_label = []
- list_prob = []
- list_newlabel = []
- list_newprob = []
- list_newlabel_cnn = []
- list_newprob_cnn = []
- for item in data:
- list_entity_id.append(item[0])
- list_before.append(item[1])
- list_text.append(item[2])
- list_after.append(item[3])
- list_label.append(item[4])
- list_prob.append(item[5])
- list_newlabel.append(item[6])
- list_newprob.append(item[7])
- list_newlabel_cnn.append(item[8])
- list_newprob_cnn.append(item[9])
-
-
- parts = 1
- parts_num = len(list_entity_id)//parts
- for i in range(parts-1):
-
- data = {"entity_id":list_entity_id[i*parts_num:(i+1)*parts_num],"list_before":list_before[i*parts_num:(i+1)*parts_num],"list_after":list_after[i*parts_num:(i+1)*parts_num],"list_text":list_text[i*parts_num:(i+1)*parts_num],"list_label":list_label[i*parts_num:(i+1)*parts_num],"list_prob":list_prob[i*parts_num:(i+1)*parts_num]}
- df = pd.DataFrame(data)
- df.to_excel("未标注错误_"+str(i)+".xls",columns=["entity_id","list_before","list_text","list_after","list_label","list_prob"])
- i = parts - 1
- data = {"entity_id":list_entity_id[i*parts_num:],"list_before":list_before[i*parts_num:],"list_after":list_after[i*parts_num:],"list_text":list_text[i*parts_num:],"list_label":list_label[i*parts_num:],"list_prob":list_prob[i*parts_num:],"list_newlabel":list_newlabel[i*parts_num:],"list_newprob":list_newprob[i*parts_num:],"list_newlabel_cnn":list_newlabel_cnn[i*parts_num:],"list_newprob_cnn":list_newprob_cnn[i*parts_num:]}
- df = pd.DataFrame(data)
- df.to_excel("测试数据_role-cnnw-biw"+str(i)+".xls",columns=["entity_id","list_before","list_text","list_after","list_label","list_prob","list_newlabel","list_newprob","list_newlabel_cnn","list_newprob_cnn"])
-
-
- def validation():
- conn1 = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
- conn2 = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
- cursor1 = conn1.cursor()
- cursor2 = conn2.cursor()
- model = getBiRNNModel()
- model.load_weights("log/ep010-loss0.033-val_loss0.043-f10.950.h5")
- [test_x,test_y,test_id] = load("val_data.pk")
- predict_y = model.predict([test_x[0],test_x[1]])
- list_id = []
- list_before = []
- list_text = []
- list_after = []
- list_same = []
- list_predict = []
- list_label = []
- data = []
- for id,predict,label in zip(test_id,np.argmax(predict_y,1),np.argmax(test_y,1)):
- if predict==label:
- same = 0
- text = ""
- beforeafter = [[],[]]
- else:
- same = 1
- if re.search("比地",id) is not None:
- sql = " select A.tokens,B.entity_text,B.begin_index,B.end_index from sentences A,entity_mention B where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id='"+id+"' "
- cursor1.execute(sql)
- rows = cursor1.fetchall()
- else:
- sql = " select A.tokens,B.entity_text,B.begin_index,B.end_index from sentences A,entity_mention_copy B where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id='"+id+"' "
- cursor2.execute(sql)
- rows = cursor2.fetchall()
- retu = rows[0]
- text = retu[1]
- beforeafter = spanWindow(retu[0], retu[2], retu[3], 10)
-
- data.append([id,same,"".join(beforeafter[0]),text,"".join(beforeafter[1]),label,predict])
- data.sort(key=lambda x:x[1])
- for item in data:
- list_id.append(item[0])
- list_same.append(item[1])
- list_before.append(item[2])
- list_text.append(item[3])
- list_after.append(item[4])
- list_label.append(item[5])
- list_predict.append(item[6])
- df = pd.DataFrame({"list_id":list_id,"list_same":list_same,"list_before":list_before,"list_text":list_text,"list_after":list_after,"list_label":list_label,"list_predict":list_predict})
- columns = ["list_id","list_same","list_before","list_text","list_after","list_label","list_predict"]
- df.to_excel("result.xls",index=False,columns=columns)
- conn1.close()
- conn2.close()
-
-
- def trainingIteration_category(iterate=2,label_table=sourcetable):
- '''
- @summary: 迭代训练模型,修改标签,适用于当数据准确率不高的条件
- @param:
- iterate:迭代次数
- label_table:标签数据所在表
-
- '''
- def getDatasets():
- conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
- cursor = conn.cursor()
-
- select_sql = " select A.tokens,B.begin_index,B.end_index,C.label,C.entity_id "
-
- sql = select_sql+" from sentences A,entity_mention B,"+label_table+" C where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id order by A.doc_id "
- cursor.execute(sql)
-
- print(sql)
-
- data_x = []
- data_y = []
- id_set = []
- rows = cursor.fetchmany(1000)
- allLimit = 320000
- all = 0
- while(rows):
- for row in rows:
- if all>=allLimit:
- break
- item_x = embedding(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2]))
- item_y = np.zeros(output_shape)
- item_y[row[3]] = 1
- all += 1
-
- data_x.append(item_x)
- data_y.append(item_y)
- id_set.append(row[4])
- rows = cursor.fetchmany(1000)
- return np.transpose(np.array(data_x),(1,0,2,3)),np.array(data_y),id_set
- train_x,train_y,id_set = getDatasets()
- alllength = len(train_x[0])
- parts = 6
- num_parts = alllength//parts
- copy_y = copy.copy(train_y)
- for ite in range(iterate):
-
- for j in range(parts-1):
- print("iterate:",str(ite)+"/"+str(iterate-1),str(j)+"/"+str(parts-1))
- model = getBiRNNModel()
- model.summary()
- test_begin = j*num_parts
- test_end = (j+1)*num_parts
- checkpoint = ModelCheckpoint(model_file+".hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode='min')
- history_model = model.fit(x=[np.concatenate((train_x[0][0:test_begin],train_x[0][test_end:])),np.concatenate((train_x[1][0:test_begin],train_x[1][test_end:]))],y=np.concatenate((copy_y[0:test_begin],copy_y[test_end:])),validation_data=([train_x[0][test_begin:test_end],train_x[1][test_begin:test_end]],copy_y[test_begin:test_end]),epochs=30,batch_size=300,shuffle=True,callbacks=[checkpoint])
- model.load_weights(model_file+".hdf5")
- predict_y = model.predict([train_x[0][test_begin:test_end],train_x[1][test_begin:test_end]])
- for i in range(len(predict_y)):
- if np.max(predict_y[i])>=0.8:
- max_index = np.argmax(predict_y[i])
- for h in range(len(predict_y[i])):
- if h==max_index:
- copy_y[i+test_begin][h] = 1
- else:
- copy_y[i+test_begin][h] = 0
- print("iterate:",str(ite)+"/"+str(iterate-1),str(j)+"/"+str(parts-1))
- model = getBiRNNModel()
- model.summary()
- test_begin = j*num_parts
- checkpoint = ModelCheckpoint(model_file+".hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode="min")
- history_model = model.fit(x=[train_x[0][0:test_begin],train_x[1][0:test_begin]],y=copy_y[0:test_begin],validation_data=([train_x[0][test_begin:],train_x[1][test_begin:]],copy_y[test_begin:]),epochs=30,batch_size=300,shuffle=True,callbacks=[checkpoint])
- model.load_weights(model_file+".hdf5")
- predict_y = model.predict([train_x[0][test_begin:],train_x[1][test_begin:]])
- for i in range(len(predict_y)):
- if np.max(predict_y[i])>=0.8:
- max_index = np.argmax(predict_y[i])
- for h in range(len(predict_y[i])):
- if h==max_index:
- copy_y[i+test_begin][h] = 1
- else:
- copy_y[i+test_begin][h] = 0
-
- with codecs.open("final_label_"+domain+".txt","w",encoding="utf8") as f:
- for i in range(len(id_set)):
- f.write(id_set[i])
- f.write("\t")
- f.write(str(np.argmax(copy_y[i])))
- f.write("\n")
- f.flush()
- f.close()
- def predict():
- '''
- @summary: 预测测试数据
- '''
- test_x,_,ids = getTokensLabels("final_label_role", isTrain=False,predict=True)
- model = models.load_model(model_file,custom_objects={'precision':precision,'recall':recall,'f1_score':f1_score})
- predict_y = model.predict([test_x[0],test_x[1]])
- with codecs.open("test_predict_"+domain+".txt","w",encoding="utf8") as f:
- for i in range(len(predict_y)):
- f.write(ids[i][0])
- f.write("\t")
- f.write(str(np.argmax(predict_y[i])))
- f.write("\t")
- value = ""
- for item in predict_y[i]:
- value += str(item)+","
- f.write(value[:-1])
- f.write("\n")
- f.flush()
- f.close()
- def importIterateLabel():
- '''
- @summary:导入迭代之后的标签值
- '''
-
- file = "final_label_"+domain+".txt"
-
- conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
-
- cursor = conn.cursor()
- tablename = file.split(".")[0]
- # 创建表
- cursor.execute(" SELECT to_regclass('"+tablename+"') is null ")
- flag = cursor.fetchall()[0][0]
- if flag:
- cursor.execute(" create table "+tablename+"(entity_id text,label int)")
- else:
- cursor.execute(" delete from "+tablename)
-
-
-
- with codecs.open(file,"r",encoding="utf8") as f:
- while(True):
- line = f.readline()
- if not line:
- break
- line_split = line.split("\t")
- entity_id=line_split[0]
- label = line_split[1]
- sql = " insert into "+tablename+"(entity_id,label) values('"+str(entity_id)+"',"+str(label)+")"
- cursor.execute(sql)
- f.close()
- conn.commit()
- conn.close()
-
- def importtestPredict():
- '''
- @summary:导入测试数据的预测值
-
- '''
- file = "test_predict_"+domain+".txt"
- conn = psycopg2.connect(dbname="BiddingKG",user="postgres",password="postgres",host="192.168.2.101")
-
- cursor = conn.cursor()
-
- tablename = file.split(".")[0]
- # 创建表
- cursor.execute(" SELECT to_regclass('"+tablename+"') is null ")
-
- flag = cursor.fetchall()[0][0]
- if flag:
- cursor.execute(" create table "+tablename+"(entity_id text,label int,value text)")
- else:
- cursor.execute(" delete from "+tablename)
-
- with codecs.open(file,"r",encoding="utf8") as f:
- while(True):
- line = f.readline()
- if not line:
- break
- line_split = line.split("\t")
- entity_id=line_split[0]
- predict = line_split[1]
- value = line_split[2]
- sql = " insert into "+tablename+"(entity_id,label,value) values('"+str(entity_id)+"',"+str(predict)+",'"+str(value)+"')"
- cursor.execute(sql)
- f.close()
- conn.commit()
- conn.close()
-
-
- def autoIterate():
- #trainingIteration_binary()
- trainingIteration_category()
- importIterateLabel()
- training()
- predict()
-
- def test1(entity_id):
- conn = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
- cursor = conn.cursor()
-
- if predict:
- sql = "select A.tokens,B.begin_index,B.end_index,0,B.entity_id from sentences A,entity_mention B where B.entity_type in ('org','company') and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id='"+entity_id+"'"
- print(sql)
- cursor.execute(sql)
-
- data_x = []
- data_y = []
-
- rows = cursor.fetchmany(1000)
- while(rows):
- for row in rows:
- item_x = encodeInput(spanWindow(tokens=row[0],begin_index=row[1],end_index=row[2],size=10,center_include=True,word_flag=True), word_len=50, word_flag=True)
- item_y = np.zeros(output_shape)
- item_y[row[3]] = 1
-
- data_x.append(item_x)
- data_y.append(item_y)
- rows = cursor.fetchmany(1000)
- model = models.load_model("../../dl_dev/role/log/ep017-loss0.088-val_loss0.125-f10.955.h5", custom_objects={'precision':precision, 'recall':recall, 'f1_score':f1_score})
- test_x = np.transpose(np.array(data_x),(1,0,2))
- predict_y = model.predict([test_x[0],test_x[1],test_x[2]])
- print(predict_y)
-
- if __name__=="__main__":
- #training()
- val()
- #validation()
- #test()
- #trainingIteration_category()
- #importIterateLabel()
- #predict()
- #importtestPredict()
- #autoIterate()
- #test1("比地_101_61333318.html_0_116_122")
|