Explorar o código

调整匹配规则和搜索规则

luojiehua hai 1 ano
pai
achega
6260463df2

+ 39 - 5
BaseDataMaintenance/maintenance/product/productUtils.py

@@ -84,6 +84,7 @@ def get_embedding_search(coll,index_name,name,grade,vector,search_params,output_
                     _d[k] = _search.entity.get(k)
                 final_list.append(_d)
             final_list = remove_repeat_item(final_list,k="ots_name")
+            final_list.sort(key=lambda x:x.get("level",1))
             try:
                 db.set(_md5,json.dumps(final_list))
                 db.expire(_md5,2*60)
@@ -244,10 +245,35 @@ def is_contain(source,target,min_len=2):
         return True
     return False
 
-
+def check_char(source,target,chat_pattern=re.compile("^[a-zA-Z0-9]+$"),find_pattern=re.compile("(?P<product>[a-zA-Z0-9]+)")):
+    if re.search(chat_pattern,source) is not None or re.search(chat_pattern,target) is not None:
+        a = set(re.findall(find_pattern,source))
+        b = set(re.findall(find_pattern,target))
+        if len(a&b)>0:
+            return True
+        else:
+            return False
 
 def check_product(source,target):
-    if is_contain(source,target,min_len=3):
+    _check = check_char(source,target)
+    if _check:
+        return True
+    else:
+        if _check==False:
+            return False
+
+    if is_contain(source,target,min_len=2):
+        return True
+    max_len = max(len(source),len(target))
+    min_len = min(len(source),len(target))
+    min_ratio = 92
+    if min_len<2:
+        return False
+    elif max_len<=5:
+        min_ratio=94
+    else:
+        min_ratio = 90
+    if is_similar(source,target,min_ratio):
         return True
     return False
 
@@ -260,7 +286,7 @@ def check_brand(source,target):
     min_len = min(len(source),len(target))
 
     min_ratio = 92
-    if max_len<2:
+    if min_len<2:
         return False
     elif max_len<=5:
         min_ratio=94
@@ -269,7 +295,14 @@ def check_brand(source,target):
 
     source_c = "".join(get_chinese_string(source))
     target_c = "".join(get_chinese_string(target))
-    print(source_c,target_c)
+
+    _check = check_char(source,target)
+    if _check:
+        return True
+    else:
+        if _check==False:
+            return False
+
     if len(source_c)>=2 and len(target_c)>=2:
         if not(source_c in area_set or target_c in area_set):
             if is_similar(source_c,target_c,min_ratio):
@@ -499,7 +532,8 @@ def clean_product_quantity(product_quantity):
     return ""
 
 if __name__ == '__main__':
-    print(check_brand('杭州郎基','杭州利华'))
+    # print(check_brand('杭州郎基','杭州利华'))
+    print(check_product("数字化医用X射线摄影系统(DR)","DR"))
     # print(re.split("[^\u4e00-\u9fff]",'128排RevolutionCTES彩色多普勒超声诊断仪VolusonE10'))
     # import Levenshtein
     # print(Levenshtein.ratio('助听器','助行器'))

+ 96 - 5
BaseDataMaintenance/maintenance/product/product_dict.py

@@ -65,7 +65,9 @@ class Product_Dict_Manager():
             FieldSchema(name="standard_name_id",dtype=DataType.VARCHAR,max_length=32),
             FieldSchema(name="embedding",dtype=DataType.FLOAT_VECTOR,dim=1024),
             FieldSchema(name="ots_parent_id",dtype=DataType.VARCHAR,max_length=32),
-            FieldSchema(name="ots_grade",dtype=DataType.INT64)
+            FieldSchema(name="ots_grade",dtype=DataType.INT64),
+            FieldSchema(name="remove_words",dtype=DataType.VARCHAR,max_length=3000),
+            FieldSchema(name="level",dtype=DataType.INT64),
         ]
 
         index_name = "embedding"
@@ -826,23 +828,29 @@ def clean_similar():
 
 
 
-def insert_new_record_to_milvus(Coll,name,grade,parent_id,standard_alias):
+def insert_new_record_to_milvus(Coll,name,grade,parent_id,standard_alias,remove_words="",level=1):
 
     n_name = get_milvus_standard_name(name)
     name_id = get_milvus_product_dict_id(n_name)
 
+
     vector = request_embedding(n_name)
 
     log("insert name %s grade %d"%(name,grade))
     if vector is not None and Coll is not None:
 
+        expr = " ots_id in ['%s']"%name_id
+        Coll.delete(expr)
         data = [[name_id],
                 [name],
                 [name],
                 [name_id],
                 [vector],
                 [parent_id],
-                [grade]]
+                [grade],
+                [remove_words],
+                [level]
+                ]
         insert_embedding(Coll,data)
 
         if standard_alias is not None and standard_alias!="":
@@ -854,6 +862,9 @@ def insert_new_record_to_milvus(Coll,name,grade,parent_id,standard_alias):
                 if _alias==name:
                     continue
                 _id = get_document_product_dict_standard_alias_id(_alias)
