Bläddra i källkod

修正表格补全,新增产品数量单价品牌规格提取

lishimin 3 år sedan
förälder
incheckning
c75afb337e

+ 2 - 0
BiddingKG/dl/interface/Preprocessing.py

@@ -100,6 +100,8 @@ def tableToText(soup):
                                         tds1[indtd - 1].insert_after(copy.copy(td))
                                     else:
                                         tds1[0].insert_before(copy.copy(td))
+                                elif indtd-2>0 and len(tds1) > 0 and len(tds1) == indtd - 1:  # 修正某些表格最后一列没补全
+                                    tds1[indtd-2].insert_after(copy.copy(td))
     def getTable(tbody):
         #trs = tbody.findChildren('tr', recursive=False)
         trs = getTrs(tbody)

+ 7 - 1
BiddingKG/dl/interface/extract.py

@@ -67,6 +67,11 @@ def predict(doc_id,text,title=""):
     log("get product done of doc_id%s"%(doc_id))
     cost_time["product"] = time.time()-start_time
 
+    start_time = time.time()
+    product_attrs = predictor.getPredictor("product_attrs").predict(doc_id, text)
+    log("get product attributes done of doc_id%s"%(doc_id))
+    cost_time["product_attrs"] = time.time()-start_time
+
     start_time = time.time()
     predictor.getPredictor("roleRule").predict(list_articles,list_sentences, list_entitys,codeName)
     cost_time["rule"] = time.time()-start_time
@@ -99,7 +104,8 @@ def predict(doc_id,text,title=""):
 
     #print(prem)
     # data_res = Preprocessing.union_result(Preprocessing.union_result(codeName, prem),list_punish_dic)[0]
-    data_res = Preprocessing.union_result(Preprocessing.union_result(Preprocessing.union_result(codeName, prem),list_punish_dic), list_channel_dic)[0]
+    # data_res = Preprocessing.union_result(Preprocessing.union_result(Preprocessing.union_result(codeName, prem),list_punish_dic), list_channel_dic)[0]
+    data_res = dict(codeName[0], **prem[0], **list_channel_dic[0], **product_attrs[0])
     data_res["cost_time"] = cost_time
     data_res["success"] = True
 

BIN
BiddingKG/dl/interface/header_set.pkl


+ 266 - 0
BiddingKG/dl/interface/predictor.py

@@ -20,6 +20,8 @@ import tensorflow as tf
 from BiddingKG.dl.product.data_util import decode, process_data
 from BiddingKG.dl.interface.Entitys import Entity
 from BiddingKG.dl.complaint.punish_predictor import Punish_Extract
