Browse Source

Merge branch 'master' of http://192.168.2.103:3000/luojiehua/BIDI_ML_INFO_EXTRACTION

znj 2 years ago
parent
commit
1ae62d4734

+ 20 - 17
BiddingKG/dl/interface/Preprocessing.py

@@ -8,7 +8,7 @@ import time
 import codecs
 
 from BiddingKG.dl.ratio.re_ratio import extract_ratio
-# from BiddingKG.dl.table_head.predict import predict
+from BiddingKG.dl.table_head.predict import predict
 
 sys.setrecursionlimit(1000000)
 sys.path.append(os.path.abspath("../.."))
@@ -422,19 +422,16 @@ def tableToText(soup):
     def set_head_model(inner_table):
         for i in range(len(inner_table)):
             for j in range(len(inner_table[i])):
-                inner_table[i][j] = inner_table[i][j][0]
+                # 删掉单格前后符号,以免影响表头预测
+                col = inner_table[i][j][0]
+                col = re.sub("^[^\u4e00-\u9fa5a-zA-Z0-9]+", "", col)
+                col = re.sub("[^\u4e00-\u9fa5a-zA-Z0-9]+$", "", col)
+                inner_table[i][j] = col
 
         # 模型预测表头
         predict_list = predict(inner_table)
-        with open(r"C:\Users\Administrator\Desktop\table_head_test.txt", "a") as f:
-            for i in range(len(predict_list)):
-                f.write(str(i) + " " + str(inner_table[i]) + "\n")
-                f.write(str(i) + " " + str(predict_list[i]) + "\n")
-            f.write("\n")
-
-        # print("table_list", inner_table)
-        # print("predict_list", predict_list)
 
+        # 组合结果
         for i in range(len(inner_table)):
             for j in range(len(inner_table[i])):
                 inner_table[i][j] = [inner_table[i][j], int(predict_list[i][j])]
@@ -1011,10 +1008,10 @@ def tableToText(soup):
         if len(inner_table)>0 and len(inner_table[0])>0:
             #inner_table,head_list = setHead_withRule(inner_table,pat_head,pat_value,3)
             #inner_table,head_list = setHead_inline(inner_table)
-            inner_table, head_list = setHead_initem(inner_table,pat_head)
-            # inner_table, head_list = set_head_model(inner_table)
+            # inner_table, head_list = setHead_initem(inner_table,pat_head)
+            inner_table, head_list = set_head_model(inner_table)
             # inner_table,head_list = setHead_incontext(inner_table,pat_head)
-            # print(inner_table)
+            # print("table_head", inner_table)
             # for begin in range(len(head_list[:-1])):
             #     for item in inner_table[head_list[begin]:head_list[begin+1]]:
             #         print(item)
@@ -1258,7 +1255,6 @@ def segment(soup,final=True):
         # if child.name in spaceList:
         #     child.insert_after(" ")
     text = str(soup.get_text())
-    # print('text',text)
     #替换英文冒号为中文冒号
     text = re.sub("(?<=[\u4e00-\u9fa5]):|:(?=[\u4e00-\u9fa5])",":",text)
     #替换为中文逗号
@@ -1833,11 +1829,10 @@ def article_limit(soup,limit_words=30000):
     for child in soup.find_all(recursive=True):
         if child.name == 'div' and 'class' in child.attrs:
             if "richTextFetch" in child['class']:
-                child.insert_before("##attachment##")
+                child.insert_before("##attachment##") # 句号分开,避免项目名称等提取
                 attachment_part = child
                 have_attachment = True
                 break
-
     if not have_attachment:
         # 无附件
         if len(re.sub(sub_space, "", soup.get_text())) > limit_words:
@@ -1909,7 +1904,6 @@ def attachment_filelink(soup):
         # print('格式化输出',soup.prettify())
         return soup
 
