浏览代码

Merge remote-tracking branch 'origin/master'

fangjiasheng 3 年之前
父节点
当前提交
dd8859b4e6

+ 1 - 1
.idea/misc.xml

@@ -1,6 +1,6 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <project version="4">
-  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.5 (py3.5)" project-jdk-type="Python SDK" />
+  <component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.5.0 (sftp://yons@192.168.2.103:22/data/home/python/anaconda3/envs/dl_nlp/bin/python)" project-jdk-type="Python SDK" />
   <component name="PythonCompatibilityInspectionAdvertiser">
     <option name="version" value="3" />
   </component>

+ 1 - 1
BiddingKG.iml

@@ -7,7 +7,7 @@
   </component>
   <component name="NewModuleRootManager">
     <content url="file://$MODULE_DIR$" />
-    <orderEntry type="jdk" jdkName="Python 3.5 (py3.5)" jdkType="Python SDK" />
+    <orderEntry type="jdk" jdkName="Remote Python 3.5.0 (sftp://yons@192.168.2.103:22/data/home/python/anaconda3/envs/dl_nlp/bin/python)" jdkType="Python SDK" />
     <orderEntry type="sourceFolder" forTests="false" />
     <orderEntry type="library" exported="" name="Python 3.5 (dl_nlp) interpreter library" level="application" />
   </component>

+ 16 - 5
BiddingKG/dl/interface/Preprocessing.py

@@ -1036,7 +1036,7 @@ def segment(soup,final=True):
             text = soup.get_text()
             # 2020/11/24 大网站规则添加
             if 'title' in soup.attrs:
-                if '...' in soup.get_text() and (soup.get_text()[:-3]).strip() in soup.attrs['title']:
+                if '...' in soup.get_text() and soup.get_text().strip()[:-3] in soup.attrs['title']:
                     text = soup.attrs['title']
 
             _list = []
@@ -1086,8 +1086,8 @@ def segment(soup,final=True):
     text = re.sub("(?<=[\u4e00-\u9fa5]);|;(?=[\u4e00-\u9fa5])",";",text)
     # 感叹号替换为中文句号
     text = re.sub("(?<=[\u4e00-\u9fa5])[!!]|[!!](?=[\u4e00-\u9fa5])","。",text)
-    #替换"?"为 " " ,update:2021/7/20
-    text = re.sub("?{1,}"," ",text)
+    #替换格式未识别的问号为" " ,update:2021/7/20
+    text = re.sub("[?\?]{2,}"," ",text)
 
 
     #替换"""为"“",否则导入deepdive出错
@@ -1678,7 +1678,7 @@ def get_preprocessed_entitys(list_sentences,useselffool=True,cost_time=dict()):
 
             '''正则识别角色实体  经营部|经销部|电脑部|服务部|复印部|印刷部|彩印部|装饰部|修理部|汽修部|修理店|零售店|设计店|服务店|家具店|专卖店|分店|文具行|商行|印刷厂|修理厂|维修中心|修配中心|养护中心|服务中心|会馆|文化馆|超市|门市|商场|家具城|印刷社|经销处'''
             for it in re.finditer(
-                    '(?P<text_key_word>[^,。、;《]{,5}(单一来源|中标|中选|中价|成交)?(供应商|供货商|服务商|候选人|单位|人)(名称)?为?[::]+)(?P<text>([^,。、;《]{5,20})(厂|中心|超市|门市|商场|工作室|文印室|城|部|店|站|馆|行|社|处))[,。]',
+                    '(?P<text_key_word>[^,。、;《]{,5}(单一来源|中标|中选|中价|成交)?(供应商|供货商|服务商|候选人|单位|人)(名称)?[为::]+)(?P<text>([^,。、;《]{5,20})(厂|中心|超市|门市|商场|工作室|文印室|城|部|店|站|馆|行|社|处))[,。]',
                     sentence_text):
                 for k, v in it.groupdict().items():
                     if k == 'text_key_word':
@@ -1689,8 +1689,19 @@ def get_preprocessed_entitys(list_sentences,useselffool=True,cost_time=dict()):
                 e = it.end() - 1
                 if (b, e, 'org', entity) not in ner_entitys and (b, e, 'company', entity) not in ner_entitys:
                     ner_entitys.append((b, e, 'company', entity))
-                    # print('正则新增 :',(b, e, 'company', entity))
 
+            for it in re.finditer(
+                    '(?P<text_key_word>((建设|招租|招标|采购)(单位|人)|业主)(名称)?[为::]+)(?P<text>\w{2,4}[省市县区镇]([^,。、;《]{2,20})(管理处|办公室|委员会|村委会|纪念馆|监狱|管教所|修养所|社区|农场|林场|羊场|猪场|石场|村|幼儿园))[,。]',
+                    sentence_text):
+                for k, v in it.groupdict().items():
+                    if k == 'text_key_word':
+                        keyword = v
+                    if k == 'text':
+                        entity = v
+                b = it.start() + len(keyword)
+                e = it.end() - 1
+                if (b, e, 'org', entity) not in ner_entitys and (b, e, 'company', entity) not in ner_entitys:
+                    ner_entitys.append((b, e, 'org', entity))
 
             #识别package
 

+ 18 - 14
BiddingKG/dl/interface/extract.py

@@ -4,6 +4,7 @@ Created on 2019年1月4日
 @author: User
 '''
 import os
+
 from bs4 import BeautifulSoup, Comment
 import copy
 import re
@@ -41,7 +42,6 @@ class MyEncoder(json.JSONEncoder):
             return obj
         return json.JSONEncoder.default(self, obj)
 
-
 def predict(doc_id,text,title="",page_time="",**kwargs):
     cost_time = dict()
 
@@ -52,46 +52,51 @@ def predict(doc_id,text,title="",page_time="",**kwargs):
     cost_time["preprocess"] = round(time.time()-start_time,2)
     cost_time.update(_cost_time)
 
-    # 依赖句子顺序
-    start_time = time.time()  # 公告类型/生命周期提取
+    #依赖句子顺序
+    start_time = time.time() # 公告类型/生命周期提取
     list_channel_dic = predictor.getPredictor("channel").predict(title=title, content=list_sentences[0])
     cost_time["channel"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 项目编号、名称提取
+    start_time = time.time() # 项目编号、名称提取
     codeName = predictor.getPredictor("codeName").predict(list_sentences,MAX_AREA=5000,list_entitys=list_entitys)
     log("get codename done of doc_id%s"%(doc_id))
     cost_time["codename"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 角色金额模型提取
+    start_time = time.time() # 角色金额模型提取
     predictor.getPredictor("prem").predict(list_sentences,list_entitys)
     log("get prem done of doc_id%s"%(doc_id))
     cost_time["prem"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 产品名称及废标原因提取
+    start_time = time.time() # 产品名称及废标原因提取
     predictor.getPredictor("product").predict(list_sentences,list_entitys)
     log("get product done of doc_id%s"%(doc_id))
     cost_time["product"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 产品相关要素正则提取 单价、数量、品牌规格 ; 项目、需求、预算、时间
+    start_time = time.time() # 产品相关要素正则提取 单价、数量、品牌规格 ; 项目、需求、预算、时间
     product_attrs = predictor.getPredictor("product_attrs").predict(doc_id, text, page_time)
     log("get product attributes done of doc_id%s"%(doc_id))
     cost_time["product_attrs"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 正则角色提取
+    start_time = time.time() #正则角色提取
     predictor.getPredictor("roleRule").predict(list_articles,list_sentences, list_entitys,codeName)
     cost_time["rule"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 联系人模型提取
+    '''正则补充最后一句实体日期格式为招标或代理 2021/12/30'''
+    start_time = time.time() #正则角色提取
+    predictor.getPredictor("roleRuleFinal").predict(list_articles, list_entitys)
+    cost_time["roleRuleFinal"] = round(time.time()-start_time,2)
+
+    start_time = time.time() #联系人模型提取
     predictor.getPredictor("epc").predict(list_sentences,list_entitys)
     log("get epc done of doc_id%s"%(doc_id))
     cost_time["person"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 时间类别提取
+    start_time = time.time() # 时间类别提取
     predictor.getPredictor("time").predict(list_sentences, list_entitys)
     log("get time done of doc_id%s"%(doc_id))
     cost_time["time"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 保证金支付方式
+    start_time = time.time() # 保证金支付方式
     payment_way_dic = predictor.getPredictor("deposit_payment_way").predict(content=list_articles[0].content)
     cost_time["deposit"] = round(time.time()-start_time,2)
 
@@ -121,7 +126,7 @@ def predict(doc_id,text,title="",page_time="",**kwargs):
     log("get attributes done of doc_id%s"%(doc_id))
     cost_time["attrs"] = round(time.time()-start_time,2)
 
-    start_time = time.time()  # 失信数据要素提取
+    start_time = time.time() #失信数据要素提取
     list_punish_dic = predictor.getPredictor("punish").get_punish_extracts(list_articles,list_sentences, list_entitys)
     cost_time["punish"] = round(time.time()-start_time,2)
 
@@ -129,9 +134,8 @@ def predict(doc_id,text,title="",page_time="",**kwargs):
         for d in product_attrs[1]['demand_info']['data']:
             for product in set(prem[0]['product']):
                 if product in d['project_name']:
-                    d['product'].append(product)  # 把产品在项目名称中的添加进需求要素中
+                    d['product'].append(product)  #把产品在项目名称中的添加进需求要素中
 
-    # print(prem)
     # data_res = Preprocessing.union_result(Preprocessing.union_result(codeName, prem),list_punish_dic)[0]
     # data_res = Preprocessing.union_result(Preprocessing.union_result(Preprocessing.union_result(codeName, prem),list_punish_dic), list_channel_dic)[0]
     data_res = dict(codeName[0], **prem[0], **list_channel_dic[0], **product_attrs[0], **product_attrs[1], **payment_way_dic)

+ 72 - 16
BiddingKG/dl/interface/getAttributes.py

@@ -2237,7 +2237,8 @@ def my_timeFormat(_time):
     return time_list
 
 def getTimeAttributes(list_entity,list_sentence):
-    # list_entity = [i for i in list_entity if i.entity_type=='time']
+    time_entitys = [i for i in list_entity if i.entity_type=='time']
+    time_entitys = sorted(time_entitys,key=lambda x:(x.sentence_index, x.begin_index))
     list_sentence = sorted(list_sentence,key=lambda x:x.sentence_index)
     dict_time = {
         "time_release": [],
@@ -2256,105 +2257,160 @@ def getTimeAttributes(list_entity,list_sentence):
         'time_commencement':[] , #13 开工日期
         'time_completion': []  # 14 竣工日期
     }
-    for entity in list_entity:
-        if entity.label!=0:
-            entity_text = entity.entity_text
-            extract_time = my_timeFormat(entity_text)
-            if extract_time:
-                sentence_text = list_sentence[entity.sentence_index].sentence_text
-                entity_left = sentence_text[max(0,entity.wordOffset_begin-2):entity.wordOffset_begin]
-                entity_right = sentence_text[entity.wordOffset_end:entity.wordOffset_end+3]
-                label_prob = entity.values[entity.label]
+    last_sentence_index = 0
+    last_time_type = ""
+    last_time_index = {
+        'time_bidstart':"time_bidclose",
+        'time_publicityStart':"time_publicityEnd",
+        'time_getFileStart':"time_getFileEnd",
+        'time_registrationStart':"time_registrationEnd",
+        'time_earnestMoneyStart':"time_earnestMoneyEnd",
+        'time_commencement':"time_completion",
+    }
+    for entity in time_entitys:
+        sentence_text = list_sentence[entity.sentence_index].sentence_text
+        entity_left = sentence_text[max(0, entity.wordOffset_begin - 2):entity.wordOffset_begin]
+        entity_right = sentence_text[entity.wordOffset_end:entity.wordOffset_end + 3]
+        label_prob = entity.values[entity.label]
+        entity_text = entity.entity_text
+        extract_time = my_timeFormat(entity_text)
+        if extract_time:
+            if re.search("至|到", entity_left):
+                if entity.sentence_index == last_sentence_index:
+                    time_type = last_time_index.get(last_time_type)
+                    if time_type:
+                        dict_time[time_type].append((extract_time[0], 0.5 + label_prob / 10))
+                        last_time_type = ""
+                        continue
+            if entity.label!=0:
                 if entity.label==1 and label_prob>0.5:
                     dict_time['time_release'].append((extract_time[0],label_prob))
+                    last_time_type = 'time_release'
                 elif entity.label==2 and label_prob>0.5:
                     dict_time['time_bidopen'].append((extract_time[0],label_prob))
+                    last_time_type = 'time_bidopen'
                 elif entity.label==3 and label_prob>0.5:
                     dict_time['time_bidclose'].append((extract_time[0],label_prob))
+                    last_time_type = 'time_bidclose'
                 elif entity.label==12 and label_prob>0.5:
                     if len(extract_time)==1:
-                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                        if re.search("前|止|截止",entity_right) or re.search("至|止|到",entity_left) or re.search("前",entity_text[-2:]):
                             dict_time['time_bidclose'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_bidclose'
                         else:
                             dict_time['time_bidstart'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_bidstart'
                     else:
                         dict_time['time_bidstart'].append((extract_time[0],label_prob))
                         dict_time['time_bidclose'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==4 and label_prob>0.5:
                     if len(extract_time)==1:
-                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                        if re.search("前|止|截止",entity_right) or re.search("至|止|到",entity_left) or re.search("前",entity_text[-2:]):
                             dict_time['time_publicityEnd'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_publicityEnd'
                         else:
                             dict_time['time_publicityStart'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_publicityStart'
                     else:
                         dict_time['time_publicityStart'].append((extract_time[0],label_prob))
                         dict_time['time_publicityEnd'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==5 and label_prob>0.5:
                     if len(extract_time)==1:
                         dict_time['time_publicityEnd'].append((extract_time[0], label_prob))
+                        last_time_type = 'time_publicityEnd'
                     else:
                         dict_time['time_publicityStart'].append((extract_time[0],label_prob))
                         dict_time['time_publicityEnd'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==6 and label_prob>0.5:
                     if len(extract_time)==1:
-                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                        if re.search("前|止|截止",entity_right) or re.search("至|止|到",entity_left) or re.search("前",entity_text[-2:]):
                             dict_time['time_getFileEnd'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_getFileEnd'
                         else:
                             dict_time['time_getFileStart'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_getFileStart'
                     else:
                         dict_time['time_getFileStart'].append((extract_time[0],label_prob))
                         dict_time['time_getFileEnd'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==7 and label_prob>0.5:
                     if len(extract_time)==1:
                         dict_time['time_getFileEnd'].append((extract_time[0], label_prob))
+                        last_time_type = 'time_getFileEnd'
                     else:
                         dict_time['time_getFileStart'].append((extract_time[0],label_prob))
                         dict_time['time_getFileEnd'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==8 and label_prob>0.5:
                     if len(extract_time)==1:
-                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                        if re.search("前|止|截止",entity_right) or re.search("至|止|到",entity_left) or re.search("前",entity_text[-2:]):
                             dict_time['time_registrationEnd'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_registrationEnd'
                         else:
                             dict_time['time_registrationStart'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_registrationStart'
                     else:
                         dict_time['time_registrationStart'].append((extract_time[0],label_prob))
                         dict_time['time_registrationEnd'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==9 and label_prob>0.5:
                     if len(extract_time)==1:
                         dict_time['time_registrationEnd'].append((extract_time[0], label_prob))
+                        last_time_type = 'time_registrationEnd'
                     else:
                         dict_time['time_registrationStart'].append((extract_time[0],label_prob))
                         dict_time['time_registrationEnd'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==10 and label_prob>0.5:
                     if len(extract_time)==1:
-                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                        if re.search("前|止|截止",entity_right) or re.search("至|止|到",entity_left) or re.search("前",entity_text[-2:]):
                             dict_time['time_earnestMoneyEnd'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_earnestMoneyEnd'
                         else:
                             dict_time['time_earnestMoneyStart'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_earnestMoneyStart'
                     else:
                         dict_time['time_earnestMoneyStart'].append((extract_time[0],label_prob))
                         dict_time['time_earnestMoneyEnd'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==11 and label_prob>0.5:
                     if len(extract_time)==1:
                         dict_time['time_earnestMoneyEnd'].append((extract_time[0], label_prob))
+                        last_time_type = 'time_earnestMoneyEnd'
                     else:
                         dict_time['time_earnestMoneyStart'].append((extract_time[0],label_prob))
                         dict_time['time_earnestMoneyEnd'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==13 and label_prob>0.5:
                     if len(extract_time)==1:
-                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                        if re.search("前|止|截止",entity_right) or re.search("至|止|到",entity_left) or re.search("前",entity_text[-2:]):
                             dict_time['time_completion'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_completion'
                         else:
                             dict_time['time_commencement'].append((extract_time[0], label_prob))
+                            last_time_type = 'time_commencement'
                     else:
                         dict_time['time_commencement'].append((extract_time[0],label_prob))
                         dict_time['time_completion'].append((extract_time[1],label_prob))
+                        last_time_type = ''
                 elif entity.label==14 and label_prob>0.5:
                     if len(extract_time)==1:
                         dict_time['time_completion'].append((extract_time[0], label_prob))
+                        last_time_type = 'time_completion'
                     else:
                         dict_time['time_commencement'].append((extract_time[0],label_prob))
                         dict_time['time_completion'].append((extract_time[1],label_prob))
+                        last_time_type = ''
+                else:
+                    last_time_type = ""
+            else:
+                last_time_type = ""
+        else:
+            last_time_type = ""
+        last_sentence_index = entity.sentence_index
 
 
     result_dict = dict((key,"") for key in dict_time.keys())

+ 69 - 30
BiddingKG/dl/interface/predictor.py

@@ -27,15 +27,16 @@ import calendar
 import datetime
 
 from threading import RLock
-dict_predictor = {"codeName": {"predictor": None, "Lock": RLock()},
-                  "prem": {"predictor": None, "Lock": RLock()},
-                  "epc": {"predictor": None, "Lock": RLock()},
-                  "roleRule": {"predictor": None, "Lock": RLock()},
-                  "form": {"predictor": None, "Lock": RLock()},
-                  "time": {"predictor": None, "Lock": RLock()},
-                  "punish": {"predictor": None, "Lock": RLock()},
-                  "product": {"predictor": None, "Lock": RLock()},
-                  "product_attrs": {"predictor": None, "Lock": RLock()},
+dict_predictor = {"codeName":{"predictor":None,"Lock":RLock()},
+              "prem":{"predictor":None,"Lock":RLock()},
+              "epc":{"predictor":None,"Lock":RLock()},
+              "roleRule":{"predictor":None,"Lock":RLock()},
+              "roleRuleFinal":{"predictor":None,"Lock":RLock()},
+                  "form":{"predictor":None,"Lock":RLock()},
+                  "time":{"predictor":None,"Lock":RLock()},
+                  "punish":{"predictor":None,"Lock":RLock()},
+                  "product":{"predictor":None,"Lock":RLock()},
+                "product_attrs":{"predictor":None,"Lock":RLock()},
                   "channel": {"predictor": None, "Lock": RLock()},
                   "deposit_payment_way": {"predictor": None, "Lock": RLock()},
                   "total_unit_money": {"predictor": None, "Lock": RLock()}
@@ -54,6 +55,8 @@ def getPredictor(_type):
                     dict_predictor[_type]["predictor"] = EPCPredict()
                 if _type == "roleRule":
                     dict_predictor[_type]["predictor"] = RoleRulePredictor()
+                if _type == "roleRuleFinal":
+                    dict_predictor[_type]["predictor"] = RoleRuleFinalAdd()
                 if _type == "form":
                     dict_predictor[_type]["predictor"] = FormPredictor()
                 if _type == "time":
@@ -531,7 +534,7 @@ class CodeNamePredict():
         return result
     '''
         
-# 角色金额模型
+#角色金额模型        
 class PREMPredict():
 
     
@@ -664,6 +667,8 @@ class PREMPredict():
                 elif re.search('尊敬的供应商:.{,25}我公司', text):
                     label = 0
                     values[label] = 0.801
+            if label == 1 and re.search('委托(单位|人|方)[是为:]+', text[:10]) and re.search('受委托(单位|人|方)[是为:]+', text[:10])==None:
+                label = 0
             entity.set_Role(label, values)
 
     def predict_money(self,list_sentences,list_entitys):
@@ -710,7 +715,7 @@ class PREMPredict():
         self.predict_money(list_sentences,list_entitys)
         
         
-# 联系人模型
+#联系人模型    
 class EPCPredict():
     
     def __init__(self):
@@ -1050,7 +1055,7 @@ class EPCPredict():
     def predict(self,list_sentences,list_entitys):
         self.predict_person(list_sentences,list_entitys)
             
-# 表格预测
+#表格预测
 class FormPredictor():
     
     def __init__(self,lazyLoad=getLazyLoad()):
@@ -1082,18 +1087,20 @@ class FormPredictor():
         else:
             return self.getModel(type).predict(form_datas)
 
+    
 
-# 角色规则
-# 依据正则给所有无角色的实体赋予角色,给予等于阈值的最低概率
+#角色规则
+#依据正则给所有无角色的实体赋予角色,给予等于阈值的最低概率
 class RoleRulePredictor():
     
     def __init__(self):
-        self.pattern_tenderee_left = "(?P<tenderee_left>((遴选|采购|招标|项目|竞价|议价|需求|最终|建设|业主|转让|招租|甲|议标|合同主体|比选|委托|询价)(?:人|公司|单位|组织|用户|业主|方|部门)|文章来源|需方)(名称)?(是|为|信息|:|:|\s*)$)"
+        # self.pattern_tenderee_left = "(?P<tenderee_left>((遴选|采购|招标|项目|竞价|议价|需求|最终|建设|业主|转让|招租|甲|议标|合同主体|比选|委托|询价)(?:人|公司|单位|组织|用户|业主|方|部门)|文章来源|需方)(名称)?(是|为|信息|:|:|\s*)$)"
+        self.pattern_tenderee_left = "(?P<tenderee_left>((遴选|采购|招标|项目|竞价|议价|需求|最终|建设|业主|转让|招租|甲|议标|合同主体|比选|委托|询价|评选|挂牌|出租|出让|谈判|邀标|邀请|洽谈|约谈|买受|选取|抽取|抽选|出售|标卖|比价)(人|公司|单位|组织|用户|业主|主体|方|部门)|文章来源|委托机构|产权所有人|需方|买方|业主|权属人|甲方当事人)[))]?(名称|信息)?([((](全称|盖章)[))])?(是|为|:|:|,|\s*)+$)"
         self.pattern_tenderee_center = "(?P<tenderee_center>(受.{,20}委托))"
-        self.pattern_tenderee_right = "(?P<tenderee_right>^(\((以下简称)?[\"”]?(招标|采购)(人|单位|机构)\)?))"  #|(^[^.。,,::](采购|竞价|招标|施工|监理|中标|物资)(公告|公示|项目|结果|招标))|的.*正在进行询比价)
-        
-        self.pattern_agency_left = "(?P<agency_left>(代理(?:人|机构|公司|单位|组织)|专业采购机构|集中采购机构|集采机构|招标机构)(.{,4}名,?称|全称|是|为|:|:|[,,]?\s*)$|(受.{,20}委托))"
-        self.pattern_agency_right = "(?P<agency_right>^(\((以下简称)?[\"”]?(代理)(人|单位|机构)\))|受.{,15}委托)"
+        self.pattern_tenderee_right = "(?P<tenderee_right>^([((](以下简称)?[,\"“]*(招标|采购)(人|单位|机构)\)?))|^委托"  #|(^[^.。,,::](采购|竞价|招标|施工|监理|中标|物资)(公告|公示|项目|结果|招标))|的.*正在进行询比价)
+
+        self.pattern_agency_left = "(?P<agency_left>(代理(?:人|机构|公司|单位|组织)|专业采购机构|集中采购机构|集采机构|[议))]+标机构)(.{,4}名,?称|全称|是|为|:|:|[,,]?\s*)$|(受.{,20}委托))"
+        self.pattern_agency_right = "(?P<agency_right>^([((](以下简称)?[,\"“]*(代理)(人|单位|机构)\))|受.{,15}委托)|^受托"
         # 2020//11/24 大网站规则 中标关键词添加 选定单位|指定的中介服务机构
         self.pattern_winTenderer_left = "(?P<winTenderer_left>((中标|中选|中价|乙|成交|承做|施工|供货|承包|竞得|受让)(候选)?(人|单位|机构|各?供应商|方|公司|厂商|商)[::是为]+$|(选定单位|指定的中介服务机构))[::是为,]+$|(第[一1](名|((中标|中选|中价|成交)?(候选)?(人|单位|机构|供应商))))[::是为]+$|((评审结果|名次|排名)[::]第?[一1]名?)$|单一来源(采购)?方式向$|((中标|成交)(结果|信息))(是|为|:|:)$|(单一来源采购(供应商|供货商|服务商))$|[^候选]((分包|标包){,5}供应商|供货商|服务商|供应商名称|服务机构|供方)[::]$)"
         # self.pattern_winTenderer_center = "(?P<winTenderer_center>第[一1].{,20}[是为]((中标|中选|中价|成交|施工)(人|单位|机构|供应商|公司)|供应商)[::是为])"
@@ -1376,6 +1383,46 @@ class RoleRulePredictor():
                 if p_entity.entity_text in self.SET_NOT_TENDERER:
                     p_entity.label=5
 
+'''正则补充最后一句实体日期格式为招标或代理 2021/12/30'''
+class RoleRuleFinalAdd():
+    def predict(self, list_articles, list_entitys):
+        text_end = list_articles[0].content[-30:]
+        sear_ent = re.search('[,。]([\u4e00-\u9fa5()()]{5,20}),\s*.{2,4}年.{1,2}月.{1,2}日', text_end)
+        if sear_ent:
+            ent_re = sear_ent.group(1)
+            tenderee_notfound = True
+            agency_notfound = True
+            ents = []
+            for ent in list_entitys[0]:
+                if ent.entity_type in ['org', 'company']:
+                    if ent.label == 0:
+                        tenderee_notfound = False
+                    elif ent.label == 1:
+                        agency_notfound = False
+                    elif ent.label == 5:
+                        ents.append(ent)
+            if agency_notfound == True and re.search('(采购|招标|投标|交易|代理|拍卖|咨询|顾问|管理)', ent_re):
+                n = 0
+                for i in range(len(ents) - 1, -1, -1):
+                    n += 1
+                    if n > 3:
+                        break
+                    if ents[i].entity_text == ent_re:
+                        ents[i].label = 1
+                        ents[i].values[1] = 0.5
+                        break
+
+            elif tenderee_notfound == True and re.search('(采购|招标|投标|交易|代理|拍卖|咨询|顾问|管理)', ent_re) == None:
+                n = 0
+                for i in range(len(ents) - 1, -1, -1):
+                    n += 1
+                    if n > 3:
+                        break
+                    if ents[i].entity_text == ent_re:
+                        ents[i].label = 0
+                        ents[i].values[0] = 0.5
+                        break
+
 
 # 时间类别
 class TimePredictor():
@@ -1482,7 +1529,6 @@ class TimePredictor():
                         values[0] = 0.5
                 entity.set_Role(label, values)
 
-
 # 产品字段提取
 class ProductPredictor():
     def __init__(self):
@@ -1559,7 +1605,6 @@ class ProductPredictor():
                     result.append(item) # 修正bug
                 return result
 
-
 # 产品数量单价品牌规格提取 #2021/11/10 添加表格中的项目、需求、预算、时间要素提取
 class ProductAttributesPredictor():
     def __init__(self,):
@@ -2031,7 +2076,6 @@ class ProductAttributesPredictor():
             demand_dic = {'demand_info':{'data':[], 'header':[], 'header_col':[]}}
         return [attr_dic, demand_dic]
 
-
 # docchannel类型提取
 class DocChannel():
   def __init__(self, life_model='/channel_savedmodel/channel.pb', type_model='/channel_savedmodel/doctype.pb'):
@@ -2213,7 +2257,6 @@ class DocChannel():
       # return self.id2type[id], prob
       return [{'docchannel':self.id2type[id]}]
 
-
 # 保证金支付方式提取
 class DepositPaymentWay():
     def __init__(self,):
@@ -2293,8 +2336,7 @@ def getSavedModel():
           inputs={"image": model.input},
           outputs={"scores": model.output}
         )
