123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- '''
- Created on 2019年1月15日
- @author: User
- '''
- import os
- import sys
- sys.path.append(os.path.abspath("../../"))
- import tensorflow as tf
- # from tensorflow.contrib.crf import crf_log_likelihood
- import json
- import numpy as np
- from BiddingKG.dl.foolnltk.label import readlabeldata,getContext
- from sklearn.metrics import accuracy_score
- from bi_lstm_crf import *
- from zipfile import ZipFile
- import math
- import random
- import codecs
- OOV_STR = "<OOV>"
- def _load_map_file(path, char_map_name, id_map_name):
- with ZipFile(path) as myzip:
- with myzip.open('all_map.json') as myfile:
- content = myfile.readline()
- content = content.decode()
- data = json.loads(content)
- return data.get(char_map_name), data.get(id_map_name)
- def loss_layer(project_logits,y_target,trans,lengths,batch_size,num_steps,num_tags=25):
- with tf.variable_scope("crf_loss1"):
- small = -1000.0
- start_logits = tf.concat(
- [small * tf.ones(shape=[batch_size, 1, num_tags]), tf.zeros(shape=[batch_size, 1, 1])],
- axis=-1)
-
- pad_logits = tf.cast(small * tf.ones([batch_size, num_steps, 1]), tf.float32)
- logits = tf.concat([project_logits, pad_logits], axis=-1)
- logits = tf.concat([start_logits, logits], axis=1)
- targets = tf.concat(
- [tf.cast(num_tags * tf.ones([batch_size, 1]), tf.int32), y_target], axis=-1)
-
- log_likelihood, trans = crf_log_likelihood(
- inputs=logits,
- tag_indices=targets,
- transition_params=trans,
- sequence_lengths=lengths + 1)
-
- return tf.reduce_mean(-log_likelihood)
-
- def getAcc(y_batch,logits,trans,lengths):
- index = 0
- small = -1000.0
- start = np.asarray([[small] * 25 + [0]])
-
- preds = []
- true_tags = []
- for score, length in zip(logits, lengths):
- score = score[:length]
- pad = small * np.ones([length, 1])
- logit = np.concatenate([score, pad], axis=1)
- logit = np.concatenate([start, logit], axis=0)
- path, _ = viterbi_decode(logit, trans)
- preds += path[1:]
- index += 1
- for y, length in zip(y_batch, lengths):
- y = y.tolist()
- true_tags += y[: length]
- acc = accuracy_score(np.reshape(true_tags,(-1)), np.reshape(preds,(-1)))
- return acc
-
- def encode(list_sent_label):
- data = []
- char_to_id, id_to_seg = _load_map_file("data/map.zip", "char_map", "ner_map")
- id_to_tag = {int(k):v for k,v in id_to_seg.items()}
-
- tag_to_id = {v:int(k) for k,v in id_to_seg.items()}
- list_sent = []
- list_label = []
- for sent_label in list_sent_label:
- _sent = []
- _label = []
- for item in sent_label:
- if item[0]=="。":
- _sent.append(item[0])
- _label.append(item[1])
- list_sent.append(_sent)
- list_label.append(_label)
- _sent = []
- _label = []
- else:
- _sent.append(item[0])
- _label.append(item[1])
- if sent_label[-1][0]!="。":
- list_sent.append(_sent)
- list_label.append(_label)
-
- for _sent,_label in zip(list_sent,list_label):
- sent_ids = [char_to_id.get(w) if w in char_to_id else char_to_id.get("<OOV>") for w in _sent]
- sent_tags = [tag_to_id.get(t) for t in _label]
- data.append([sent_ids,sent_tags,len(_sent)])
-
- return data
- class NERBatchManager(object):
- def __init__(self, data, batch_size):
- self.MAX_LEN = 300
- self.char_to_id, self.id_to_seg = _load_map_file("data/map.zip", "char_map", "ner_map")
- self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()}
- self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()}
- self.batch_data = self.sort_and_pad(data, batch_size)
- self.len_data = len(self.batch_data)
- print(self.len_data)
- def sort_and_pad(self, data, batch_size):
- num_batch = int(math.ceil(len(data) / batch_size))
- sorted_data = sorted(data, key=lambda x: len(x[0]))
- batch_data = list()
- for i in range(num_batch):
- _item = self.pad_data(sorted_data[i*batch_size : (i+1)*batch_size])
- if _item is not None:
- batch_data.append(_item)
- return batch_data
- def pad_data(self,data):
- batch_strings = []
- batch_targets = []
- batch_lengths = []
- max_length = max([len(sentence[0]) for sentence in data])
-
- if max_length>1000:
- return None
-
- for line in data:
- text_ids,text_labels,text_len = line
- padding = [0] * (max_length - len(text_ids))
- padding_l = [self.tag_to_id.get("O")]*(max_length-len(text_labels))
-
- batch_strings.append(text_ids+padding)
- batch_targets.append(text_labels+padding_l)
- batch_lengths.append(text_len)
- return [np.array(batch_strings), np.array(batch_targets), np.array(batch_lengths)]
- def iter_batch(self, shuffle=False):
- if shuffle:
- random.shuffle(self.batch_data)
- for idx in range(self.len_data):
- yield self.batch_data[idx]
- def train():
-
- '''
- list_sent_label = labelEntity()
- data = encode(list_sent_label)
- '''
- bilstm = BertCRF()
-
- list_context = getContext("ner_train.txt")
- '''
- data_train = readlabeldata("ner_train.txt",list_context)
- data_train_split = readlabeldata("ner_train_split.txt",list_context)
- data_label_split = readlabeldata("ner_label_split.txt",list_context)
- data_label = readlabeldata("ner_label.txt",list_context)
- data_entity_found = readlabeldata("company_found.txt",list_context,keep_prob=0.33)
- data_entity_cleaned = readlabeldata("cleanedEntity.txt",list_context)
- data = data_train + data_entity_found+data_train_split +data_entity_cleaned+data_label_split+ data_label
- '''
- data = readlabeldata("selffool_train.txt", list_context)
- #data = readlabeldata("ner_label.txt",[])
- test_len = 128
-
- ner_bm_train = NERBatchManager(data[:-test_len], batch_size=96)
-
- test_x = []
- test_y = []
- test_length = []
-
- max_length = 0
- for line in data[-test_len:]:
- text_ids,text_labels,text_len = line
- if max_length<300 and text_len>max_length and text_len<300:
- max_length = text_len
- print("test_len",max_length)
- for line in data[-test_len:]:
- text_ids,text_labels,text_len = line
- if text_len>max_length:
- continue
- padding = [0] * (max_length - len(text_ids))
- padding_l = [ner_bm_train.tag_to_id.get("O")]*(max_length-len(text_labels))
-
- test_x.append(text_ids+padding)
- test_y.append(text_labels+padding_l)
- test_length.append(text_len)
- test_x = np.array(test_x)
- test_y = np.array(test_y)
- test_length = np.array(test_length)
-
-
-
-
- with bilstm.sess as sess:
- input_x,y_target,lengths,dropout,logits,trans,loss,train_op = bilstm.getNodes()
- bilstm.initVariables()
-
- #bilstm.restore()
-
- epochs = 100
- saver = tf.train.Saver(max_to_keep=epochs)
- with codecs.open("log.txt","w",encoding="utf8") as f:
- test_index = 0
- for i in range(epochs):
- _batch = 0
- for x_batch,y_batch,length_batch in ner_bm_train.iter_batch(shuffle=True):
- print(np.shape(x_batch))
- _batch += 1
- feed_dict = {
- input_x: x_batch,
- y_target: y_batch,
- lengths: length_batch.reshape(-1, ),
- dropout: 0.75
- }
- _,_loss,_logits,_trans = sess.run([train_op,loss,logits,trans], feed_dict=feed_dict)
- acc = getAcc(y_batch, _logits, _trans, length_batch.reshape(-1, ))
- print("epochs:",i,"batch:",_batch,_loss,acc)
- f.write("epochs:"+str(i)+"/batch:"+str(_batch)+"/loss:"+str(_loss)+"/acc:"+str(acc)+"\n")
- print("testing")
- feed_dict = {
- input_x: test_x,
- y_target: test_y,
- lengths: test_length.reshape(-1, ),
- dropout: 1.0
- }
- _loss,_logits,_trans = sess.run([loss,logits,trans], feed_dict=feed_dict)
- acc = getAcc(test_y, _logits, _trans, test_length.reshape(-1, ))
- print("---test","epochs:",i,acc)
- f.write("---test"+"epochs:"+str(test_index)+"-"+str(i)+"/loss:"+str(_loss)+"/acc:"+str(acc)+"\n")
- print("============","epochs:",str(i),"loss:",str(_loss)+"done")
- saver.save(sess,'model/'+str(test_index)+"-"+str(i)+'/model.ckpt')
- #saver.save(sess,"/model/model-epoch:"+str(i)+"-loss:"+str(_loss)[0:6]+"-acc:"+str(acc)[0:6]+".ckpt",global_step=1)
- if __name__=="__main__":
- train()
|