+from bs4 import BeautifulSoup
+import copy
 
 from threading import RLock
 dict_predictor = {"codeName":{"predictor":None,"Lock":RLock()},
@@ -30,6 +32,7 @@ dict_predictor = {"codeName":{"predictor":None,"Lock":RLock()},
                   "time":{"predictor":None,"Lock":RLock()},
                   "punish":{"predictor":None,"Lock":RLock()},
                   "product":{"predictor":None,"Lock":RLock()},
+                "product_attrs":{"predictor":None,"Lock":RLock()},
                   "channel": {"predictor": None, "Lock": RLock()}}
 
 
@@ -53,6 +56,8 @@ def getPredictor(_type):
                     dict_predictor[_type]["predictor"] = Punish_Extract()
                 if _type=="product":
                     dict_predictor[_type]["predictor"] = ProductPredictor()
+                if _type=="product_attrs":
+                    dict_predictor[_type]["predictor"] = ProductAttributesPredictor()
                 if _type == "channel":
                     dict_predictor[_type]["predictor"] = DocChannel()
             return dict_predictor[_type]["predictor"]
@@ -1495,6 +1500,267 @@ class ProductPredictor():
                     result.append(item) # 修正bug
                 return result
 
+# 产品数量单价品牌规格提取
+class ProductAttributesPredictor():
+    def __init__(self,):
+        self.p1 = '(设备|货物|商品|产品|物品|货品|材料|物资|物料|物件|耗材|备件|食材|食品|品目|标的|标的物|标项|资产|拍卖物|仪器|器材|器械|药械|药品|药材|采购品?|项目|招标|工程|服务)[\))]?(名称|内容|描述)'
+        self.p2 = '设备|货物|商品|产品|物品|货品|材料|物资|物料|物件|耗材|备件|食材|食品|品目|标的|标的物|资产|拍卖物|仪器|器材|器械|药械|药品|药材|采购品|项目|品名|菜名|内容|名称'
+        with open('E:\公告金额/header_set.pkl', 'rb') as f:
+            self.header_set = pickle.load(f)
+    def isTrueTable(self, table):
+        '''真假表格规则:
+        1、包含<caption>或<th>标签为真
+        2、包含大量链接、表单、图片或嵌套表格为假
+        3、表格尺寸太小为假
+        4、外层<table>嵌套子<table>,一般子为真,外为假'''
+        if table.find_all(['caption', 'th']) != []:
+            return True
+        elif len(table.find_all(['form', 'a', 'img'])) > 5:
+            return False
+        elif len(table.find_all(['tr'])) < 2:
+            return False
+        elif len(table.find_all(['table'])) >= 1:
+            return False
+        else:
+            return True
+
+    def getTrs(self, tbody):
+        # 获取所有的tr
+        trs = []
+        objs = tbody.find_all(recursive=False)
+        for obj in objs:
+            if obj.name == "tr":
+                trs.append(obj)
+            if obj.name == "tbody":
+                for tr in obj.find_all("tr", recursive=False):
+                    trs.append(tr)
+        return trs
+
+    def getTable(self, tbody):
+        trs = self.getTrs(tbody)
+        inner_table = []
+        if len(trs) < 2:
+            return inner_table
+        for tr in trs:
+            tr_line = []
+            tds = tr.findChildren(['td', 'th'], recursive=False)
+            if len(tds) < 3:
+                continue
+            for td in tds:
+                td_text = re.sub('\s', '', td.get_text())
+                tr_line.append(td_text)
+            inner_table.append(tr_line)
+        return inner_table
+
+    def fixSpan(self, tbody):
+        # 处理colspan, rowspan信息补全问题
+        trs = self.getTrs(tbody)
+        ths_len = 0
+        ths = list()
+        trs_set = set()
+        # 修改为先进行列补全再进行行补全,否则可能会出现表格解析混乱
+        # 遍历每一个tr
+
+        for indtr, tr in enumerate(trs):
+            ths_tmp = tr.findChildren('th', recursive=False)
+            # 不补全含有表格的tr
+            if len(tr.findChildren('table')) > 0:
+                continue
+            if len(ths_tmp) > 0:
+                ths_len = ths_len + len(ths_tmp)
+                for th in ths_tmp:
+                    ths.append(th)
+                trs_set.add(tr)
+            # 遍历每行中的element
+            tds = tr.findChildren(recursive=False)
+            if len(tds) < 3:
+                continue  # 列数太少的不补全
+            for indtd, td in enumerate(tds):
+                # 若有colspan 则补全同一行下一个位置
+                if 'colspan' in td.attrs and str(re.sub("[^0-9]", "", str(td['colspan']))) != "":
+                    col = int(re.sub("[^0-9]", "", str(td['colspan'])))
+                    if col < 10 and len(td.get_text()) < 500:
+                        td['colspan'] = 1
+                        for i in range(1, col, 1):
+                            td.insert_after(copy.copy(td))
+        for indtr, tr in enumerate(trs):
+            ths_tmp = tr.findChildren('th', recursive=False)
+            # 不补全含有表格的tr
+            if len(tr.findChildren('table')) > 0:
+                continue
+            if len(ths_tmp) > 0:
+                ths_len = ths_len + len(ths_tmp)
+                for th in ths_tmp:
+                    ths.append(th)
+                trs_set.add(tr)
+            # 遍历每行中的element
+            tds = tr.findChildren(recursive=False)
+            same_span = 0
+            if len(tds) > 1 and 'rowspan' in tds[0].attrs:
+                span0 = tds[0].attrs['rowspan']
+                for td in tds:
+                    if 'rowspan' in td.attrs and td.attrs['rowspan'] == span0:
+                        same_span += 1
+            if same_span == len(tds):
+                continue
+
+            for indtd, td in enumerate(tds):
+                # 若有rowspan 则补全下一行同样位置
+                if 'rowspan' in td.attrs and str(re.sub("[^0-9]", "", str(td['rowspan']))) != "":
+                    row = int(re.sub("[^0-9]", "", str(td['rowspan'])))
+                    td['rowspan'] = 1
+                    for i in range(1, row, 1):
+                        # 获取下一行的所有td, 在对应的位置插入
+                        if indtr + i < len(trs):
+                            tds1 = trs[indtr + i].findChildren(['td', 'th'], recursive=False)
+                            if len(tds1) >= (indtd) and len(tds1) > 0:
+                                if indtd > 0:
+                                    tds1[indtd - 1].insert_after(copy.copy(td))
+                                else:
+                                    tds1[0].insert_before(copy.copy(td))
+                            elif len(tds1) > 0 and len(tds1) == indtd - 1:
+                                tds1[indtd - 2].insert_after(copy.copy(td))
+
+    def find_header(self, items, p1, p2):
+        '''
+        inner_table 每行正则检查是否为表头,是则返回表头所在列序号,及表头内容
+        :param items: 列表,内容为每个td 文本内容
+        :param p1: 优先表头正则
+        :param p2: 第二表头正则
+        :return: 表头所在列序号,是否表头,表头内容
+        '''
+        flag = False
+        header_dic = {'名称': '', '数量': '', '单价': '', '品牌': '', '规格': ''}
+        product = ""  # 产品
+        quantity = ""  # 数量
+        unitPrice = ""  # 单价
+        brand = ""  # 品牌
+        specs = ""  # 规格
+        for i in range(min(4, len(items))):
+            it = items[i]
+            if len(it) < 15 and re.search(p1, it) != None:
+                flag = True
+                product = it
+                header_dic['名称'] = i
+                break
+        if not flag:
+            for i in range(min(4, len(items))):
+                it = items[i]
+                if len(it) < 15 and re.search(p2, it) and re.search(
+                        '编号|编码|号|情况|报名|单位|位置|地址|数量|单价|价格|金额|品牌|规格类型|型号|公司|中标人|企业|供应商|候选人', it) == None:
+                    flag = True
+                    product = it
+                    header_dic['名称'] = i
+                    break
+        if flag:
+            for j in range(i + 1, len(items)):
+                if len(items[j]) > 20 and len(re.sub('[\((].*[)\)]|[^\u4e00-\u9fa5]', '', items[j])) > 10:
+                    continue
+                if re.search('数量', items[j]):
+                    header_dic['数量'] = j
+                    quantity = items[j]
+                elif re.search('单价', items[j]):
+                    header_dic['单价'] = j
+                    unitPrice = items[j]
+                elif re.search('品牌', items[j]):
+                    header_dic['品牌'] = j
+                    brand = items[j]
+                elif re.search('规格', items[j]):
+                    header_dic['规格'] = j
+                    specs = items[j]
+            if header_dic.get('名称', "") != "" and (header_dic.get('数量', "") != "" or header_dic.get('单价', "") != ""
+                                                   or header_dic.get('品牌', "") != "" or header_dic.get('规格',
+                                                                                                       "") != ""):
+                return header_dic, flag, (product, quantity, unitPrice, brand, specs)
+
+        flag = False
+        return header_dic, flag, (product, quantity, unitPrice, brand, specs)
+
+    def predict(self, docid='', html=''):
+        '''
+        正则寻找table表格内 产品相关信息
+        :param html:公告HTML原文
+        :return:公告表格内 产品、数量、单价、品牌、规格 ,表头,表头列等信息
+        '''
+
+
+        soup = BeautifulSoup(html, 'lxml')
+        tables = soup.find_all(['table'])
+        headers = []
+        header_col = []
+        product_link = []
+        for table in tables:
+            if not self.isTrueTable(table):
+                continue
+            self.fixSpan(table)
+            inner_table = self.getTable(table)
+            i = 0
+            found_header = False
+            header_colnum = 0
+            while i < (len(inner_table)):
+                tds = inner_table[i]
+                not_empty = [it for it in tds if it != ""]
+                if len(set(not_empty)) < len(not_empty) * 0.5:
+                    i += 1
+                    continue
+                product = ""  # 产品
+                quantity = ""  # 数量
+                unitPrice = ""  # 单价
+                brand = ""  # 品牌
+                specs = ""  # 规格
+                if len(set(tds) & self.header_set) > len(tds) * 0.2:
+                    header_dic, found_header, header_list = self.find_header(tds, self.p1, self.p2)
+                    if found_header:
+                        headers.append('_'.join(header_list))
+                        header_colnum = len(tds)
+                        header_col.append('_'.join(tds))
+                    i += 1
+                    continue
+                elif found_header:
+                    if len(tds) != header_colnum:  # 表头、属性列数不一致跳过
+                        i += 1
+                        continue
+                    id1 = header_dic.get('名称', "")
+                    id2 = header_dic.get('数量', "")
+                    id3 = header_dic.get('单价', "")
+                    id4 = header_dic.get('品牌', "")
+                    id5 = header_dic.get('规格', "")
+                    if re.search('[a-zA-Z\u4e00-\u9fa5]', tds[id1]) and tds[id1] not in self.header_set and \
+                            re.search('备注|汇总|合计|总价|价格|金额|公司|附件|详见|无$|xxx', tds[id1]) == None:
+                        product = tds[id1]
+                        if id2 != "":
+                            if re.search('\d+|[壹贰叁肆伍陆柒捌玖拾一二三四五六七八九十]', tds[id2]):
+                                quantity = tds[id2]
+                            else:
+                                quantity = ""
+                        if id3 != "":
+                            if re.search('\d+|[零壹贰叁肆伍陆柒捌玖拾佰仟萬億十百千万亿元角分]{3,}', tds[id3]):
+                                unitPrice = tds[id3]
+                                if '万元' in header_list[2] and '万元' not in unitPrice:
+                                    unitPrice += '万元'
+                            else:
+                                unitPrice = ""
+                        if id4 != "":
+                            if re.search('\w', tds[id4]):
+                                brand = tds[id4]
+                            else:
+                                brand = ""
+                        if id5 != "":
+                            if re.search('\w', tds[id5]):
+                                specs = tds[id5]
+                            else:
+                                specs = ""
+                        if quantity != "" or unitPrice != "" or brand != "" or specs != "":
+                            # link = "{0}\t{1}\t{2}\t{3}\t{4}".format(product, quantity, unitPrice, brand, specs)
+                            link = {'product': product, 'quantity': quantity, 'unitPrice': unitPrice,
+                                                      'brand': brand[:50], 'speces': specs[:100]}
+                            if link not in product_link:
+                                product_link.append(link)
+                    i += 1
+                else:
+                    i += 1
+        return [{'product_attrs':{'data':product_link, 'header':headers, 'header_col':header_col}}]
+
 # docchannel类型提取
 class DocChannel():
   def __init__(self, life_model='/channel_savedmodel/channel.pb', type_model='/channel_savedmodel/doctype.pb'):