-
-
+        
 def getBiLSTMCRFModel(MAX_LEN,vocab,EMBED_DIM,BiRNN_UNITS,chunk_tags,weights):
     '''
     model = models.Sequential()
@@ -2533,8 +2575,7 @@ def save_person_model():
                                    inputs={"input0":model.input[0],
                                            "input1":model.input[1]},
                                    outputs = {"outputs":model.output})
-
-
+    
 def save_form_model():
     model_form = FormPredictor()
     with model_form.graph.as_default():
@@ -2546,8 +2587,7 @@ def save_form_model():
                                    "./form_savedmodel/",
                                    inputs={"inputs":model.input},
                                    outputs = {"outputs":model.output})
-
-
+    
 def save_codesplit_model():
     filepath_code = "../projectCode/models/model_code.hdf5"
     
@@ -2565,7 +2605,6 @@ def save_codesplit_model():
                                            "input2":model_code.input[2]},
                                    outputs={"outputs":model_code.output})
 
-
 def save_timesplit_model():
     filepath = '../time/model_label_time_classify.model.hdf5'
     with tf.Graph().as_default() as graph:

+ 163 - 4
BiddingKG/dl/product/data_util.py

@@ -4,16 +4,130 @@
 # @Time    : 2021/1/13 0013 14:19
 import re
 import math
+import json
 import random
 import numpy as np
+import pandas as pd
 from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_word,viterbi_decode
 
-id_to_tag = {0:'O',1:'B',2:'I',3:'E'}
+tag2index = {'S':0,'B-pro':1, 'I-pro':2, 'E-pro':3, 'B-rea':4, 'I-rea':5, 'E-rea':6}
+id_to_tag = {v:k for k,v in tag2index.items()}
+# id_to_tag = {0:'O',1:'B',2:'I',3:'E'}
 word_model = getModel_word()
+
 vocab, matrix = getVocabAndMatrix(word_model, Embedding_size=60)
 word2id = {k: v for v, k in enumerate(vocab)}
 max_id = len(vocab)
 
+# vocab = ["<pad>"] + word_model.index2word+ ["<unk>"]
+# matrix = np.zeros((len(vocab), 60))
+# for i in range(1, len(vocab)-1):
+#     matrix[i] = word_model[vocab[i]]
+# max_id = len(vocab)
+# word2id = {k: v for v, k in enumerate(vocab)}
+
+def df2data(df):
+    import pandas as pd
+    import json
+    datas = []
+    for idx in df.index:
+        docid = df.loc[idx, 'docid']
+        text = df.loc[idx, 'text']
+        # string = list(text)
+        tags = [0]*len(text)
+        labels = json.loads(df.loc[idx, 'label'])
+        for label in labels:
+            _, _, begin, end, _ = re.split('\s',label)
+            begin = int(begin)
+            end = int(end)
+            if end-begin>=2:
+                tags[begin]=1
+                tags[end-1]=3
+                for i in range(begin+1,end-1):
+                    tags[i]=2
+        # datas.append([string, tags])
+        text_sentence = []
+        ids_sentence = []
+        tag_sentence = []
+        for i in range(len(text)):
+            text_sentence.append(text[i])
+            ids_sentence.append(word2id.get(text[i], max_id))
+            tag_sentence.append(tags[i])
+            if text[i] in ['。','!']:
+                if text_sentence:
+                    # if len(text_sentence) > 100:
+                    if len(text_sentence)>5 and len(text_sentence)<1000:
+                        datas.append([text_sentence, ids_sentence,tag_sentence])
+                    else:
+                        print('单句小于5或大于1000,句子长度为:%d,文章ID:%s'%(len(text_sentence), docid))
+                    text_sentence = []
+                    ids_sentence = []
+                    tag_sentence = []
+        if text_sentence:
+            # if len(text_sentence) > 5:
+            if len(text_sentence) > 5 and len(text_sentence) < 1000:
+                datas.append([text_sentence, ids_sentence, tag_sentence])
+            else:
+                print('单句小于5或大于1000,句子长度为:%d,文章ID:%s' % (len(text_sentence), docid))
+    return datas
+
+def find_kw_from_text(kw, s):
+    '''
+    输入关键词及句子信息,返回句子中关键词的所有出现位置
+    :param kw: 关键词
+    :param s: 文本
+    :return:
+    '''
+    begin = s.find(kw, 0)
+    kws = []
+    while begin!=-1:
+        end = begin + len(kw)
+        # print(s[begin:end])
+        kws.append((begin, end))
+        begin = s.find(kw, end)
+    return kws
+
+def get_feature(text, lbs):
+    '''
+    输入文章预处理后文本内容及产品名称列表,返回句子列表,数字化句子列表,数字化标签列表
+    :param text: 文本内容
+    :param lbs: 产品名称列表
+    :return:
+    '''
+    lbs = sorted(set(lbs), key=lambda x: len(x), reverse=True)
+    sentences = []
+    ids_list = []
+    tags_list = []
+    for sentence in text.split('。'):
+        if len(sentence) < 5:
+            continue
+        if len(sentence) > 1000:
+            sentence = sentence[:1000]
+        tags = [0] * len(sentence)
+        ids = [word2id.get(word, max_id) for word in sentence]
+        for lb in lbs:
+            kw_indexs = find_kw_from_text(lb, sentence)
+            for indexs in kw_indexs:
+                b, e = indexs
+                if tags[b] == 0 and tags[e - 1] == 0:
+                    tags[b] = 1
+                    tags[e - 1] = 3
+                    for i in range(b+1, e - 1):
+                        tags[i] = 2
+        sentences.append(list(sentence))
+        ids_list.append(ids)
+        tags_list.append(tags)
+    return sentences, ids_list, tags_list
+
+def dfsearchlb(df):
+    datas = []
+    for i in df.index:
+        text = df.loc[i, 'text']
+        lbs = json.loads(df.loc[i, 'lbset'])
+        sentences, ids_list, tags_list = get_feature(text, lbs)
+        for sen, ids, tags in zip(sentences, ids_list, tags_list):
+            datas.append([sen, ids, tags])
+    return datas
 
 def get_label_data():
     import psycopg2
@@ -112,12 +226,16 @@ def result_to_json(line, tags):
     result = []
     ner = []
     tags = ''.join([str(it) for it in tags])
-    for it in re.finditer("12*3", tags):
+    # for it in re.finditer("12*3", tags):
+    #     start = it.start()
+    #     end = it.end()
+    #     ner.append([line[start:end], (start, end)])
+    for it in re.finditer("45*6", tags):
         start = it.start()
         end = it.end()
         ner.append([line[start:end], (start, end)])
     result.append([line, ner])
-    print(tags)
+    # print(tags)
     return result
 
 
@@ -153,4 +271,45 @@ class BatchManager(object):
         if shuffle:
             random.shuffle(self.batch_data)
         for idx in range(self.len_data):
-            yield self.batch_data[idx]
+            yield self.batch_data[idx]
+
+def 获取原始标注数据():
+    import psycopg2
+    import json
+    conn = psycopg2.connect(dbname='iepy_product', user='postgres', password='postgres', host='192.168.2.103')
+    cursor = conn.cursor()
+    sql = "select human_identifier, text from corpus_iedocument where edittime NOTNULL AND jump_signal=0 ;"
+    cursor.execute(sql)
+    writer = open('label_data.txt', 'w', encoding='utf-8')
+    datas = []
+    for row in cursor.fetchall():
+        docid = row[0]
+        text = row[1]
+        sql_lb = "select b.value from brat_bratannotation as b where document_id = '{}' and b.value like 'T%product%';".format(docid)
+        cursor.execute(sql_lb)
+        rows = cursor.fetchall()
+        print('len(rows)', len(rows))
+        datas.append((docid, text, json.dumps(rows, ensure_ascii=False), len(rows)))
+    df = pd.DataFrame(datas, columns=['docid', 'text', 'rows', 'product_num'])
+    df.to_excel('data/产品数据自己人标注的原始数据.xlsx')
+
+
+if __name__=="__main__":
+    # import os
+    import pickle
+    # with open('data/dev_data2.pkl', 'rb') as f:
+    #     dev_data = pickle.load(f)
+    # print(len(dev_data))
+    # print(os.path.exists('data/testdata.xlsx'))
+    # df = pd.read_excel('data/testdata.xlsx')
+    # print(len(df))
+    # data_test = df2data(df)
+    # print(len(data_test), len(data_test[0][0]))
+    # 获取原始标注数据()
+    df = pd.read_excel('data/产品数据自己人标注的原始数据.xlsx')
+    with open('data/dev_data2.pkl', 'rb') as f:
+        dev_data = pickle.load(f)
+    print(len(set(df['docid'])))
+    print('')
+
+

+ 138 - 35
BiddingKG/dl/product/main.py

@@ -3,38 +3,54 @@
 # @Author  : bidikeji
 # @Time    : 2021/1/13 0013 14:03 
 from BiddingKG.dl.product.product_model import Product_Model
-from BiddingKG.dl.product.data_util import BatchManager, get_label_data, id_to_tag, input_from_line, decode, result_to_json
+from BiddingKG.dl.product.data_util import BatchManager, get_label_data, id_to_tag, input_from_line, decode, result_to_json, df2data,dfsearchlb
+from BiddingKG.dl.product.data_process import data_precess
 import numpy as np
+import pandas as pd
 import tensorflow as tf
 import random
 import pickle
 import os
+import glob
+os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
 
 def train():
     # all_data = get_label_data()
     # random.shuffle(all_data)
     # train_data = all_data[:int(len(all_data)*0.85)]
     # dev_data = all_data[int(len(all_data)*0.85):]
-    # with open('data/train_data2.pkl', 'wb') as f:
-    #     pickle.dump(train_data, f)
-    # with open('data/dev_data2.pkl', 'wb') as f:
-    #     pickle.dump(dev_data, f)
 
-    with open('data/train_data2.pkl', 'rb') as f:
-        train_data = pickle.load(f)
-    with open('data/dev_data2.pkl', 'rb') as f:
-        dev_data = pickle.load(f)
+    # df = pd.read_excel('data/所有产品标注数据筛选20211125.xlsx')
+    # df.reset_index(drop=True, inplace=True)
+    # np.random.seed(8)
+    # shuffle_ids = np.random.permutation(len(df))
+    # split_ids = int(len(df)*0.1)
+    # train_ids = shuffle_ids[split_ids:]
+    # dev_ids = shuffle_ids[:int(split_ids/2)]
+    # df_train = df.iloc[train_ids]
+    # df_dev = df.iloc[dev_ids]
+    # train_data = df2data(df_train)
+    # dev_data = df2data(df_dev)
 
-    train_manager = BatchManager(train_data, batch_size=128)
-    dev_manager = BatchManager(dev_data, batch_size=64)
+    # with open(os.path.dirname(__file__)+'/data/train_data2021-11-30.pkl', 'rb') as f:
+    #     train_data = pickle.load(f)
+    # with open(os.path.dirname(__file__)+'data/dev_data2021-11-30.pkl', 'rb') as f:
+    #     dev_data = pickle.load(f)
 
-    tf_config = tf.ConfigProto()
-    tf_config.gpu_options.allow_growth = True
+    train_data, dev_data = data_precess()
+
+    train_manager = BatchManager(train_data, batch_size=256)
+    dev_manager = BatchManager(dev_data, batch_size=256)
+
+    # tf_config = tf.ConfigProto()
+    # tf_config.gpu_options.allow_growth = True
+    tf_config = tf.ConfigProto(device_count={'gpu': 1})
     steps_per_epoch = train_manager.len_data
-    ckpt_path = "model"
+    ckpt_path = os.path.dirname(__file__)+'/'+"model"
     with tf.Session(config=tf_config) as sess:
         model = Product_Model()
         sess.run(tf.global_variables_initializer())
+        model.saver.restore(sess, os.path.join(ckpt_path, "ner2.ckpt"))
         # ckpt = tf.train.get_checkpoint_state(ckpt_path)
         # if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
         #     model.saver.restore(sess, ckpt.model_checkpoint_path)
@@ -44,7 +60,7 @@ def train():
         loss = []
         mix_loss = 1000
         max_f1 = 0
-        for i in range(100):
+        for i in range(20):
             print('epochs:',i)
             # model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
             # break
@@ -53,20 +69,21 @@ def train():
                 # step, batch_loss = model.run_step(sess, True, batch)
                 step, batch_loss = model.run_step(sess, 'train', batch)
                 loss.append(batch_loss)
-                if step % 10 == 0:
+                if step % 1000 == 0:
                     iteration = step // steps_per_epoch + 1
                     print('iter:{} step:{} loss:{}'.format(iteration, step, np.mean(loss)))
-            if i >= 50 or i%5==0:
+            if i >= 2 or i%5==0:
                 f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
                 print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss))
-                if max_f1 < f1:
-                    model.saver.save(sess, os.path.join(ckpt_path, "ner2.ckpt"))
-                    print("model save .bast f1 is %.4f" % f1)
+                # if max_f1 < f1:
+                #     model.saver.save(sess, os.path.join(ckpt_path, "ner2.ckpt"))
+                #     print("model save .bast f1 is %.4f" % f1)
+                #     max_f1 = f1
+                if evl_loss<mix_loss and max_f1 < f1:
+                    mix_loss = evl_loss
                     max_f1 = f1
-                    # if np.mean(loss)<mix_loss:
-                    #     mix_loss = np.mean(loss)
-                    #     model.saver.save(sess, os.path.join(ckpt_path, "ner.ckpt"))
-                    #     print("model saved, loss is:",mix_loss)
+                    model.saver.save(sess, os.path.join(ckpt_path, "ner1202_find_lb.ckpt")) #ner1130_find_lb.ckpt
+                    print("model saved, val_loss is:",mix_loss)
                 loss = []
 
 def evaluate_line():
@@ -74,15 +91,22 @@ def evaluate_line():
     with tf.Session() as sess:
         model = Product_Model()
         sess.run(tf.global_variables_initializer())
-        ckpt = tf.train.get_checkpoint_state(ckpt_path)
-        if ckpt and tf.train.checkpoint_exists(ckpt_path):
-            print('模型文件:',ckpt.model_checkpoint_path)
-            model.saver.restore(sess, ckpt.model_checkpoint_path)
-            print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
-            while True:
-                line = input("请输入测试句子:")
-                result = model.evaluate_line(sess, line)
-                print(result)
+        # model.saver.restore(sess, 'model/ner1215.ckpt')
+        # model.saver.restore(sess, 'model/ner_f10.7039_loss1.2353.ckpt')
+        model.saver.restore(sess, 'model/ner_epoch10_f10.6875_loss1.5230.ckpt')
+        while True:
+            line = input("请输入测试句子:")
+            result = model.evaluate_line(sess, line)
+            print(result)
+        # ckpt = tf.train.get_checkpoint_state(ckpt_path)
+        # if ckpt and tf.train.checkpoint_exists(ckpt_path):
+        #     print('模型文件:',ckpt.model_checkpoint_path)
+        #     model.saver.restore(sess, ckpt.model_checkpoint_path)
+        #     print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
+        #     while True:
+        #         line = input("请输入测试句子:")
+        #         result = model.evaluate_line(sess, line)
+        #         print(result)
 def predict():
     pb_path = "model/product.pb"
     with tf.Graph().as_default():
@@ -111,7 +135,86 @@ def predict():
                     result = result_to_json(line, tags)
                     print(result)
 
+def predict_df():
+    ckpt_path = "model"
+    import json
+    with tf.Session() as sess:
+        model = Product_Model()
+        sess.run(tf.global_variables_initializer())
+        ckpt = tf.train.get_checkpoint_state(ckpt_path)
+        # model.saver.restore(sess, 'model/ner2.ckpt')
+        # model.saver.restore(sess, 'model/ner1201_find_lb.ckpt')  # f1:0.6972, precision:0.7403, recall:0.6588, evl_loss:1.2983  model saved, val_loss is: 1.32706
+        # model.saver.restore(sess, 'model/ner1208_find_lb.ckpt')  # f1:0.7038, precision:0.7634, recall:0.6528, evl_loss:1.3046 model saved, val_loss is: 1.29316
+        # model.saver.restore(sess, 'model/ner_f10.7039_loss1.2353.ckpt')  # f1:0.70 ner1215
+        model.saver.restore(sess, 'model/ner_epoch4_f10.6952_loss1.2512.ckpt')  # f1:0.70 ner1215
+
+        print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
+        # df = pd.read_excel('../test/data/贵州数据新字段提取信息_predict.xlsx')
+        # df = pd.read_excel('../test/data/所有产品标注数据_补充筛选废标原因数据.xlsx')
+        # df = pd.read_excel('../test/data/所有产品标注数据筛选_废标_predict.xlsx')
+        df = pd.read_excel('../test/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
+        # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
+        df.reset_index(drop=True, inplace=True)
+        rs = []
+        for i in df.index:
+            line = df.loc[i, 'text']
+            pos = df.loc[i, 'pos']
+            reason = df.loc[i, 'reasons_label']
+            if pos==0 or reason!='[]':
+                rs.append('')
+                continue
+            # if i > 200:
+            #     rs.append('')
+            #     continue
+            # line = df.loc[i, 'process_text']
+            result = model.evaluate_line(sess, line)
+            print(result[0][1])
+            rs.append(json.dumps(result[0][1], ensure_ascii=False))
+        # df['pred_new1202'] = pd.Series(rs)
+        df['reson_model'] = pd.Series(rs)
+        # df.to_excel('../test/data/贵州数据新字段提取信息_predict.xlsx')
+        # df.to_excel('../test/data/所有产品标注数据_补充筛选废标原因数据_predict.xlsx')
+        # df.to_excel('../test/data/所有产品标注数据筛选_废标_predict.xlsx')
+        df.to_excel('../test/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
+        # df.to_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
+
 if __name__ == "__main__":
     # train()
-    # evaluate_line()
-    predict()
+    evaluate_line()
+    # predict()
+    # predict_df()
+    # import json
+    # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
+    # old_new = []
+    # new_old = []
+    # df['old-new'] = df.apply(lambda x:set([str(it) for it in json.loads(x['pred_old'])])-set([str(it) for it in json.loads(x['pred_new'])]), axis=1)
+    # df['new-old'] = df.apply(lambda x:set([str(it) for it in json.loads(x['pred_new'])])-set([str(it) for it in json.loads(x['pred_old'])]), axis=1)
+    # df['old=new'] = df.apply(lambda x: 1 if x['old-new']==x['new-old'] else 0, axis=1)
+    # df.to_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
+
+
+    # with open('data/dev_data2.pkl', 'rb') as f:
+    #     dev_data = pickle.load(f)
+    # import json
+    # df_dev = pd.read_excel('data/产品数据自己人标注的原始数据.xlsx')[:]
+    # def rows2lb(rows):
+    #     rows = json.loads(rows)
+    #     rows = list(set([it[0].split()[-1] for it in rows]))
+    #     return json.dumps(rows, ensure_ascii=False)
+    # df_dev['lbset'] = df_dev['rows'].apply(lambda x:rows2lb(x))
+    # dev_data = dfsearchlb(df_dev)
+    # dev_manager = BatchManager(dev_data, batch_size=64)
+    # # ckpt_path = "model/ner0305.ckpt" #f1:0.7304, precision:0.8092, recall:0.6656, evl_loss:2.2160
+    # # ckpt_path = "model/ner0316.ckpt" #f1:0.7220, precision:0.7854, recall:0.6681, evl_loss:2.2921
+    # # ckpt_path = "model/ner2.ckpt" # f1:0.8019, precision:0.8541, recall:0.7557, evl_loss:1.6286
+    # # ckpt_path = "model/ner1029.ckpt" #f1:0.6374, precision:0.6897, recall:0.5924, evl_loss:2.0840
+    # # ckpt_path = "model/ner1129.ckpt" #f1:0.6034, precision:0.6931, recall:0.5343, evl_loss:1.9704
+    # ckpt_path = "model/ner1129.ckpt" #f1:0.6034, precision:0.6931, recall:0.5343, evl_loss:1.9704
+    # with tf.Session() as sess:
+    #     model = Product_Model()
+    #     sess.run(tf.global_variables_initializer())
+    #     model.saver.restore(sess, ckpt_path)
+    #     print("从文件加载原来模型数据",ckpt_path)
+    #     f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
+    #     print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss))
+

+ 17 - 13
BiddingKG/dl/product/product_model.py

@@ -6,16 +6,17 @@
 from BiddingKG.dl.product.data_util import matrix,vocab,input_from_line,result_to_json,get_ner
 import tensorflow as tf
 import numpy as np
-# from tensorflow.contrib.crf import crf_log_likelihood
-# from tensorflow.contrib.crf import viterbi_decode
-# from tensorflow.contrib.layers.python.layers import initializers
+from tensorflow.contrib.crf import crf_log_likelihood
+from tensorflow.contrib.crf import viterbi_decode
+from tensorflow.contrib.layers.python.layers import initializers
 
 # word_model = getModel_word()
 class Product_Model(object):
     def __init__(self):
         self.char_dim = 60
-        self.lstm_dim = 128
-        self.num_tags = 4
+        self.lstm_dim = 120#128 120
+        # self.num_tags = 4
+        self.num_tags = 7
         self.lr = 0.001
         self.clip = 5.0
         self.dropout_rate = 0.5
@@ -34,6 +35,7 @@ class Product_Model(object):
         self.char_inputs = tf.placeholder(dtype=tf.int32, shape=[None, None],name='CharInputs')
         self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None],name='Targets')
         self.dropout = tf.placeholder(dtype=tf.float32, name='Dropout')
+        # self.lengths = tf.placeholder(dtype=tf.int32, shape=[None],name='lengths')
 
         used = tf.sign(tf.abs(self.char_inputs))
         length = tf.reduce_sum(used, reduction_indices=1)
@@ -207,14 +209,16 @@ class Product_Model(object):
                 # Recall.append(recall_temp)
                 # F1.append(f1_temp)
 
-                # for char, gold, pred in zip(string, gold, pred):
-                #     result.append(" ".join([char, gold, pred]))
-                # results.append(result)
-                # with open('evaluate_result.txt','w', encoding='utf-8') as f:
-                #     for rs in results:
-                #         for line in rs:
-                #             f.write(line+'\n')
-                #         f.write('\n')
+                if gold_ner!=pred_ner:
+                    for char, gold, pred in zip(string, gold, pred):
+                            result.append(" ".join([char, gold, pred]))
+                    # print(result)
+                    results.append(result)
+                with open('evaluate_result.txt','w', encoding='utf-8') as f:
+                    for rs in results:
+                        for line in rs:
+                            f.write(line+'\n')
+                        f.write('\n')
 
         # return sum(F1)/len(F1),sum(Precision)/len(Precision),sum(Recall)/len(Recall)
         precision = equal_num/(pred_num+1e-10)