|
@@ -27,6 +27,13 @@ import calendar
|
|
|
import datetime
|
|
|
# import fool # 统一用 selffool ,阿里云上只有selffool 包
|
|
|
|
|
|
+cpu_num = int(os.environ.get("CPU_NUM",0))
|
|
|
+sess_config = tf.ConfigProto(
|
|
|
+ inter_op_parallelism_threads = cpu_num,
|
|
|
+ intra_op_parallelism_threads = cpu_num,
|
|
|
+ log_device_placement=True)
|
|
|
+sess_config = None
|
|
|
+
|
|
|
from threading import RLock
|
|
|
dict_predictor = {"codeName":{"predictor":None,"Lock":RLock()},
|
|
|
"prem":{"predictor":None,"Lock":RLock()},
|
|
@@ -51,11 +58,11 @@ def getPredictor(_type):
|
|
|
with dict_predictor[_type]["Lock"]:
|
|
|
if dict_predictor[_type]["predictor"] is None:
|
|
|
if _type == "codeName":
|
|
|
- dict_predictor[_type]["predictor"] = CodeNamePredict()
|
|
|
+ dict_predictor[_type]["predictor"] = CodeNamePredict(config=sess_config)
|
|
|
if _type == "prem":
|
|
|
- dict_predictor[_type]["predictor"] = PREMPredict()
|
|
|
+ dict_predictor[_type]["predictor"] = PREMPredict(config=sess_config)
|
|
|
if _type == "epc":
|
|
|
- dict_predictor[_type]["predictor"] = EPCPredict()
|
|
|
+ dict_predictor[_type]["predictor"] = EPCPredict(config=sess_config)
|
|
|
if _type == "roleRule":
|
|
|
dict_predictor[_type]["predictor"] = RoleRulePredictor()
|
|
|
if _type == "roleRuleFinal":
|
|
@@ -63,17 +70,17 @@ def getPredictor(_type):
|
|
|
if _type == "tendereeRuleRecall":
|
|
|
dict_predictor[_type]["predictor"] = TendereeRuleRecall()
|
|
|
if _type == "form":
|
|
|
- dict_predictor[_type]["predictor"] = FormPredictor()
|
|
|
+ dict_predictor[_type]["predictor"] = FormPredictor(config=sess_config)
|
|
|
if _type == "time":
|
|
|
- dict_predictor[_type]["predictor"] = TimePredictor()
|
|
|
+ dict_predictor[_type]["predictor"] = TimePredictor(config=sess_config)
|
|
|
if _type == "punish":
|
|
|
dict_predictor[_type]["predictor"] = Punish_Extract()
|
|
|
if _type == "product":
|
|
|
- dict_predictor[_type]["predictor"] = ProductPredictor()
|
|
|
+ dict_predictor[_type]["predictor"] = ProductPredictor(config=sess_config)
|
|
|
if _type == "product_attrs":
|
|
|
dict_predictor[_type]["predictor"] = ProductAttributesPredictor()
|
|
|
if _type == "channel":
|
|
|
- dict_predictor[_type]["predictor"] = DocChannel()
|
|
|
+ dict_predictor[_type]["predictor"] = DocChannel(config=sess_config)
|
|
|
if _type == 'deposit_payment_way':
|
|
|
dict_predictor[_type]["predictor"] = DepositPaymentWay()
|
|
|
if _type == 'total_unit_money':
|
|
@@ -87,7 +94,7 @@ def getPredictor(_type):
|
|
|
# 编号名称模型
|
|
|
class CodeNamePredict():
|
|
|
|
|
|
- def __init__(self,EMBED_DIM=None,BiRNN_UNITS=None,lazyLoad=getLazyLoad()):
|
|
|
+ def __init__(self,EMBED_DIM=None,BiRNN_UNITS=None,lazyLoad=getLazyLoad(),config=None):
|
|
|
|
|
|
self.model = None
|
|
|
self.MAX_LEN = None
|
|
@@ -123,8 +130,8 @@ class CodeNamePredict():
|
|
|
|
|
|
self.inputs = None
|
|
|
self.outputs = None
|
|
|
- self.sess_codename = tf.Session(graph=tf.Graph())
|
|
|
- self.sess_codesplit = tf.Session(graph=tf.Graph())
|
|
|
+ self.sess_codename = tf.Session(graph=tf.Graph(),config=config)
|
|
|
+ self.sess_codesplit = tf.Session(graph=tf.Graph(),config=config)
|
|
|
self.inputs_code = None
|
|
|
self.outputs_code = None
|
|
|
if not lazyLoad:
|
|
@@ -535,11 +542,11 @@ class CodeNamePredict():
|
|
|
class PREMPredict():
|
|
|
|
|
|
|
|
|
- def __init__(self):
|
|
|
+ def __init__(self,config=None):
|
|
|
#self.model_role_file = os.path.abspath("../role/models/model_role.model.hdf5")
|
|
|
self.model_role_file = os.path.dirname(__file__)+"/../role/log/new_biLSTM-ep012-loss0.028-val_loss0.040-f10.954.h5"
|
|
|
- self.model_role = Model_role_classify_word()
|
|
|
- self.model_money = Model_money_classify()
|
|
|
+ self.model_role = Model_role_classify_word(config=config)
|
|
|
+ self.model_money = Model_money_classify(config=config)
|
|
|
|
|
|
return
|
|
|
|
|
@@ -734,8 +741,8 @@ class PREMPredict():
|
|
|
#联系人模型
|
|
|
class EPCPredict():
|
|
|
|
|
|
- def __init__(self):
|
|
|
- self.model_person = Model_person_classify()
|
|
|
+ def __init__(self,config=None):
|
|
|
+ self.model_person = Model_person_classify(config=config)
|
|
|
|
|
|
|
|
|
|
|
@@ -1074,13 +1081,13 @@ class EPCPredict():
|
|
|
#表格预测
|
|
|
class FormPredictor():
|
|
|
|
|
|
- def __init__(self,lazyLoad=getLazyLoad()):
|
|
|
+ def __init__(self,lazyLoad=getLazyLoad(),config=None):
|
|
|
self.model_file_line = os.path.dirname(__file__)+"/../form/model/model_form.model_line.hdf5"
|
|
|
self.model_file_item = os.path.dirname(__file__)+"/../form/model/model_form.model_item.hdf5"
|
|
|
- self.model_form_item = Model_form_item()
|
|
|
- self.model_form_context = Model_form_context()
|
|
|
+ self.model_form_item = Model_form_item(config=config)
|
|
|
self.model_dict = {"line":[None,self.model_file_line]}
|
|
|
-
|
|
|
+ self.model_form_context = Model_form_context(config=config)
|
|
|
+
|
|
|
|
|
|
def getModel(self,type):
|
|
|
if type=="item":
|
|
@@ -1690,8 +1697,8 @@ class TendereeRuleRecall():
|
|
|
|
|
|
# 时间类别
|
|
|
class TimePredictor():
|
|
|
- def __init__(self):
|
|
|
- self.sess = tf.Session(graph=tf.Graph())
|
|
|
+ def __init__(self,config=None):
|
|
|
+ self.sess = tf.Session(graph=tf.Graph(),config=config)
|
|
|
self.inputs_code = None
|
|
|
self.outputs_code = None
|
|
|
self.input_shape = (2,40,128)
|
|
@@ -1795,11 +1802,11 @@ class TimePredictor():
|
|
|
|
|
|
# 产品字段提取
|
|
|
class ProductPredictor():
|
|
|
- def __init__(self):
|
|
|
+ def __init__(self,config=None):
|
|
|
vocabpath = os.path.dirname(__file__) + "/codename_vocab.pk"
|
|
|
self.vocab = load(vocabpath)
|
|
|
self.word2index = dict((w, i) for i, w in enumerate(np.array(self.vocab)))
|
|
|
- self.sess = tf.Session(graph=tf.Graph())
|
|
|
+ self.sess = tf.Session(graph=tf.Graph(),config=config)
|
|
|
self.load_model()
|
|
|
|
|
|
def load_model(self):
|
|
@@ -2515,9 +2522,9 @@ class ProductAttributesPredictor():
|
|
|
|
|
|
# docchannel类型提取
|
|
|
class DocChannel():
|
|
|
- def __init__(self, life_model='/channel_savedmodel/channel.pb', type_model='/channel_savedmodel/doctype.pb'):
|
|
|
+ def __init__(self, life_model='/channel_savedmodel/channel.pb', type_model='/channel_savedmodel/doctype.pb',config=None):
|
|
|
self.lift_sess, self.lift_title, self.lift_content, self.lift_prob, self.lift_softmax,\
|
|
|
- self.mask, self.mask_title = self.load_life(life_model)
|
|
|
+ self.mask, self.mask_title = self.load_life(life_model,config)
|
|
|
self.type_sess, self.type_title, self.type_content, self.type_prob, self.type_softmax,\
|
|
|
self.type_mask, self.type_mask_title = self.load_type(type_model)
|
|
|
self.sequen_len = 200 # 150 200
|
|
@@ -2578,7 +2585,7 @@ class DocChannel():
|
|
|
'招标公告': '(采购|招标|询价|议价|竞价|比价|比选|遴选|邀请|邀标|磋商|洽谈|约谈|谈判|拍卖|招租|交易|出让)的?(公告|公示|$)|公开(采购|招标|招租|拍卖|挂牌|出让)|(资审|预审|后审)公告',
|
|
|
}
|
|
|
|
|
|
- def load_life(self,life_model):
|
|
|
+ def load_life(self,life_model,config):
|
|
|
with tf.Graph().as_default() as graph:
|
|
|
output_graph_def = graph.as_graph_def()
|
|
|
with open(os.path.dirname(__file__)+life_model, 'rb') as f:
|
|
@@ -2586,7 +2593,7 @@ class DocChannel():
|
|
|
tf.import_graph_def(output_graph_def, name='')
|
|
|
# print("%d ops in the final graph" % len(output_graph_def.node))
|
|
|
del output_graph_def
|
|
|
- sess = tf.Session(graph=graph)
|
|
|
+ sess = tf.Session(graph=graph,config=config)
|
|
|
sess.run(tf.global_variables_initializer())
|
|
|
inputs = sess.graph.get_tensor_by_name('inputs/inputs:0')
|
|
|
prob = sess.graph.get_tensor_by_name('inputs/dropout:0')
|