+
+                expr = " ots_id in ['%s']"%_id
+                Coll.delete(expr)
                 n_alias = get_milvus_standard_name(_alias)
                 vector = request_embedding(n_alias)
                 data = [[_id],
@@ -862,7 +873,10 @@ def insert_new_record_to_milvus(Coll,name,grade,parent_id,standard_alias):
                         [name_id],
                         [vector],
                         [parent_id],
-                        [grade]]
+                        [grade],
+                        [remove_words],
+                        [level]
+                        ]
                 insert_embedding(Coll,data)
         return True
 
@@ -916,8 +930,85 @@ def interface_deletes():
         print(s)
         dict_interface_delete(s,grade,ots_client)
 
+def clean_brands():
+    from queue import Queue as TQueue
+    task_queue = TQueue()
+    ots_client = getConnect_ots()
+
+    list_data = []
+
+    columns=[DOCUMENT_PRODUCT_DICT_NAME,DOCUMENT_PRODUCT_DICT_PARENT_ID,DOCUMENT_PRODUCT_DICT_GRADE]
+
+    bool_query = BoolQuery(must_queries=[
+        RangeQuery(DOCUMENT_PRODUCT_DICT_GRADE,4,4,True,True),
+    ])
+
+    rows,next_token,total_count,is_all_succeed = ots_client.search(Document_product_dict_table_name,Document_product_dict_table_name+"_index",
+                                                                   SearchQuery(bool_query,sort=Sort(sorters=[FieldSort(DOCUMENT_PRODUCT_DICT_IS_SYNCHONIZED)]),limit=100,get_total_count=True),
+                                                                   columns_to_get=ColumnsToGet(columns,ColumnReturnType.SPECIFIED))
+
+    list_dict = getRow_ots(rows)
+    for _d in list_dict:
+        list_data.append(_d)
+
+    while next_token:
+        rows,next_token,total_count,is_all_succeed = ots_client.search(Document_product_dict_table_name,Document_product_dict_table_name+"_index",
+                                                                       SearchQuery(bool_query,next_token=next_token,limit=100,get_total_count=True),
+                                                                       columns_to_get=ColumnsToGet(columns,ColumnReturnType.SPECIFIED))
+        list_dict = getRow_ots(rows)
+        for _d in list_dict:
+            list_data.append(_d)
+        # if len(list_data)>=1000:
+        #     break
+    log("product_dict embedding total_count:%d"%total_count)
+
+    set_key = set()
+    list_process_data = []
+    for _d in list_data:
+        name = _d.get(DOCUMENT_PRODUCT_DICT_NAME)
+        grade = _d.get(DOCUMENT_PRODUCT_DICT_GRADE)
+        _key = "%s-%d"%(name,grade)
+        if _key in set_key:
+            continue
+        set_key.add(_key)
+        task_queue.put(_d)
+        list_process_data.append(_d)
+    def _handle(item,result_queue):
+        name = item.get(DOCUMENT_PRODUCT_DICT_NAME)
+
+        if is_legal_brand(ots_client,name):
+            item["legal"] = 1
+        else:
+            bool_query = BoolQuery(must_queries=[
+                TermQuery("brand",name)
+            ])
+            rows,next_token,total_count,is_all_succeed = ots_client.search("document_product","document_product_index",
+                                                                           SearchQuery(bool_query,get_total_count=True))
+            if total_count>0:
+                item["legal"] = 1
+            else:
+                item["legal"] = 0
+    mt = MultiThreadHandler(task_queue,_handle,None,30)
+    mt.run()
+
+    list_legal = []
+    list_illegal = []
+    for _data in list_process_data:
+        name = _data.get(DOCUMENT_PRODUCT_DICT_NAME)
+        legal = _data["legal"]
+        if legal==1:
+            list_legal.append(name)
+        else:
+            list_illegal.append(name)
+    with open("../../test/legal_brand.txt", "w", encoding="utf8") as f:
+        for _name in list_legal:
+            f.write("%s\n"%(_name))
+    with open("../../test/illegal_brand.txt", "w", encoding="utf8") as f:
+        for _name in list_illegal:
+            f.write("%s\n"%(_name))
 
 if __name__ == '__main__':
     # start_embedding_product_dict()
     # interface_deletes()
-    clean_similar()
+    # clean_similar()
+    clean_brands()

+ 26 - 42
BaseDataMaintenance/maintenance/product/products.py

@@ -24,7 +24,7 @@ from BaseDataMaintenance.maintenance.product.product_dict import Product_Dict_Ma
 from apscheduler.schedulers.blocking import BlockingScheduler
 
 from BaseDataMaintenance.maintenance.product.make_brand_pattern import *
-from BaseDataMaintenance.maintenance.product.product_dict import IS_SYNCHONIZED
+from BaseDataMaintenance.maintenance.product.product_dict import *
 import logging
 
 root = logging.getLogger()
@@ -981,10 +981,10 @@ def test_check_brand():
         else:
             brand = _d.get("brand")
             list_illegal_brand.append(brand)
