Explorar o código

优化地区提取,省或市提取不到的,补充标题、内容再次提取;

lsm %!s(int64=2) %!d(string=hai) anos
pai
achega
70bf0b8f14
Modificáronse 2 ficheiros con 98 adicións e 77 borrados
  1. 1 1
      BiddingKG/dl/interface/extract.py
  2. 97 76
      BiddingKG/dl/interface/predictor.py

+ 1 - 1
BiddingKG/dl/interface/extract.py

@@ -236,7 +236,7 @@ def predict(doc_id,text,title="",page_time="",web_source_no='',web_source_name="
 
     '''地区获取'''
     start_time = time.time()
-    district = predictor.getPredictor('district').predict(project_name=codeName[0]['name'], prem=prem, web_source_name=web_source_name)
+    district = predictor.getPredictor('district').predict(project_name=codeName[0]['name'], prem=prem,title=title, list_articles=list_articles, web_source_name=web_source_name)
     cost_time["district"] = round(time.time() - start_time, 2)
 
     '''限制行业最高金额'''

+ 97 - 76
BiddingKG/dl/interface/predictor.py

@@ -3925,7 +3925,16 @@ class DistrictPredictor():
             self.short2id = short2id
             self.full2id = full2id
 
-    def predict(self, project_name, prem, web_source_name = ""):
+    def predict(self, project_name, prem, title, list_articles, web_source_name = ""):
+        '''
+        先匹配 project_name+tenderee+tenderee_address, 如果缺少省或市 再匹配 title+content
+        :param project_name:
+        :param prem:
+        :param title:
+        :param list_articles:
+        :param web_source_name:
+        :return:
+        '''
         def get_ree_addr(prem):
             tenderee = ""
             tenderee_address = ""
@@ -3938,85 +3947,97 @@ class DistrictPredictor():
             except Exception as e:
                 print('解析prem 获取招标人、及地址出错')
             return tenderee, tenderee_address
-        tenderee, tenderee_address = get_ree_addr(prem)
-        project_name = str(project_name).replace(str(tenderee), '')
-        text = "{} {} {}".format(project_name, tenderee, tenderee_address)
-        web_source_name = str(web_source_name)  # 修复某些不是字符串类型造成报错
-        text = re.sub('复合肥|铁路|公路|新会计', ' ', text)  #预防提取错 合肥 路南 新会 等地区
-        score_l = []
-        id_set = set()
-
-        if re.search(self.short_name, text):
-            for it in re.finditer(self.full_name, text):
-                name = it.group(0)
-                score = len(name) / len(text)
-                for _id in self.full2id[name]:
-                    area = self.dist_dic[_id]['area'] + [''] * (3 - len(self.dist_dic[_id]['area']))
-                    # score_l.append([_id, score] + area)
-                    w = self.dist_dic[_id]['权重']
-                    score_l.append([_id, score+w]+ area)
-
-            flag = 0
-            for it in re.finditer(self.short_name, text):
-                if it.end() < len(text) and re.search('^(村|镇|街|路|江|河|湖|北路|南路|东路|大道|社区)', text[it.end():]) == None:
+        def get_area(text, web_source_name):
+            score_l = []
+            id_set = set()
+
+            if re.search(self.short_name, text):
+                for it in re.finditer(self.full_name, text):
                     name = it.group(0)
-                    score = (it.start() + len(name)) / len(text)
-                    for _id in self.short2id[name]:
-                        score2 = 0
+                    score = len(name) / len(text)
+                    for _id in self.full2id[name]:
+                        area = self.dist_dic[_id]['area'] + [''] * (3 - len(self.dist_dic[_id]['area']))
+                        # score_l.append([_id, score] + area)
                         w = self.dist_dic[_id]['权重']
-                        _type = self.dist_dic[_id]['类型']
+                        score_l.append([_id, score + w] + area)
+
+                flag = 0
+                for it in re.finditer(self.short_name, text):
+                    if it.end() < len(text) and re.search('^(村|镇|街|路|江|河|湖|北路|南路|东路|大道|社区)', text[it.end():]) == None:
+                        name = it.group(0)
+                        score = (it.start() + len(name)) / len(text)
+                        for _id in self.short2id[name]:
+                            score2 = 0
+                            w = self.dist_dic[_id]['权重']
+                            _type = self.dist_dic[_id]['类型']
+                            area = self.dist_dic[_id]['area'] + [''] * (3 - len(self.dist_dic[_id]['area']))
+                            if area[0] in ['2', '16', '20', '30']:
+                                _type += 10
+                            score2 += w
+                            if _id not in id_set:
+                                if _type == 20:
+                                    type_w = 3
+                                elif _type == 30:
+                                    type_w = 2
+                                else:
+                                    type_w = 1
+                                id_set.add(_id)
+                                score2 += w * type_w
+                            score_l.append([_id, score * w + score2] + area)
+
+                if flag == 1:
+                    pass
+                #         print('score', score)
+            if re.search('公司', web_source_name) == None:
+                for it in re.finditer(self.short_name, web_source_name):
+                    name = it.group(0)
+                    for _id in self.short2id[name]:
                         area = self.dist_dic[_id]['area'] + [''] * (3 - len(self.dist_dic[_id]['area']))
-                        if area[0] in ['2', '16', '20', '30']:
-                            _type += 10
-                        score2 += w
-                        if _id not in id_set:
-                            if _type == 20:
-                                type_w = 3
-                            elif _type == 30:
-                                type_w = 2
-                            else:
-                                type_w = 1
-                            id_set.add(_id)
-                            score2 += w * type_w
-                        score_l.append([_id, score * w + score2] + area)
-
-            if flag == 1:
-                pass
-            #         print('score', score)
-        if re.search('公司', web_source_name) == None:
-            for it in re.finditer(self.short_name, web_source_name):
-                name = it.group(0)
-                for _id in self.short2id[name]:
-                    area = self.dist_dic[_id]['area'] + [''] * (3 - len(self.dist_dic[_id]['area']))
-                    w = self.dist_dic[_id]['权重']
-                    score = w * 0.2
-                    score_l.append([_id, score] + area)
-        area_dic = {'area': '全国', 'province': '全国', 'city': '未知', 'district': '未知'}
-        if len(score_l) == 0:
-            return {'district':area_dic}
-        else:
-            df = pd.DataFrame(score_l, columns=['id', 'score', 'province', 'city', 'district'])
-            df_pro = df.groupby('province').sum().sort_values(by=['score'], ascending=False)
-            pro_id = df_pro.index[0]
-            # if df_pro.loc[pro_id, 'score'] < 0.1:  # 省级评分小于0.1的不要
-            #     print('评分低于0.1', df_pro.loc[pro_id, 'score'], self.dist_dic[pro_id]['地区'])
-            #     return area_dic
-            area_dic['province'] = self.dist_dic[pro_id]['地区']
-            area_dic['area'] = self.dist_dic[pro_id]['大区']
-            df = df[df['city'] != ""]
-            df = df[df['province'] == pro_id]
-            if len(df) > 0:
-                df_city = df.groupby('city').sum().sort_values(by=['score'], ascending=False)
-                city_id = df_city.index[0]
-                area_dic['city'] = self.dist_dic[city_id]['地区']
-                df = df[df['district'] != ""]
-                df = df[df['city'] == city_id]
+                        w = self.dist_dic[_id]['权重']
+                        score = w * 0.2
+                        score_l.append([_id, score] + area)
+            area_dic = {'area': '全国', 'province': '全国', 'city': '未知', 'district': '未知'}
+            if len(score_l) == 0:
+                return {'district': area_dic}
+            else:
+                df = pd.DataFrame(score_l, columns=['id', 'score', 'province', 'city', 'district'])
+                df_pro = df.groupby('province').sum().sort_values(by=['score'], ascending=False)
+                pro_id = df_pro.index[0]
+                # if df_pro.loc[pro_id, 'score'] < 0.1:  # 省级评分小于0.1的不要
+                #     print('评分低于0.1', df_pro.loc[pro_id, 'score'], self.dist_dic[pro_id]['地区'])
+                #     return area_dic
+                area_dic['province'] = self.dist_dic[pro_id]['地区']
+                area_dic['area'] = self.dist_dic[pro_id]['大区']
+                df = df[df['city'] != ""]
+                df = df[df['province'] == pro_id]
                 if len(df) > 0:
-                    df_dist = df.groupby('district').sum().sort_values(by=['score'], ascending=False)
-                    dist_id = df_dist.index[0]
-                    area_dic['district'] = self.dist_dic[dist_id]['地区']
-            # print(area_dic)
-            return {'district':area_dic}
+                    df_city = df.groupby('city').sum().sort_values(by=['score'], ascending=False)
+                    city_id = df_city.index[0]
+                    area_dic['city'] = self.dist_dic[city_id]['地区']
+                    df = df[df['district'] != ""]
+                    df = df[df['city'] == city_id]
+                    if len(df) > 0:
+                        df_dist = df.groupby('district').sum().sort_values(by=['score'], ascending=False)
+                        dist_id = df_dist.index[0]
+                        area_dic['district'] = self.dist_dic[dist_id]['地区']
+                # print(area_dic)
+                return {'district': area_dic}
+
+        tenderee, tenderee_address = get_ree_addr(prem)
+        project_name = str(project_name).replace(str(tenderee), '')
+        text1 = "{} {} {}".format(project_name, tenderee, tenderee_address)
+        web_source_name = str(web_source_name)  # 修复某些不是字符串类型造成报错
+        text1 = re.sub('复合肥|铁路|公路|新会计', ' ', text1)  #预防提取错 合肥 路南 新会 等地区
+        rs = get_area(text1, web_source_name)
+        if rs['district']['province'] == '全国' or rs['district']['city'] == '未知':
+            text2 = title + list_articles[0].content if len(list_articles[0].content)<2000 else title + list_articles[0].content[:1000] + list_articles[0].content[-1000:]
+            text2 = re.sub('复合肥|铁路|公路|新会计', ' ', text2)
+            rs2 = get_area(text2, web_source_name)
+            if rs['district']['province'] == '全国' and rs2['district']['province'] != '全国':
+                rs = rs2
+            elif rs['district']['province'] == rs2['district']['province'] and rs2['district']['city'] != '未知':
+                rs = rs2
+        return rs
 
 
 def getSavedModel():