-
 def get_preprocessed_article(articles,cost_time = dict(),useselffool=True):
     '''
     :param articles: 待处理的article source html
@@ -1962,10 +1956,14 @@ def get_preprocessed_article(articles,cost_time = dict(),useselffool=True):
         # print('article_processed')
         article_processed = tableToText(article_processed)
         article_processed = segment(article_processed)
+
+        article_processed = article_processed.replace('(', '(').replace(')', ')')  #2022/8/10 统一为中文括号
         article_processed = article_processed.replace('.','.') # 2021/12/01 修正OCR识别PDF小数点错误问题
         article_processed = article_processed.replace('报价限价', '招标限价') #2021/12/17 由于报价限价预测为中投标金额所以修改
         article_processed = article_processed.replace('成交工程价款', '成交工程价')  # 2021/12/21 修正为中标价
         article_processed = re.sub('任务(?=编号[::])', '项目',article_processed)  # 2022/08/10 修正为项目编号
+        article_processed = article_processed.replace('招标(建设)单位', '招标单位')  #2022/8/10 修正预测不到表达
+        article_processed = re.sub('(招标|采购)人(概况|信息)[,。]', '采购人信息:', article_processed)  # 2022/8/10统一表达
         # 修复OCR金额中“,”、“。”识别错误
         article_processed_list = article_processed.split("##attachment##")
         if len(article_processed_list)>1:
@@ -2017,6 +2015,7 @@ def get_preprocessed_article(articles,cost_time = dict(),useselffool=True):
         if key_preprocess not in cost_time:
             cost_time[key_preprocess] = 0
         cost_time[key_preprocess] += round(time.time()-start_time,2)
+
         #article_processed = article[1]
         _article = Article(doc_id,article_processed,sourceContent,_send_doc_id,_title,
                            bidway=bidway)
@@ -2043,6 +2042,10 @@ def get_preprocessed_sentences(list_articles,useselffool=True,cost_time=dict()):
         key_preprocess = "tableToText"
         start_time = time.time()
         article_processed = article.content
+
+        if len(_title)<100 and _title not in article_processed: # 把标题放到正文
+            article_processed = _title + article_processed
+
         attachment_begin_index = -1
 
         if key_preprocess not in cost_time:

+ 5 - 2
BiddingKG/dl/interface/getAttributes.py

@@ -360,11 +360,14 @@ def get_dict_entity_prob(list_entity,on_value=0.5):
                 if role_prob>=on_value and str(entity.label)!="5":
                     _key_prob = _key+"$text$"+entity.entity_text
                     if in_attachment == True:
+                        role_prob = 0.8 if role_prob>0.8 else role_prob   #附件的概率修改低点
                         if entity.entity_text in identified_role:
                             continue
                     if _key_prob in dict_pack_entity_prob:
-                        if role_prob>dict_pack_entity_prob[_key_prob][1]:
-                            dict_pack_entity_prob[_key_prob] = [entity.entity_text,role_prob]
+                        new_prob = role_prob+dict_pack_entity_prob[_key_prob][1]
+                        dict_pack_entity_prob[_key_prob] = [entity.entity_text, new_prob] #公司同角色多次出现概率累计
+                        # if role_prob>dict_pack_entity_prob[_key_prob][1]:
+                        #     dict_pack_entity_prob[_key_prob] = [entity.entity_text,role_prob]
                     else:
                         dict_pack_entity_prob[_key_prob] = [entity.entity_text,role_prob]
     return dict_pack_entity_prob

+ 37 - 18
BiddingKG/dl/interface/predictor.py

@@ -1111,27 +1111,31 @@ class RoleRulePredictor():
     
     def __init__(self):
         # (?P<tenderee_left_w1> 正则组名 后面的 w1 为概率权重关键词
-        self.pattern_tenderee_left = "(?P<tenderee_left>((项目|需求|最终|建设|业主|转让|招租|甲|议标|合同主体|挂牌|出租|出让|买受|选取|抽取|抽选|出售|标卖|比价|处置)" \
-                                "(人|公司|单位|组织|用户|业主|主体|方|部门)|文章来源|委托机构|产权所有人|需求?方|买方|业主|(业主|采购人|招标人)联系方式[,:]公司名称:|权属人|甲方当事人|询价书企业|比选发起人|项目单位[,:]单位名称|结算单位)"\
-                                "[))]?(信息[,:])?((公司|单位)?名称)?([((](全称|盖章)[))])?(是|为|:|:|\s*)+$)"
-        self.pattern_tenderee_left_w1 = "(?P<tenderee_left_w1>((遴选|采购|招标|竞价|议价|比选|委托|询比?价|评选|谈判|邀标|邀请|洽谈|约谈)" \
+        self.pattern_tenderee_left = "(?P<tenderee_left>((遴选|寻源|采购|招标|竞价|议价|比选|委托|询比?价|比价|评选|谈判|邀标|邀请|洽谈|约谈|选取|抽取|抽选|项目|需求|最终|建设|业主|甲|转让|招租|议标|合同主体|挂牌|出租|出让|买受|出售|标卖|处置)" \
+                                    "(人|方|单位|组织|用户|业主|主体|部门|公司)|文章来源|委托机构|产权所有人|需求?方|买方|业主|权属人|甲方当事人|询价书企业|比选发起人|结算单位)"\
+                                    "[))]?(信息|联系方式|概况)?[,,::]?([((](1|2|1.1|1.2)[))])?((公司|单位)?名称)?([((](全称|盖章)[))])?(是|为|:|:|\s*)+$)"
+        self.pattern_tenderee_left_w0 = "(?P<tenderee_left>(,|。|^)(项目)?((遴选|寻源|采购|招标|竞价|议价|比选|委托|询比?价|比价|评选|谈判|邀标|邀请|洽谈|约谈|选取|抽取|抽选|项目|需求|最终|建设|业主|甲|转让|招租|议标|合同主体|挂牌|出租|出让|买受|出售|标卖|处置)" \
+                                        "(人|方|单位|组织|用户|业主|主体|部门|公司)|文章来源|委托机构|产权所有人|需求?方|买方|业主|权属人|甲方当事人|询价书企业|比选发起人|结算单位)"\
+                                        "[))]?(信息|联系方式|概况)?[,,。::]?([((]?(1|2|1.1|1.2)[))]?)?((公司|单位)?名称)?([((](全称|盖章)[))])?(是|为|:|:|,|\s*)+$)"
+        self.pattern_tenderee_left_w1 = "(?P<tenderee_left_w1>(,|。|^)(项目)?((遴选|寻源|采购|招标|竞价|议价|比选|委托|询比?价|比价|评选|谈判|邀标|邀请|洽谈|约谈|选取|抽取|抽选)" \
                                      "(人|公司|单位|组织|用户|业主|主体|方|部门))" \
-                                     "(信息[,:])?((公司|单位)?名称)?([((](全称|盖章)[))])?(是|为|:|:|\s*)+$)"
-        self.pattern_tenderee_center = "(?P<tenderee_center>(受.{5,20}委托))"
-        self.pattern_tenderee_right = "(?P<tenderee_right>^([((](以下简称)?[,\"“]*(招标|采购)(人|单位|机构)[,\"”]*[))])|^委托|^现委托|^的\w{2,10}正在进行)"  #|(^[^.。,,::](采购|竞价|招标|施工|监理|中标|物资)(公告|公示|项目|结果|招标))|的.*正在进行询比价)
+                                     "(是|为|:|:|\s*)+$)"
+        self.pattern_tenderee_center = "(?P<tenderee_center>(受.{5,20}委托|现将[\w()()]{5,20}[\d年月季度至]+采购意向))"
+        self.pattern_tenderee_right = "(?P<tenderee_right>^([((](以下简称)?[,\"“]*(招标|采购)(人|单位|机构)[,\"”]*[))]|^委托|^将于[\d年月日,::]+进行|^现委托|^的\w{2,10}正在进行|[\d年月季度至]+采购意向))"  #|(^[^.。,,::](采购|竞价|招标|施工|监理|中标|物资)(公告|公示|项目|结果|招标))|的.*正在进行询比价)
         self.pattern_tendereeORagency_right = "(?P<tendereeORagency_right>(^拟对|^现?就|^现对))"
         self.pattern_agency_left = "(?P<agency_left>(代理(?:人|机构|公司|单位|组织)|专业采购机构|集中采购机构|招标组织机构|集采机构|[招议))]+标机构)(名称)?(.{,4}名,?称|全称|是|为|:|:|[,,]?\s*)$|(受.{5,20}委托,?$))"
         self.pattern_agency_right = "(?P<agency_right>^([((](以下简称)?[,\"“]*(代理)(人|单位|机构)[,\"”]*[))])|^受.{5,20}委托|^受委?托,)"  # |^受托  会与 受托生产等冲突,代理表达一般会在后面有逗号
         # 2020//11/24 大网站规则 中标关键词添加 选定单位|指定的中介服务机构
-        self.pattern_winTenderer_left = "(?P<winTenderer_left>(乙|承做|施工|供货|承包|承建|竞得|受让|签约)(候选)?(人|单位|机构|供应商|方|公司|厂商|商)[::是为]+$|" \
+        self.pattern_winTenderer_left = "(?P<winTenderer_left>(乙|承做|施工|供货|承包|承建|承租|竞得|受让|签约)(候选)?(人|单位|机构|供应商|方|公司|厂商|商)[::是为]+$|" \
                                         "(选定单位|指定的中介服务机构|实施主体|承制单位|供方)[::是为]+$|((评审结果|名次|排名|中标结果)[::]*第?[一1]名?)[::是为]+$|" \
                                         "单一来源(采购)?(供应商|供货商|服务商|方式向)$|((中标|成交)(结果|信息))[::是为]+$|(供应|供货|承销|服务|实施)(机构|单位|商|方)(名称)?[::是为]+$)"
-        self.pattern_winTenderer_left_w1 = "(?P<winTenderer_left_w1>(中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|方|公司|厂商|商)(名称)?([((]按综合排名排序[))])?[::是为]+$)" #取消逗号 并拒绝执行改进计划的供应商,华新水泥将可能终止与其合作关系
+        self.pattern_winTenderer_left_w0 = "(?P<winTenderer_left_w1>(,|。|^)((中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|客户|方|公司|厂商|商)|第?[一1]名)(名称)?[,,]?([((]按综合排名排序[))])?[::,,]$)" #解决表头识别不到加逗号情况,需前面为,。空
+        self.pattern_winTenderer_left_w1 = "(?P<winTenderer_left_w1>(中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|客户|方|公司|厂商|商)(名称)?([((]按综合排名排序[))])?[::是为]+$)" #取消逗号 并拒绝执行改进计划的供应商,华新水泥将可能终止与其合作关系
         # self.pattern_winTenderer_center = "(?P<winTenderer_center>第[一1].{,20}[是为]((中标|中选|中价|成交|施工)(人|单位|机构|供应商|公司)|供应商)[::是为])"
         # self.pattern_winTenderer_right = "(?P<winTenderer_right>(^[是为\(]((采购(供应商|供货商|服务商)|(第[一1]|预)?(拟?(中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司|厂商)))))|^(报价|价格)最低,确定为本项目成交供应商)"
         self.pattern_winTenderer_right = "(?P<winTenderer_right>(^[是为]((采购|中标)(供应商|供货商|服务商)|(第[一1]|预)?(拟?(中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司|厂商)))|" \
-                                        "^(报价|价格)最低,确定为本项目成交供应商|^:贵公司参与|^:?你方于|^中标。|^成为[\w、()()]+项目的成交供应商))"
-        self.pattern_winTenderer_whole = "(?P<winTenderer_center>贵公司.{,15}以.{,15}中标|最终由.{,15}竞买成功|经.{,15}决定[以由].{,15}公司中标|谈判结果:由.{5,20}供货)|中标通知书.{,15}你方"   # 2020//11/24 大网站规则 中标关键词添加 谈判结果:由.{5,20}供货
+                                        "^(报价|价格)最低,确定为本项目成交供应商|^:贵公司参与|^:?你方于|^中标。|^成为[\w、()()]+项目的成交供应商|^[((]中标人名称[))]))"
+        self.pattern_winTenderer_whole = "(?P<winTenderer_center>贵公司.{,15}以.{,15}中标|最终由.{,15}竞买成功|经.{,15}决定[以由].{,15}公司中标|(谈判结果:|确定)由.{5,20}(向我单位)?供货)|中标通知书.{,15}你方"   # 2020//11/24 大网站规则 中标关键词添加 谈判结果:由.{5,20}供货
 
         # self.pattern_winTenderer_location = "(中标|中选|中价|乙|成交|承做|施工|供货|承包|竞得|受让)(候选)?(人|单位|机构|供应商|方|公司|厂商|商)|(供应商|供货商|服务商)[::]?$|(第[一1](名|((中标|中选|中价|成交)?(候选)?(人|单位|机构|供应商))))(是|为|:|:|\s*$)|((评审结果|名次|排名)[::]第?[一1]名?)|(单一来源(采购)?方式向.?$)"
 
@@ -1141,15 +1145,17 @@ class RoleRulePredictor():
         self.pattern_thirdTenderer_left = "(?P<thirdTenderer_left>(第[三3](名|((中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司))))(名称)?[::是为]+$|((评审结果|名次|排名)[::]第?[三3]名?,?投标商名称[::]+$))"
         self.pattern_thirdTenderer_right = "(?P<thirdTenderer_right>^[是为\(]第[三3](名|(中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司)))"
 
-        self.pattern_whole = [self.pattern_tenderee_left,
-                              self.pattern_tenderee_left_w1,
+        self.pattern_whole = [self.pattern_tenderee_left_w1,
+                              self.pattern_tenderee_left,
+                              self.pattern_tenderee_left_w0,
                               self.pattern_tenderee_center,
                               self.pattern_tenderee_right,
                               self.pattern_tendereeORagency_right,
                               self.pattern_agency_left,
                               self.pattern_agency_right,
-                              self.pattern_winTenderer_left,
                               self.pattern_winTenderer_left_w1,
+                              self.pattern_winTenderer_left,
+                              self.pattern_winTenderer_left_w0,
                               self.pattern_winTenderer_whole,
                               self.pattern_winTenderer_right,
                               self.pattern_secondTenderer_left,
@@ -1161,7 +1167,7 @@ class RoleRulePredictor():
         self.SET_NOT_TENDERER = set(["人民政府","人民法院","中华人民共和国","人民检察院","评标委员会","中国政府","中国海关","中华人民共和国政府"])
         
         self.pattern_money_tenderee = re.compile("投标最高限价|采购计划金额|项目预算|招标金额|采购金额|项目金额|建安费用|采购(单位|人)委托价|限价|拦标价|预算金额")
-        self.pattern_money_tenderer = re.compile("((合同|成交|中标|应付款|交易|投标|验收)[)\)]?(总?金额|结果|[单报]?价))|总价|标的基本情况")
+        self.pattern_money_tenderer = re.compile("((合同|成交|中标|应付款|交易|投标|验收|订单)[)\)]?(总?金额|结果|[单报]?价))|总价|标的基本情况")
         self.pattern_money_tenderer_whole = re.compile("(以金额.*中标)|中标供应商.*单价|以.*元中标")
         self.pattern_money_other = re.compile("代理费|服务费")
         self.pattern_pack = "(([^承](包|标[段号的包]|分?包|包组)编?号?|项目)[::]?[\((]?[0-9A-Za-z一二三四五六七八九十]{1,4})[^至]?|(第?[0-9A-Za-z一二三四五六七八九十]{1,4}(包号|标[段号的包]|分?包))|[0-9]个(包|标[段号的包]|分?包|包组)"
@@ -1264,9 +1270,20 @@ class RoleRulePredictor():
                                                    word_flag=True, use_text=False)
                                 # _flag = False
 
+                                # 添加中标通知书类型特殊处理
+                                try:
+                                    if s_index == 0 and re.search('中标通知书.{,30}[,:]%s:'%p_entity.entity_text.replace('(', '').replace(')', ''),
+                                                                  list_sentence[s_index].sentence_text.replace('(', '').replace(')', '')[:100]):
+                                        p_entity.label = 2
+                                        p_entity.values[2] = 0.5
+                                        # log('正则召回实体: %s, %s, %s, %d, %.4f, %s'%(_group,  _v_group, p_entity.entity_text, p_entity.label, p_entity.values[p_entity.label], list_spans[_i_span]))
+                                        break
+                                except Exception as e:
+                                    print('正则报错:', e)
+
                                 # 使用正则+距离解决冲突
                                 # 2021/6/11update center: spans[1] --> spans[0][-30:]+spans[1]
-                                list_spans = [spans[0][-30:], spans[0][-10:] + spans[1] + spans[2][:10], spans[2]] # 实体左、中、右 信息
+                                list_spans = [spans[0][-30:], spans[0][-10:] + spans[1] + spans[2][:25], spans[2]] # 实体左、中、右 信息
                                 for _i_span in range(len(list_spans)):
                                     _flag = False
                                     _prob_weight = 1
@@ -1408,15 +1425,17 @@ class RoleRuleFinalAdd():
         sear_ent = re.search('[,。;]([\u4e00-\u9fa5()()]{5,20}(,?[\u4e00-\u9fa5]{,6}(分公司|部))?),?\s*[0-9零一二三四五六七八九十〇]{2,4}[年\-/][0-9零一二三四五六七八九十]{1,2}[月\-/][0-9零一二三四五六七八九十]{1,2}日?', text_end)
         sear_ent2 = re.search('(户名|开户名称)[::]([\u4e00-\u9fa5()()]{5,20})[,。]', list_articles[0].content[:5000])
         sear_ent3 = re.search('(报名咨询|[收送交]货地点)[,:]([\u4e00-\u9fa5()()]{5,20})[0-9\-]*[,。]', list_articles[0].content[:5000])
-        sear_ent4 = re.search('(发布(?:人|单位|机构))[::]([\u4e00-\u9fa5()()]{5,20})[,。]', list_articles[0].content[:5000])
+        sear_ent4 = re.search('(发布(?:人|单位|机构|企业)|项目业主)[::]([\u4e00-\u9fa5()()]{5,20})[,。]', list_articles[0].content[:5000])
         sear_list = [sear_ent4 , sear_ent3 , sear_ent2 , sear_ent]
 
         tenderee_notfound = True
         agency_notfound = True
+        tenderee_list = []
         ents = []
         for ent in list_entitys[0]:
             if ent.entity_type in ['org', 'company']:
                 if ent.label == 0:
+                    tenderee_list.append(ent.entity_text)
                     tenderee_notfound = False
                 elif ent.label == 1:
                     agency_notfound = False
@@ -1457,7 +1476,7 @@ class RoleRuleFinalAdd():
                             tenderee_notfound = False
                             # log('正则最后补充实体: %s'%(ent_re))
                             break
-                elif agency_notfound == True and re.search('(采购|招标|投标|交易|代理|拍卖|咨询|顾问|管理)', ent_re):
+                elif agency_notfound == True and re.search('(采购|招标|投标|交易|代理|拍卖|咨询|顾问|管理)', ent_re) and ent_re not in tenderee_list:
                     n = 0
                     for i in range(len(ents) - 1, -1, -1):
                         if not ents[i].in_attachment:

+ 2 - 3
BiddingKG/dl/table_head/models/model.py

@@ -1,10 +1,9 @@
 import sys
 import os
 import numpy as np
+sys.path.append(os.path.abspath(os.path.dirname(__file__)))
 from keras.layers import Lambda, Dense, Reshape, Bidirectional, LSTM, Conv2D, BatchNormalization, LeakyReLU, Masking
 from keras.preprocessing.sequence import pad_sequences
-sys.path.append(os.path.dirname(__file__))
-
 from models.layer_utils import BatchReshape1, BatchReshape2, MyPadding, MySplit, BatchReshape3, \
     BatchReshape4, BatchReshape5, BatchReshape6
 from keras import layers, models, Sequential
@@ -70,7 +69,7 @@ def model_1(input_shape, output_shape):
     model = models.Model(inputs=[input_1, input_2, input_3, input_4, input_5, input_6],
                          outputs=output)
 
-    model.summary()
+    # model.summary()
     return model
 
 

+ 3 - 2
BiddingKG/dl/table_head/post_process.py

@@ -1,7 +1,8 @@
 
 
-def table_post_process(table_text_list, predict_result, threshold=0.5):
-    predict_result = predict_result.tolist()
+def table_post_process(table_text_list, predict_result, threshold=0.5, is_list=False):
+    if not is_list:
+        predict_result = predict_result.tolist()
     predict_list = []
     for i in range(0, len(predict_result)):
         predict = predict_result[i][0]

+ 35 - 14
BiddingKG/dl/table_head/pre_process.py

@@ -440,26 +440,47 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
                   {'output': Y}
 
     else:
-        while True:
-            new_data_list = []
-            for j in range(batch_size):
-                if i >= data_num:
-                    i = 0
-
-                # 中文字符映射为Embedding
-                data = data_list[i]
-                data = embedding_word(data, output_shape)
-                if data.shape == output_shape:
-                    new_data_list.append(data)
-                i += 1
+        new_data_list = []
+        for j in range(len(data_list)):
+            # 中文字符映射为Embedding
+            data = data_list[i]
+            data = embedding_word(data, output_shape)
+            if data.shape == output_shape:
+                new_data_list.append(data)
+            i += 1
 
-            new_data_list = np.array(new_data_list)
-            X = new_data_list
+        for j in range(0, len(data_list), batch_size):
+            sub_data_list = np.array(new_data_list[j: j+batch_size])
+            X = sub_data_list
             X = np.transpose(X, (1, 0, 2, 3))
+            # print(X)
+            # return X
             yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
                    'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
 
 
+def my_data_loader_predict(data_list, data_label_list, batch_size):
+    data_num = len(data_list)
+
+    # 定义Embedding输出
+    output_shape = (6, 20, 60)
+
+    i = 0
+    new_data_list = []
+    for j in range(len(data_list)):
+        # 中文字符映射为Embedding
+        data = data_list[i]
+        data = embedding_word(data, output_shape)
+        if data.shape == output_shape:
+            new_data_list.append(data)
+        i += 1
+
+    sub_data_list = np.array(new_data_list)
+    X = sub_data_list
+    X = np.transpose(X, (1, 0, 2, 3))
+    return X
+
+
 def my_data_loader_2(table_list, table_label_list, batch_size, is_train=True):
     pad_len = 0
 

File diff suppressed because it is too large
+ 231 - 15
BiddingKG/dl/table_head/predict.py


Some files were not shown because too many files changed in this diff