-    with open("legal_brand.txt","w",encoding="utf8") as f:
+    with open("../../test/legal_brand.txt", "w", encoding="utf8") as f:
         for b in list_legal_brand:
             f.write(b+"\n")
-    with open("illegal_brand.txt","w",encoding="utf8") as f:
+    with open("../../test/illegal_brand.txt", "w", encoding="utf8") as f:
         for b in list_illegal_brand:
             f.write(b+"\n")
 
@@ -1017,6 +1017,16 @@ def test_match():
     start_time = time.time()
     # final_list = get_embedding_search(Coll,embedding_index_name,a,_GRADE,vector,pm.search_params,output_fields,limit=5)
     final_list = get_intellect_search(Coll,embedding_index_name,a,_GRADE,pm.search_params,output_fields,limit=10)
+    for _search in final_list:
+        ots_id = _search.get("standard_name_id")
+        ots_name = _search.get("ots_name")
+        standard_name = _search.get("standard_name")
+        ots_parent_id = _search.get("ots_parent_id")
+        if is_similar(a,ots_name) or check_product(a,ots_name):
+            print("similar",a,ots_name)
+        else:
+            print("not similar",a,ots_name)
+
     print("cost",time.time()-start_time)
     print(final_list)
 
@@ -1057,7 +1067,6 @@ def rebuild_milvus():
 
     log("rebuild milvus %d counts"%(task_queue.qsize()))
     def insert_into_milvus(item,result_queue):
-
         name = item.get(DOCUMENT_PRODUCT_DICT_NAME,"")
         grade = item.get(DOCUMENT_PRODUCT_DICT_GRADE)
 
@@ -1065,49 +1074,24 @@ def rebuild_milvus():
             name = clean_product_specs(name)
             if len(name)<2:
                 return
+        if len(name)<2:
+            return
 
-        n_name = get_milvus_standard_name(name)
-        name_id = get_milvus_product_dict_id(n_name)
 
-        vector = request_embedding(n_name)
         parent_id = item.get(DOCUMENT_PRODUCT_DICT_PARENT_ID,"")
 
         Coll,_ = pdm.get_collection(grade)
         standard_alias = item.get(DOCUMENT_PRODUCT_DICT_STANDARD_ALIAS,"")
 
-
-
         log("insert name %s grade %d"%(name,grade))
-        if vector is not None and Coll is not None:
-
-            data = [[name_id],
-                    [name],
-                    [name],
-                    [name_id],
-                    [vector],
-                    [parent_id],
-                    [grade]]
-            insert_embedding(Coll,data)
-
-            if standard_alias is not None and standard_alias!="":
-                list_alias = standard_alias.split(DOCUMENT_PRODUCT_DICT_STANDARD_ALIAS_SEPARATOR)
-                for _alias in list_alias:
-                    _alias = _alias.strip()
-                    if len(_alias)==0:
-                        continue
-                    if _alias==name:
-                        continue
-                    _id = get_document_product_dict_standard_alias_id(_alias)
-                    n_alias = get_milvus_standard_name(_alias)
-                    vector = request_embedding(n_alias)
-                    data = [[_id],
-                            [_alias],
-                            [name],
-                            [name_id],
-                            [vector],
-                            [parent_id],
-                            [grade]]
-                    insert_embedding(Coll,data)
+        remove_words = item.get(DOCUMENT_PRODUCT_DICT_REMOVE_WORDS,"")
+        level = item.get(DOCUMENT_PRODUCT_DICT_LEVEL)
+        if level is None:
+            if re.search("装置|设备",name) is not None:
+                level = 2
+            else:
+                level = 1
+        insert_new_record_to_milvus(Coll,name,grade,parent_id,standard_alias,remove_words,level)
 
     def start_thread():
         mt = MultiThreadHandler(task_queue,insert_into_milvus,None,5)
@@ -1159,7 +1143,7 @@ def move_document_product():
 
 current_path = os.path.dirname(__file__)
 def delete_brands():
-    filename = os.path.join(current_path,"search_similar2_1.xlsx_brand_move.txt")
+    filename = os.path.join(current_path,"illegal_brand.txt")
 
     ots_client = getConnect_ots()
     list_brand = []
@@ -1280,8 +1264,8 @@ def test():
     # pm.test()
     # fix_product_data()
     # test_check_brand()
-    test_match()
-    # rebuild_milvus()
+    # test_match()
+    rebuild_milvus()
 
     # move_document_product()
     # delete_brands()

+ 3 - 0
BaseDataMaintenance/model/ots/document_product_dict.py

@@ -13,6 +13,9 @@ DOCUMENT_PRODUCT_DICT_IS_SYNCHONIZED = "is_synchonized"
 
 DOCUMENT_PRODUCT_DICT_STANDARD_ALIAS = "standard_alias"
 
+DOCUMENT_PRODUCT_DICT_REMOVE_WORDS = "remove_words"
+DOCUMENT_PRODUCT_DICT_LEVEL = "level"
+
 DOCUMENT_PRODUCT_DICT_STANDARD_ALIAS_SEPARATOR = "|"