12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188 |
- import base64
- import copy
- import json
- import logging
- import math
- import random
- import re
- import traceback
- from glob import glob
- import cv2
- from sklearn.cluster import AffinityPropagation, DBSCAN
- # from tensorflow_version.table_head_predict import predict
- from botr.utils import request_post, line_iou, pil_resize, get_best_predict_size2, line_overlap
- import jieba
- import numpy as np
- from matplotlib import pyplot as plt
- def _plot(_line_list, mode=1):
- for _line in _line_list:
- if mode == 1:
- x0, y0, x1, y1 = _line.__dict__.get("bbox")
- elif mode == 2:
- x0, y0, x1, y1 = _line
- elif mode == 3:
- x0, y0 = _line[0]
- x1, y1 = _line[1]
- plt.plot([x0, x1], [y0, y1])
- plt.show()
- return
- def get_table_by_rule2(img, text_list, bbox_list, table_location, is_test=0):
- # 处理bbox,缩小框
- bbox_list = shrink_bbox(img, bbox_list)
- # 创建对应dict
- bbox_text_dict = {}
- for i in range(len(text_list)):
- bbox_text_dict[str(bbox_list[i])] = text_list[i]
- # 获取全局的按行排列bbox
- row_list = get_table_rows(bbox_list, bbox_text_dict)
- if len(row_list) == 0:
- return [], [], []
- # 删除只有一个bbox的第一行和最后一行
- if len(row_list[0]) == 1:
- table_location = [table_location[0], row_list[0][0][2][1],
- table_location[2], table_location[3]]
- row_list = row_list[1:]
- if len(row_list[-1]) == 1:
- table_location = [table_location[0], table_location[1],
- table_location[2], row_list[-1][0][0][1]]
- row_list = row_list[:-1]
- # 获取表格区域,以及区域里的按行排列bbox
- table_location_list = [[[int(table_location[0]), int(table_location[1])], [int(table_location[2]), int(table_location[3])]]]
- area_row_list = [row_list]
- area_row_list = merge_row_bbox_list(area_row_list)
- # 获取全局的按列排列bbox
- area_col_list = get_table_cols(bbox_list, table_location_list)
- # 获取行线、列线
- area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
- if is_test:
- _plot(area_row_lines[0] + area_col_lines[0], mode=3)
- # 判断列线合法
- area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict)
- # 判断行线合法
- area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict)
- if is_test:
- _plot(area_row_lines[0] + area_col_lines[0], mode=3)
- # 由线得到按行列排列的bbox
- area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list)
- if is_test:
- for a in area_table_cell_list:
- for r in a:
- for c in r:
- cv2.rectangle(img, c[0], c[1], (255, 0, 0), 1)
- # for b in c:
- # cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1)
- cv2.imshow('table_cell', img)
- cv2.waitKey(0)
- # 展示
- if is_test:
- show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list)
- if not area_row_lines or not area_col_lines:
- return [], [], []
- line_list = [[x[0][0], x[0][1], x[1][0], x[1][1]] for x in area_row_lines[0] + area_col_lines[0]]
- cell_list = area_table_cell_list[0]
- return line_list, cell_list, table_location
- def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=1):
- # 处理bbox,缩小框
- bbox_list = shrink_bbox(img, bbox_list)
- # 创建对应dict
- bbox_text_dict = {}
- for i in range(len(text_list)):
- bbox_text_dict[str(bbox_list[i])] = text_list[i]
- # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox
- table_left_up_point = [table_location[0], table_location[1]]
- min_distance = 100000000000
- first_bbox = bbox_list[0]
- for bbox in bbox_list:
- distance = abs(bbox[0][0] - table_left_up_point[0]) + abs(bbox[0][1] - table_left_up_point[1])
- if distance < min_distance:
- min_distance = distance
- first_bbox = bbox
- # 对first_bbox预处理
- # 分割
- new_bbox_list, bbox_text_dict = split_bbox(img, first_bbox, bbox_text_dict)
- if new_bbox_list:
- if first_bbox in bbox_list:
- bbox_list.remove(first_bbox)
- bbox_list += new_bbox_list
- new_bbox_list.sort(key=lambda x: (x[0][0]))
- first_bbox = new_bbox_list[0]
- # 根据第一个bbox,得到第一行
- first_row = []
- bbox_list.sort(key=lambda x: (x[0][1], x[0][0]))
- for bbox in bbox_list:
- # h有交集
- if first_bbox[0][1] <= bbox[0][1] <= first_bbox[2][1] \
- or first_bbox[0][1] <= bbox[2][1] <= first_bbox[2][1] \
- or bbox[0][1] <= first_bbox[0][1] <= bbox[2][1] \
- or bbox[0][1] <= first_bbox[2][1] <= bbox[2][1]:
- first_row.append(bbox)
- # h小于first_box
- elif bbox[2][1] <= first_bbox[0][1]:
- first_row.append(bbox)
- # 对第一行分列
- first_row.sort(key=lambda x: (x[0][0], x[0][1]))
- first_row_col = []
- used_bbox = []
- for bbox in first_row:
- if bbox in used_bbox:
- continue
- temp_col = []
- for bbox1 in first_row:
- if bbox1 in used_bbox:
- continue
- if bbox1[0][0] <= bbox[0][0] <= bbox1[2][0] \
- or bbox1[0][0] <= bbox[2][0] <= bbox1[2][0] \
- or bbox[0][0] <= bbox1[0][0] <= bbox[2][0] \
- or bbox[0][0] <= bbox1[2][0] <= bbox[2][0]:
- temp_col.append(bbox1)
- used_bbox.append(bbox1)
- first_row_col.append(temp_col)
- # 根据第一个bbox,得到第一列
- first_col = []
- bbox_list.sort(key=lambda x: (x[0][0], x[0][1]))
- for bbox in bbox_list:
- # w有交集
- if first_bbox[0][0] <= bbox[0][0] <= first_bbox[2][0] \
- or first_bbox[0][0] <= bbox[2][0] <= first_bbox[2][0] \
- or bbox[0][0] <= first_bbox[0][0] <= bbox[2][0] \
- or bbox[0][0] <= first_bbox[2][0] <= bbox[2][0]:
- first_col.append(bbox)
- # w小于first_box
- elif bbox[2][0] <= first_bbox[0][0]:
- first_col.append(bbox)
- # 对第一列分行
- first_col.sort(key=lambda x: (x[0][1], x[0][0]))
- first_col_row = []
- current_bbox = first_col[0]
- temp_row = []
- for bbox in first_col:
- if current_bbox[0][1] <= bbox[0][1] <= current_bbox[2][1] \
- or current_bbox[0][1] <= bbox[2][1] <= current_bbox[2][1] \
- or bbox[0][1] <= current_bbox[0][1] <= bbox[2][1] \
- or bbox[0][1] <= current_bbox[2][1] <= bbox[2][1]:
- temp_row.append(bbox)
- else:
- if temp_row:
- temp_row.sort(key=lambda x: x[0][1])
- first_col_row.append(temp_row)
- temp_row = [bbox]
- current_bbox = bbox
- if temp_row:
- temp_row.sort(key=lambda x: x[0][1])
- first_col_row.append(temp_row)
- print('len(first_row)', len(first_row))
- print('first_row', [bbox_text_dict.get(str(x)) for x in first_row])
- print('first_col', [bbox_text_dict.get(str(x)) for x in first_col])
- print('len(first_col)', len(first_col))
- print('len(first_row_col)', len(first_row_col))
- print('len(first_col_row)', len(first_col_row))
- # 划线 列
- col_line_list = []
- for col in first_row_col:
- # 画2条线,根据左右bbox
- min_w, max_w = 1000000, 0
- print('col', [bbox_text_dict.get(str(x)) for x in col])
- for bbox in col:
- if bbox[0][0] < min_w:
- min_w = bbox[0][0]
- if bbox[2][0] > max_w:
- max_w = bbox[2][0]
- col_line_list.append([min_w, table_location[1], min_w, table_location[3]])
- col_line_list.append([max_w, table_location[1], max_w, table_location[3]])
- # 划线 行
- row_line_list = []
- last_max_h = None
- for row in first_col_row:
- # 画3条线,根据上下bbox
- min_h, max_h = 1000000, 0
- for bbox in row:
- if bbox[0][1] < min_h:
- min_h = bbox[0][1]
- if bbox[2][1] > max_h:
- max_h = bbox[2][1]
- row_line_list.append([table_location[0], min_h, table_location[2], min_h])
- row_line_list.append([table_location[0], max_h, table_location[2], max_h])
- # if last_max_h:
- # row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)])
- last_max_h = max_h
- print('len(col_line_list)', len(col_line_list))
- print('col_line_list', col_line_list)
- print('len(row_line_list)', len(row_line_list))
- # 判断列线有没有压在黑色像素上,若有则移动
- temp_list = []
- for i in range(1, len(col_line_list), 2):
- # 前一列右边线
- line1 = col_line_list[i]
- line1 = [int(x) for x in line1]
- # 后一列左边线
- if i+1 >= len(col_line_list):
- break
- line2 = col_line_list[i+1]
- line2 = [int(x) for x in line2]
- max_black_cnt = 10
- black_threshold = 150
- black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold)
- print('col black_cnt2', black_cnt2)
- if black_cnt2 <= max_black_cnt:
- temp_list.append(line2)
- else:
- black_cnt1 = count_black(img[line1[1]:line1[3], line1[0]:line1[2]+1, :], threshold=black_threshold)
- print('col black_cnt1', black_cnt1)
- if black_cnt1 <= max_black_cnt:
- temp_list.append(line1)
- else:
- # 两条线都不符合,从右向左移寻找
- for j in range(line2[0], line1[0], -1):
- black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold)
- print('col black_cnt', black_cnt)
- if black_cnt <= max_black_cnt:
- temp_list.append([j, line2[1], j, line2[3]])
- break
- col_line_list = temp_list
- # 根据列的划线对bbox分列
- last_line = [0, 0, 0, 0]
- col_bbox_list = []
- # used_bbox_list = []
- for line in col_line_list + [[img.shape[0], 0, img.shape[0], 0]]:
- col = []
- for bbox in bbox_list:
- # if bbox in used_bbox_list:
- # continue
- # print('last_line, line, bbox', last_line, line, bbox)
- iou = line_iou([[last_line[0], 0], [line[0], 0]], [[bbox[0][0], 0], [bbox[2][0], 0]], axis=0)
- if iou >= 0.6:
- col.append(bbox)
- # used_bbox_list.append(bbox)
- col.sort(key=lambda x: x[0][1])
- col_bbox_list.append(col)
- last_line = line
- # 判断行线
- temp_list = []
- for i in range(1, len(row_line_list), 2):
- # 前一行下边线
- line1 = row_line_list[i]
- line1 = [int(x) for x in line1]
- # 后一行上边线
- if i+1 >= len(row_line_list):
- break
- line2 = row_line_list[i+1]
- line2 = [int(x) for x in line2]
- # 判断行线之间的bbox分别属于哪一行
- sub_bbox_list = []
- threshold = 5
- for bbox in bbox_list:
- if line1[1] - threshold <= bbox[0][1] <= bbox[2][1] <= line2[1]+threshold:
- sub_bbox_list.append(bbox)
- # 根据行的h和分列判断bbox属于上一行还是下一行
- line1_bbox_list = []
- line2_bbox_list = []
- if sub_bbox_list:
- sub_bbox_list.sort(key=lambda x: x[0][1])
- min_h = sub_bbox_list[0][0][1] - 1
- max_h = sub_bbox_list[-1][2][1] + 1
- for bbox in sub_bbox_list:
- # 找到属于哪一列
- current_col = None
- for col in col_bbox_list:
- if bbox in col:
- current_col = copy.deepcopy(col)
- break
- if current_col:
- # 行做成bbox加入列作为基准
- line1_bbox = [[0, min_h], [], [0, min_h], []]
- line2_bbox = [[0, max_h], [], [0, max_h], []]
- current_col += [line1_bbox, line2_bbox]
- current_col.sort(key=lambda x: x[0][1])
- bbox_index = current_col.index(bbox)
- line1_bbox_index = current_col.index(line1_bbox)
- line2_bbox_index = current_col.index(line2_bbox)
- print('current_col', [bbox_text_dict.get(str(x)) for x in current_col])
- print('line1_bbox_index, bbox_index, line2_bbox_index', line1_bbox_index, bbox_index, line2_bbox_index)
- # 计算距离
- distance1 = 10000
- for index in range(line1_bbox_index, bbox_index):
- h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
- h2 = (current_col[index+1][0][1] + current_col[index+1][2][1]) / 2
- # print(bbox_text_dict.get())
- distance1 = abs(h1 - h2)
- distance2 = 10000
- for index in range(line2_bbox_index, bbox_index, -1):
- h1 = (current_col[index][0][1] + current_col[index][2][1]) / 2
- h2 = (current_col[index-1][0][1] + current_col[index-1][2][1]) / 2
- distance2 = abs(h1 - h2)
- print(bbox_text_dict.get(str(bbox)), distance1, distance2)
- ratio = 1.5
- # 属于下一行
- if distance1 >= distance2 * ratio or distance1 >= distance2 + 8:
- line2_bbox_list.append(bbox)
- # 属于上一行
- elif distance2 >= distance1 * ratio or distance2 >= distance1 + 8:
- line1_bbox_list.append(bbox)
- else:
- print('距离不明确,需要nsp模型介入判断')
- if line1_bbox_list:
- print('line1_bbox_list', [bbox_text_dict.get(str(x)) for x in line1_bbox_list])
- line1_bbox_list.sort(key=lambda x: x[0][1])
- b = line1_bbox_list[-1]
- line1 = [line1[0], b[2][1], line1[2], b[2][1]]
- if line2_bbox_list:
- print('line2_bbox_list', [bbox_text_dict.get(str(x)) for x in line2_bbox_list])
- line2_bbox_list.sort(key=lambda x: x[0][1])
- b = line2_bbox_list[0]
- line2 = [line2[0], b[0][1], line2[2], b[0][1]]
- _line = [line1[0], (line1[1]+line2[1])/2, line1[2], (line1[3]+line2[3])/2]
- _line = [int(x) for x in _line]
- temp_list.append(_line)
- row_line_list = temp_list
- # 加上表格轮廓线
- row_line_list.append([table_location[0], table_location[1], table_location[2], table_location[1]])
- row_line_list.append([table_location[0], table_location[3], table_location[2], table_location[3]])
- col_line_list.append([table_location[0], table_location[1], table_location[0], table_location[3]])
- col_line_list.append([table_location[2], table_location[1], table_location[2], table_location[3]])
- # 由线得到按行列排列的bbox
- area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, [row_line_list], [col_line_list], [table_location], bbox_list)
- # show
- if is_test:
- for line in col_line_list:
- cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2)
- for line in row_line_list:
- cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2)
- cv2.namedWindow('img', cv2.WINDOW_NORMAL)
- cv2.imshow('img', cv2.resize(img, (768, 1024)))
- cv2.waitKey(0)
- return [], [], []
- def split_bbox_by_kmeans(img, bbox, bbox_text_dict):
- sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
- # 从左至右扫描
- def get_table():
- # 1. 一个单元格多行合并需解决 √
- # 2. 一行多个单字合并 1007.jpg √
- # 3. ocr识别错误bbox剔除
- # 4. 上下表格合并 距离近,列数一样,或只少了第一列 1005.jpg 1014.jpg 1033.jpg √
- # 5. 相近行列线合并 1020.jpg 1025.jpg 1054.jpg 1068.jpg
- # 6. 行线在合并bbox中间,需向上或向下移动 105.jpg 1054.jpg 1020.jpg
- # 7. 贴着左边框的长bbox也当做标题分开表格 1047.jpg 1059.jpg √
- # 8. 判断非规整表格,单个单元格多个bbox,排除上下连接的bbox 105.jpg
- # 9. 判断非规整表格,ocr识别漏,黑色像素多 1050.jpg √
- # 10. 第一列序号ocr识别漏 1051.jpg
- # 11. 用其他列作为分行标准,作为辅助,挑平均间隔最大的,行数也够的列 1085.jpg
- # 12. 判断表格 两个bbox靠的太近的不能作为开始行 1106.jpg √
- # 13. 列中所有行间隔都很小,聚类距离统一值 1098.jpg √
- # 14. 漏列(需剔除表格中非表格部分) 1059.jpg
- # 15. 漏行 1064.jpg 1065.jpg 1067.jpg 1085.jpg 1097.jpg 1101.jpg √
- # 16. 表格分割错误 1045.jpg 1051.jpg 1078.jpg 1079.jpg √
- # 17. 分列时,第一行的表头选定 1051.jpg 1106.jpg 1129.jpg
- # 18. 分割同一行中多个列 1093.jpg 1095.jpg 110.jpg
- # 19. 表格漏了 1119.jpg 1141.jpg
- # 20. 非规整表格判断错误,黑色像素 1122.jpg 1121.jpg √
- # 21. 分列错误 1125.jpg 1158.jpg 1020.jpg √
- # 22. 分行分列错误(需在第一列排除过长bbox) 1131.jpg 1132.jpg √
- # 1135.jpg 1136.jpg 1147.jpg
- # 23. 表格范围外,与单元格内的文字上下相连 1134.jpg 1142.jpg
- # 24. 第一列空单元格太多可列为非规整
- # 25. 竖线跨越多个bbox的较中心位置,考虑剔除
- # 26. 竖线跨越bbox,考虑竖线缩短,将跨越的那一截去掉 1020.jpg
- # 27. 竖线插在一列中间,需调整其向右找到空白位置 1023.jpg
- # label_path = glob('../data/borderless_tables/*_label.jpg')
- # temp_label_path = []
- # label_row_dict = {}
- # for p in label_path:
- # img = cv2.imread(p)
- # row_img, col_img = get_lines_from_img(img)
- # label_row_list, is_standard = get_bbox_by_img(row_img, col_img)
- # label_row_dict[p] = label_row_list
- # if is_standard:
- # temp_label_path.append(p)
- # label_path = temp_label_path
- # print('len(label_path)', len(label_path))
- # for p in label_path:
- # print(p)
- with open('standard_table.txt', 'r') as f:
- label_path_list = f.readlines()
- # paths = glob('../data/borderless_tables/1.jpg') # merge_row
- # paths = glob('../data/borderless_tables/5.jpg') # title
- # paths = glob('../data/borderless_tables/26.jpg') # merge_col
- paths = glob('../data/borderless_tables/59.jpg') # split bbox
- paths = glob('../../hrnet-pytorch-main/my_dataset/borderless_tables/62.jpg')
- # paths = glob('../data/borderless_tables/57.jpg')
- paths = glob('../../hrnet-pytorch-main/my_dataset/borderless_tables/3.jpg') # not standard table
- # paths = glob(r'C:\Users\Administrator\Desktop\test_pdf_table\1.png')
- # label_path_list.append(r'C:\Users\Administrator\Desktop\test_pdf_table\1_label.jpg\n')
- paths = glob('../data/borderless_tables/*.jpg')
- # paths = glob('../data/standard_tables/*.jpg')
- path_cnt = 0
- all_teds = 0
- all_standard_cnt = 0
- for p in paths:
- if 'label' in p:
- continue
- label_p = p[:-4] + '_label.jpg\n'
- if label_p not in label_path_list:
- continue
- # if path_cnt <= 10:
- # path_cnt += 1
- # continue
- path_cnt += 1
- img = cv2.imread(p)
- result = test_ocr_model(p)
- print(p)
- # print(result)
- bbox_list = eval(result.get('bbox'))
- text_list = eval(result.get('text'))
- bbox_text_dict = {}
- for i in range(len(text_list)):
- bbox_text_dict[str(bbox_list[i])] = text_list[i]
- # split_bbox(img, text_list, bbox_list)
- # 获取全局的按行排列bbox
- row_list = get_table_rows(bbox_list)
- # bbox预处理
- bbox_list, text_list, bbox_text_dict = bbox_preprocess(bbox_list, text_list, row_list, bbox_text_dict)
- # bbox处理后再按行排列bbox
- row_list = get_table_rows(bbox_list)
- # 获取表格区域,以及区域里的按行排列bbox
- table_location_list, area_row_list = get_table_location(row_list)
- # 表格分割
- table_location_list, area_row_list = split_table(table_location_list, area_row_list, bbox_text_dict)
- table_location_list, area_row_list = split_table(table_location_list, area_row_list, bbox_text_dict)
- print('fix_table_location_list', table_location_list)
- # print('fix_area_row_list', area_row_list)
- # 获取表格区域里,按列排序bbox
- area_col_list = get_table_cols(bbox_list, table_location_list)
- # 合并一列中多行bbox
- area_row_list = merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list)
- # 排除非规整表格
- table_standard_list = delete_not_standard_table(img, area_row_list, area_col_list, table_location_list, bbox_list, bbox_text_dict)
- # 上下表格合并
- area_row_list, area_col_list, table_location_list = merge_table(area_row_list, area_col_list, table_location_list, bbox_list)
- # 获取行线、列线
- area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
- # 根据行列线生成对应bbox行列
- area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list)
- # 添加列线
- add_area_col_lines = add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict)
- for j in range(len(area_col_lines)):
- area_col_lines[j] += add_area_col_lines[j]
- # 判断列线合法
- area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict)
- area_col_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=1)
- area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0)
- #
- # for a in area_col_list:
- # for c in a:
- # print('area_col_list', [bbox_text_dict.get(str(x)) for x in c])
- #
- # # 合并一列中多行bbox
- # area_row_list = merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list)
- #
- # # 获取行线、列线
- # area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
- #
- # add_area_col_lines = add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict)
- #
- # for j in range(len(area_col_lines)):
- # area_col_lines[j] += add_area_col_lines[j]
- #
- # area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list)
- #
- # area_col_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=1)
- # area_row_list = get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0)
- #
- #
- add_area_row_lines = add_row_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict, area_row_lines)
- for j in range(len(area_row_lines)):
- area_row_lines[j] += add_area_row_lines[j]
- #
- area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict)
- # 合并相近线
- for j in range(len(area_col_lines)):
- area_col_lines[j] = merge_lines(area_col_lines[j], axis=1)
- area_row_lines[j] = merge_lines(area_row_lines[j], axis=0)
- # area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list)
- # area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list)
- # # 重新生成
- # table_location_list = []
- # temp_area_row_list = []
- # for temp_row_list in area_row_list:
- # location_list, temp_row_list = get_table_location(temp_row_list)
- # table_location_list += location_list
- # temp_area_row_list += temp_row_list
- # area_col_list = get_table_cols(bbox_list, table_location_list)
- # area_row_list = temp_area_row_list
- #
- # # 获取行线、列线
- # area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
- #
- # print('len(table_location_list)', len(table_location_list))
- # for bbox in bbox_list:
- # cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])),
- # (0, 0, 255), 1)
- #
- # for i in range(len(table_location_list)):
- # # location = table_location_list[i]
- # # cv2.rectangle(img, location[0], location[1], (0, 255, 0), 1)
- #
- # row_lines = area_row_lines[i]
- # col_lines = area_col_lines[i]
- # for r in row_lines:
- # cv2.line(img, r[0], r[1], (0, 255, 0), 1)
- # for c in col_lines:
- # cv2.line(img, c[0], c[1], (0, 255, 0), 1)
- #
- # cv2.imshow('img', img)
- # cv2.waitKey(0)
- # 计算标注表格和生成表格的相似度
- if len(table_location_list) == 1:
- # if not table_standard_list[0]:
- # continue
- row_lines = area_row_lines[0]
- col_lines = area_col_lines[0]
- row_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
- col_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
- for r in row_lines:
- cv2.line(row_img, r[0], r[1], (255, 255, 255), 1)
- for c in col_lines:
- cv2.line(col_img, c[0], c[1], (255, 255, 255), 1)
- row_list, is_standard = get_bbox_by_img(row_img, col_img)
- if not is_standard:
- continue
- row_list = merge_text_and_table(bbox_list, row_list)
- continue_flag = 0
- for row in row_list:
- for b in row:
- if len(b) > 1:
- continue_flag = 1
- break
- if continue_flag:
- continue
- max_len = 1
- continue_flag = 0
- for row in row_list:
- if abs(max_len - len(row)) > 2:
- continue_flag = 1
- break
- if len(row) > max_len:
- max_len = len(row)
- if continue_flag:
- continue
- img_label = cv2.imread(label_p[:-1])
- row_img1, col_img1 = get_lines_from_img(img_label)
- label_row_list, label_is_standard = get_bbox_by_img(row_img1, col_img1)
- if not label_is_standard:
- continue
- label_row_list = merge_text_and_table(bbox_list, label_row_list)
- add_flag = 0
- modify_flag = 0
- for i in range(len(row_list)):
- if i >= len(label_row_list):
- continue
- row = row_list[i]
- label_row = label_row_list[i]
- for r in label_row:
- if r not in row:
- add_flag += 1
- else:
- if label_row.index(r) != row.index(r):
- modify_flag += 1
- bbox_cnt = 0
- for row in row_list:
- for b in row:
- bbox_cnt += 1
- label_bbox_cnt = 0
- for row in label_row_list:
- for b in row:
- label_bbox_cnt += 1
- teds = 1 - (add_flag + modify_flag) / max(bbox_cnt, label_bbox_cnt)
- print('add_flag', add_flag, 'modify_flag', modify_flag, 'bbox_cnt', bbox_cnt, 'label_bbox_cnt', label_bbox_cnt)
- print('TEDS:', teds, p)
- all_teds += teds
- all_standard_cnt += 1
- # if teds <= 0.8:
- # print('row_list', [y for y in [x for x in row_list]])
- # print('label_row_list', [y for y in [x for x in label_row_list]])
- # cv2.imshow('model_table', row_img+col_img)
- # cv2.imshow('label_table', row_img1+col_img1)
- # cv2.waitKey(0)
- # for i in range(len(row_list)):
- try:
- avg_teds = all_teds / all_standard_cnt
- except:
- avg_teds = 0
- print('standard table cnt', all_standard_cnt)
- print('Avg TEDS', avg_teds)
- return
- def get_table_new():
- with open('standard_table.txt', 'r') as f:
- label_path_list = f.readlines()
- # 表格分割问题:1019.jpg, 1020.jpg, 1023.jpg, 1027.jpg, 1029.jpg, 1030.jpg, 1031.jpg, 1035.jpg, 1040.jpg, 1042.jpg, 1046.jpg, 1047.jpg, 1061.jpg, 1064.jpg, 1067.jpg, 1072.jpg
- # 分列问题:1059.jpg,
- paths = glob('../data/borderless_tables/*.jpg')
- # paths = glob(r'C:\Users\Administrator\Desktop\test_pdf_table\1.png')
- paths = ['1019.jpg', '1020.jpg', '1023.jpg', '1027.jpg', '1029.jpg', '1030.jpg', '1031.jpg', '1035.jpg', '1040.jpg', '1042.jpg', '1046.jpg', '1047.jpg', '1061.jpg', '1064.jpg', '1067.jpg', '1072.jpg']
- paths = ['../data/borderless_tables/' + x for x in paths]
- path_cnt = 0
- for p in paths:
- if 'label' in p:
- continue
- # label_p = p[:-4] + '_label.jpg\n'
- # if label_p not in label_path_list:
- # continue
- # if path_cnt <= 22:
- # path_cnt += 1
- # continue
- path_cnt += 1
- img = cv2.imread(p)
- result = test_ocr_model(p)
- print(p)
- bbox_list = eval(result.get('bbox'))
- text_list = eval(result.get('text'))
- # 处理bbox,缩小框
- bbox_list = shrink_bbox(img, bbox_list)
- # 创建对应dict
- bbox_text_dict = {}
- for i in range(len(text_list)):
- bbox_text_dict[str(bbox_list[i])] = text_list[i]
- # 获取全局的按行排列bbox
- row_list = get_table_rows(bbox_list, bbox_text_dict)
- # 获取表格区域,以及区域里的按行排列bbox
- table_location_list, area_row_list = get_table_location(row_list, bbox_text_dict)
- area_row_list = merge_row_bbox_list(area_row_list)
- # for a in area_row_list:
- # i = 0
- # for r in a:
- # print('row', i)
- # i += 1
- # for b in r:
- # print(bbox_text_dict.get(str(b)))
- # 获取全局的按列排列bbox
- area_col_list = get_table_cols(bbox_list, table_location_list)
- # 获取行线、列线
- area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
- # 判断列线合法
- area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict)
- # # 判断行线合法
- area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict)
- # 由线得到按行列排列的bbox
- area_table_bbox_list, area_table_cell_list = get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list)
- for a in area_table_bbox_list:
- for r in a:
- for c in r:
- # cv2.rectangle(img, c[0], c[1], (255, 0, 0), 1)
- for b in c:
- cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1)
- cv2.imshow('table_cell', img)
- # 分割表格
- # table_location_list, _ = split_table_new2(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict)
- # table_location_list, _ = split_table(table_location_list, area_row_list, bbox_text_dict)
- # table_location_list = split_table_by_col(table_location_list, area_table_bbox_list, bbox_text_dict)
- # table_location_list = split_table_by_table_head(table_location_list, area_table_bbox_list, bbox_text_dict)
- # 重新生成按行排列bbox
- area_row_list = get_table_rows2(area_row_list, table_location_list)
- # for a in area_row_list:
- # for r in a:
- # for b in r:
- # cv2.rectangle(img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (255, 0, 0), 1)
- # cv2.imshow('area_row_list', img)
- # 获取全局的按列排列bbox
- area_col_list = get_table_cols(bbox_list, table_location_list)
- # 获取行线、列线
- area_row_lines, area_col_lines = get_table_borders(area_row_list, area_col_list, table_location_list)
- # 判断列线合法
- area_col_lines = judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict)
- # 判断行线合法
- area_row_lines = judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict)
- # 展示
- show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list)
- return
- def show_result(img, bbox_list, area_row_lines, area_col_lines, table_location_list):
- for bbox in bbox_list:
- cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])),
- (0, 0, 255), 1)
- for i in range(len(table_location_list)):
- # location = table_location_list[i]
- # cv2.rectangle(img, location[0], location[1], (0, 255, 0), 1)
- row_lines = area_row_lines[i]
- col_lines = area_col_lines[i]
- for r in row_lines:
- cv2.line(img, r[0], r[1], (0, 255, 0), 1)
- for c in col_lines:
- cv2.line(img, c[0], c[1], (0, 255, 0), 1)
- cv2.namedWindow('img', cv2.WINDOW_NORMAL)
- cv2.imshow('img', img)
- cv2.waitKey(0)
- return
- def get_table_borders(area_row_list, area_col_list, table_location_list):
- area_row_lines = []
- area_col_lines = []
- # 循环每个表格
- for i in range(len(area_row_list)):
- row_list = area_row_list[i]
- col_list = area_col_list[i]
- location = table_location_list[i]
- # 获取行线
- row_lines = []
- row_lines.append([[location[0][0], location[0][1]], [location[1][0], location[0][1]]])
- for row in row_list:
- max_h = 0
- for bbox in row:
- if bbox[2][1] > max_h:
- max_h = int(bbox[2][1])
- row_lines.append([[location[0][0], max_h], [location[1][0], max_h]])
- row_lines[-1][0][1] = max(location[1][1], row_lines[-1][0][1])
- row_lines[-1][1][1] = max(location[1][1], row_lines[-1][1][1])
- # 补充表格行范围
- table_location_list[i][1][1] = max(location[1][1], row_lines[-1][1][1])
- location = table_location_list[i]
- # 获取列线
- col_lines = []
- col_lines.append([[location[0][0], location[0][1]], [location[0][0], location[1][1]]])
- for col in col_list:
- max_w = 0
- for bbox in col:
- if bbox[2][0] > max_w:
- max_w = int(bbox[2][0])
- col_lines.append([[max_w, location[0][1]], [max_w, location[1][1]]])
- # 补充表格列范围
- table_location_list[i][1][0] = max(location[1][0], col_lines[-1][1][0])
- location = table_location_list[i]
- for row in row_lines:
- row[0][0] = location[0][0]
- row[1][0] = location[1][0]
- area_row_lines.append(row_lines)
- area_col_lines.append(col_lines)
- return area_row_lines, area_col_lines
- def get_table_location(row_list, bbox_text_dict):
- # for r in row_list:
- # print('row', r)
- up_h = 10000
- bottom_h = 0
- left_w = 10000
- right_w = 0
- table_rows = 0
- tolerance_list = []
- area_row_list = []
- temp_row_list = []
- table_location_list = []
- catalog_text_cnt = 0
- for row in row_list:
- if len(row) >= 2:
- if not temp_row_list:
- # 第一行bbox之间需大于一定值
- max_distance = 0
- row.sort(key=lambda x: x[0][0])
- row_text_list = []
- catalog_text_cnt = 0
- bbox_height_list = [abs(row[-1][0][1] - row[-1][2][1])]
- for i in range(1, len(row)):
- dis = row[i][0][0] - row[i-1][2][0]
- if dis >= max_distance:
- max_distance = dis
- text = bbox_text_dict.get(str(row[i-1]))
- row_text_list.append(bbox_text_dict.get(str(row[i-1])))
- match = re.findall('\\.+\d+', text)
- if match and len(match[0]) == len(text):
- catalog_text_cnt += 1
- bbox_height_list.append(abs(row[i][0][1] - row[i][2][1]))
- # 排除
- # if len(row) == 2:
- # if max_distance <= abs(row[0][2][0] - row[0][0][0]):
- # continue
- # else:
- if max_distance <= 5:
- continue
- # 排除 '地 址', '名 称'
- # if len(row) == 2 and len(bbox_text_dict.get(str(row[0]))) == 1:
- # continue
- row_text_list = []
- bbox_height_list = []
- for i in range(len(row)):
- text = bbox_text_dict.get(str(row[i-1]))
- row_text_list.append(bbox_text_dict.get(str(row[i-1])))
- match = re.findall('\\.+\d+', text)
- if match and len(match[0]) == len(text):
- catalog_text_cnt += 1
- bbox_height_list.append(abs(row[i][0][1] - row[i][2][1]))
- # 排除height差别过大的
- bbox_height_list.sort(key=lambda x: x)
- if bbox_height_list[-1] - bbox_height_list[0] > bbox_height_list[0]:
- continue
- # 排除目录
- if catalog_text_cnt >= 3:
- continue
- # 排除水印图
- if len(list(set(row_text_list))) < 2/3 * len(row):
- continue
- # 排除有下划线的
- table_rows += 1
- temp_row_list.append(row)
- for bbox in row:
- if up_h > bbox[0][1]:
- up_h = bbox[0][1]
- if bottom_h < bbox[2][1]:
- bottom_h = bbox[2][1]
- if left_w > bbox[0][0]:
- left_w = bbox[0][0]
- if right_w < bbox[2][0]:
- right_w = bbox[2][0]
- else:
- if len(tolerance_list) < 3 and table_rows > 0:
- tolerance_list.append(row)
- temp_row_list.append(row)
- continue
- if table_rows > 2 and up_h < bottom_h:
- table_location_list.append([[int(left_w), int(up_h)],
- [int(right_w), int(bottom_h)]])
- if tolerance_list[-1] == temp_row_list[-1]:
- area_row_list.append(temp_row_list[:-1])
- else:
- area_row_list.append(temp_row_list)
- up_h = 10000
- bottom_h = 0
- left_w = 10000
- right_w = 0
- table_rows = 0
- tolerance_list = []
- temp_row_list = []
- if temp_row_list:
- if table_rows > 2 and up_h < bottom_h:
- table_location_list.append([[int(left_w), int(up_h)],
- [int(right_w), int(bottom_h)]])
- area_row_list.append(temp_row_list)
- return table_location_list, area_row_list
- def get_table_rows(bbox_list, bbox_text_dict):
- bbox_list.sort(key=lambda x: (x[0][1], x[2][1], x[0][0], x[2][0]))
- row_list = []
- used_bbox_list = []
- for b1 in bbox_list:
- if b1 in used_bbox_list:
- continue
- temp_bbox_list = [b1]
- used_bbox_list.append(b1)
- for b2 in bbox_list:
- if b2 in used_bbox_list:
- continue
- if abs((b1[0][1] + b1[2][1]) / 2 - (b2[0][1] + b2[2][1]) / 2) <= 10 \
- and line_overlap(b1[0][1], b1[2][1], b2[0][1], b2[2][1]) >= 1/2*min(b1[2][1]-b1[0][1], b2[2][1]-b2[0][1]):
- temp_bbox_list.append(b2)
- used_bbox_list.append(b2)
- row_list.append(temp_bbox_list)
- return row_list
- def get_table_rows2(area_row_list, table_location_list):
- temp_area_row_list = []
- for area in area_row_list:
- temp_area_row_list += area
- area_row_list = []
- for location in table_location_list:
- row_list = []
- for row in temp_area_row_list:
- if location[0][1] <= row[0][0][1] <= row[0][2][1] <= location[1][1]:
- row_list.append(row)
- area_row_list.append(row_list)
- return area_row_list
- def get_table_bbox_row_or_col(bbox_list, axis=0):
- bbox_list.sort(key=lambda x: (x[0][1-axis], x[2][1-axis], x[0][axis], x[2][axis]))
- row_list = []
- used_bbox_list = []
- for b1 in bbox_list:
- if b1 in used_bbox_list:
- continue
- temp_bbox_list = [b1]
- used_bbox_list.append(b1)
- for b2 in bbox_list:
- if b2 in used_bbox_list:
- continue
- if abs((b1[0][1-axis] + b1[2][1-axis]) / 2 - (b2[0][1-axis] + b2[2][1-axis]) / 2) <= 10:
- temp_bbox_list.append(b2)
- used_bbox_list.append(b2)
- row_list.append(temp_bbox_list)
- return row_list
- def get_table_cols(bbox_list, table_location_list):
- bbox_list.sort(key=lambda x: (x[0][0], x[2][0], x[0][1], x[2][1]))
- all_col_list = []
- used_bbox_list = []
- for location in table_location_list:
- sub_bbox_list = []
- for b in bbox_list:
- if location[0][1] <= (b[0][1] + b[2][1])/2 <= location[1][1]:
- sub_bbox_list.append(b)
- col_list = []
- for b1 in sub_bbox_list:
- if b1 in used_bbox_list:
- continue
- col_width = [b1[0][0], b1[2][0]]
- temp_bbox_list = [b1]
- used_bbox_list.append(b1)
- for b2 in sub_bbox_list:
- if b2 in used_bbox_list:
- continue
- # 判断同一列
- # 1. 中心点相差一定范围内
- # 2. 左边点相差一定范围内
- # 3. 行范围包含
- # 4. iou大于一定值
- if abs((b1[0][0] + b1[2][0]) / 2 - (b2[0][0] + b2[2][0]) / 2) <= 10 \
- or abs(b1[0][0] - b2[0][0]) <= 10 \
- or col_width[0] <= b2[0][0] <= b2[2][0] <= col_width[1] \
- or b2[0][0] <= col_width[0] <= col_width[1] <= b2[2][0] \
- or line_iou([[col_width[0], 0], [col_width[1], 0]], [[b2[0][0], 0], [b2[1][0], 0]], axis=0) >= 0.6:
- temp_bbox_list.append(b2)
- used_bbox_list.append(b2)
- if b2[0][0] < col_width[0]:
- col_width[0] = b2[0][0]
- if b2[2][0] > col_width[1]:
- col_width[1] = b2[2][0]
- col_list.append(temp_bbox_list)
- all_col_list.append(col_list)
- return all_col_list
- def merge_col_bbox_by_cluster(img, area_row_list, area_col_list, bbox_text_dict, all_bbox_list, table_location_list):
- temp_img = copy.deepcopy(img)
- # 循环每个表格
- for i in range(len(area_row_list)):
- row = area_row_list[i]
- col = area_col_list[i]
- # 循环每一列,计算列中行之间的间隔距离
- new_col = []
- col_cnt = 0
- for bbox_list in col:
- # 获取间隔距离
- distance_list = []
- bbox_list.sort(key=lambda x: (x[0][1], x[1][1]))
- text_list = [bbox_text_dict.get(str(x)) for x in bbox_list]
- for j in range(1, len(bbox_list)):
- dis = bbox_list[j][0][1] - bbox_list[j-1][2][1]
- if dis < 0:
- dis = 0.
- distance_list.append(dis)
- print("\n")
- print("distance_list", distance_list)
- # 聚类获取类别组
- data_list = [[0, x] for x in distance_list]
- # 排除距离大于一定值的
- data_mask_list = []
- temp_data_list = []
- for j in range(len(data_list)):
- if data_list[j][1] < 5.:
- data_mask_list.append(True)
- temp_data_list.append(data_list[j])
- else:
- data_mask_list.append(False)
- data_list = temp_data_list
- print("data_list", data_list)
- cluster_list = []
- if len(data_list) > 2:
- # 聚类
- pred_list = dbscan(data_list)
- print('pred_list', pred_list)
- temp_pred_list = []
- for j in data_mask_list:
- if j:
- temp_pred_list.append(pred_list.pop(0))
- else:
- temp_pred_list.append(-1)
- pred_list = temp_pred_list
- print('pred_list', pred_list)
- cluster_num = len(list(set(pred_list)))
- for k in range(cluster_num):
- temp_list = []
- for j in range(len(pred_list)):
- if pred_list[j] == k:
- if temp_list:
- if j - temp_list[-1] == 1:
- temp_list.append(j)
- else:
- temp_list.append(j)
- else:
- if temp_list:
- cluster_list.append(temp_list)
- temp_list = []
- if temp_list:
- cluster_list.append(temp_list)
- elif len(data_list) > 0:
- temp_list = []
- for j in range(len(distance_list)):
- if distance_list[j] < 5.0:
- temp_list.append(j)
- else:
- if temp_list:
- cluster_list.append(temp_list)
- temp_list = []
- if temp_list:
- cluster_list.append(temp_list)
- # cluster_list.append([x for x in range(len(distance_list))])
- print('text_list', text_list)
- print('cluster_list', cluster_list)
- # 合并bbox
- new_bbox_list = copy.deepcopy(bbox_list)
- for cluster in cluster_list:
- merge_flag = 1
- for dis in [distance_list[x] for x in cluster]:
- if dis >= 5.0:
- merge_flag = 0
- break
- if merge_flag:
- b_list = bbox_list[cluster[0]:cluster[-1]+2]
- t_list = text_list[cluster[0]:cluster[-1]+2]
- min_w = 10000
- max_w = 0
- min_h = 10000
- max_h = 0
- b_list = [eval(x) for x in list(set([str(x) for x in b_list]))]
- for bbox in b_list:
- if bbox in new_bbox_list:
- new_bbox_list.remove(bbox)
- if bbox in all_bbox_list:
- all_bbox_list.remove(bbox)
- if bbox[0][0] < min_w:
- min_w = bbox[0][0]
- if bbox[0][1] < min_h:
- min_h = bbox[0][1]
- if bbox[2][0] > max_w:
- max_w = bbox[2][0]
- if bbox[2][1] > max_h:
- max_h = bbox[2][1]
- new_bbox = [[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]]
- new_bbox_list.append(new_bbox)
- all_bbox_list.append(new_bbox)
- # 根据第一列的合并结果,指导其他列合并
- if col_cnt == 0:
- first_col_rows = get_first_col_rows(new_bbox_list, table_location_list[i])
- for r in first_col_rows:
- cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (0, 0, 255), 1)
- cv2.imshow('temp_img', temp_img)
- # cv2.waitKey(0)
- col_cnt += 1
- # new_col.append(new_bbox_list)
- return all_bbox_list
- def merge_col_bbox_by_block(img, area_row_list, area_col_list, bbox_text_dict, bbox_list, table_location_list):
- temp_img = copy.deepcopy(img)
- # 循环每个表格
- for i in range(len(area_row_list)):
- row_list = area_row_list[i]
- col_list = area_col_list[i]
- table_location = table_location_list[i]
- sub_bbox_list = []
- for bbox in bbox_list:
- if table_location[0][1] <= bbox[0][1] <= table_location[1][1] \
- or table_location[0][1] <= bbox[1][1] <= table_location[1][1]:
- sub_bbox_list.append(bbox)
- # 对第一列聚类,合并,再根据空白分行
- first_col = col_list[0]
- cluster_list, distance_list = distance_cluster(first_col, axis=1)
- merge_first_col = merge_cluster(first_col, cluster_list, distance_list)
- merge_first_col.sort(key=lambda x: (x[0][1], x[0][0]))
- row_lines = get_first_col_rows(merge_first_col, table_location)
- # 对其他列聚类,合并
- # merge_bbox_list = [] + first_col
- # for col in col_list[1:]:
- # cluster_list = distance_cluster(col, axis=1)
- # merge_col = merge_cluster(col, cluster_list)
- # merge_bbox_list += merge_col
- # 循环每一列,根据分行合并
- new_row_list = []
- row_lines.sort(key=lambda x: x)
- row_cnt = 0
- need_add_bbox = []
- # for c in first_col:
- # print('first col ', bbox_text_dict.get(str(c)))
- for j in range(1, len(row_lines)):
- print('\n')
- top_line = row_lines[j-1]
- bottom_line = row_lines[j]
- new_row = []
- if need_add_bbox:
- # print('add')
- new_row += need_add_bbox
- print('add', bbox_text_dict.get(str(new_row[0])))
- need_add_bbox = []
- # 合并条件:
- # 1. 完全包含
- # 2. 处在两行之间,判断bbox与第一列的这两行的bbox高度距离
- for bbox in sub_bbox_list:
- if top_line <= bbox[0][1] <= bbox[2][1] <= bottom_line:
- new_row.append(bbox)
- # print('bbox, line', bbox_text_dict.get(str(bbox)), top_line, bottom_line)
- else:
- if bbox in first_col:
- continue
- # 如果第一列只有一行,交界处的bbox不算
- if len(first_col) == 1:
- need_add_bbox.append(bbox)
- continue
- # 计算离该bbox最近的上下两个第一列的bbox
- first_col_center_h1 = 0
- first_col_center_h2 = 10000
- first_col_bbox1 = None
- first_col_bbox2 = None
- bbox_center_h = (bbox[0][1] + bbox[2][1]) / 2
- for b in first_col:
- b_center_h = (b[0][1] + b[2][1]) / 2
- # if bbox[0][1] <= b_center_h <= bbox[2][1]:
- # first_col_center_h2 = b_center_h
- # break
- if bbox_center_h >= b_center_h and bbox_center_h - b_center_h <= bbox_center_h - first_col_center_h1:
- first_col_center_h1 = b_center_h
- first_col_bbox1 = b
- if b_center_h >= bbox_center_h and b_center_h - bbox_center_h <= first_col_center_h2 - bbox_center_h:
- first_col_center_h2 = b_center_h
- first_col_bbox2 = b
- # 如果离该bbox最近的第一列的bbox,不是这一行的
- if new_row and first_col_bbox1 != new_row[0] and top_line < bbox[0][1] < bottom_line:
- need_add_bbox.append(bbox)
- continue
- # if top_line <= bbox[2][1] <= bottom_line \
- # and abs(first_col_center_h1 - bbox_center_h) >= abs(first_col_center_h2 - bbox_center_h):
- # new_row.append(bbox)
- # if first_col_bbox1 and first_col_bbox2:
- # print('bbox1, bbox2', bbox_text_dict[str(first_col_bbox1)], bbox_text_dict[str(first_col_bbox2)],
- # bbox_text_dict[str(bbox)])
- if top_line < bbox[0][1] < bottom_line \
- and abs(first_col_center_h1 - bbox_center_h) <= abs(first_col_center_h2 - bbox_center_h):
- new_row.append(bbox)
- elif top_line < bbox[0][1] < bottom_line:
- need_add_bbox.append(bbox)
- for r in need_add_bbox:
- print("next_row bbox", bbox_text_dict.get(str(r)))
- print('row', row_cnt, len(new_row))
- for b in new_row:
- print(bbox_text_dict.get(str(b)))
- row_cnt += 1
- new_row_list.append(new_row)
- area_row_list[i] = new_row_list
- # show
- r_cnt = 0
- # for r in row_lines:
- # if r_cnt == 0 or r_cnt == len(row_lines) - 1:
- # cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (255, 0, 0), 1)
- # else:
- # cv2.line(temp_img, (0, int(r)), (temp_img.shape[1], int(r)), (0, 255, 0), 1)
- # r_cnt += 1
- # for b in merge_bbox_list:
- # cv2.rectangle(temp_img, [int(b[0][0]), int(b[0][1])], [int(b[2][0]), int(b[2][1])], (0, 0, 255), 1)
- # cv2.imshow('temp_img', temp_img)
- return area_row_list
- def distance_cluster(bbox_list, max_distance=5., axis=1):
- # 获取间隔距离
- distance_list = []
- bbox_list.sort(key=lambda x: (x[0][1], x[1][1]))
- for j in range(1, len(bbox_list)):
- dis = bbox_list[j][0][axis] - bbox_list[j-1][2][axis]
- if dis < 0:
- dis = 0.
- distance_list.append(dis)
- print("\n")
- print("distance_list", distance_list)
- # 聚类获取类别组
- data_list = [[0, x] for x in distance_list]
- # 排除距离大于一定值的
- data_mask_list = []
- temp_data_list = []
- for j in range(len(data_list)):
- if data_list[j][1] < max_distance:
- data_mask_list.append(True)
- temp_data_list.append(data_list[j])
- else:
- data_mask_list.append(False)
- data_list = temp_data_list
- print("data_list", data_list)
- cluster_list = []
- if len(data_list) > 2:
- # 聚类
- pred_list = dbscan(data_list)
- print('pred_list', pred_list)
- temp_pred_list = []
- for j in data_mask_list:
- if j:
- temp_pred_list.append(pred_list.pop(0))
- else:
- temp_pred_list.append(-1)
- pred_list = temp_pred_list
- print('pred_list', pred_list)
- cluster_num = len(list(set(pred_list)))
- for k in range(cluster_num):
- temp_list = []
- for j in range(len(pred_list)):
- if pred_list[j] == k:
- if temp_list:
- if j - temp_list[-1] == 1:
- temp_list.append(j)
- else:
- temp_list.append(j)
- else:
- if temp_list:
- cluster_list.append(temp_list)
- temp_list = []
- if temp_list:
- cluster_list.append(temp_list)
- elif len(data_list) > 0:
- temp_list = []
- for j in range(len(distance_list)):
- if distance_list[j] < max_distance:
- temp_list.append(j)
- else:
- if temp_list:
- cluster_list.append(temp_list)
- temp_list = []
- if temp_list:
- cluster_list.append(temp_list)
- print('cluster_list', cluster_list)
- return cluster_list, distance_list
- def merge_cluster(bbox_list, cluster_list, distance_list):
- new_bbox_list = copy.deepcopy(bbox_list)
- # 特殊情况:每行之间空隙小,且规律,会全被分到一个类
- if len(cluster_list) == 1 and len(cluster_list[0]) >= 4:
- cluster_list = [[x] for x in cluster_list[0]]
- # 每行的空隙小且均匀
- if distance_list:
- if max(distance_list) - min(distance_list) <= 5.5:
- cluster_list = [[i] for i in range(len(distance_list))]
- # 去掉一个最大值,其他的空隙小且均匀
- if distance_list and max(distance_list) - min(distance_list) >= 10:
- index = distance_list.index(max(distance_list))
- if index <= 2 and len(distance_list[index+1:]) >= 3 and max(distance_list[index+1:]) - min(distance_list[index+1:]) <= 5.5:
- if index == 0:
- cluster_list = [[i] for i in range(len(distance_list[index+1:]))]
- else:
- if max(distance_list[:index]) - min(distance_list[:index]) <= 5.5:
- cluster_list = [[i] for i in range(len(distance_list[:index]))]
- cluster_list += [[i] for i in range(len(distance_list[index+1:]))]
- for cluster in cluster_list:
- b_list = bbox_list[cluster[0]:cluster[-1]+2]
- min_w = 10000
- max_w = 0
- min_h = 10000
- max_h = 0
- b_list = [eval(x) for x in list(set([str(x) for x in b_list]))]
- for bbox in b_list:
- if bbox in new_bbox_list:
- new_bbox_list.remove(bbox)
- if bbox[0][0] < min_w:
- min_w = bbox[0][0]
- if bbox[0][1] < min_h:
- min_h = bbox[0][1]
- if bbox[2][0] > max_w:
- max_w = bbox[2][0]
- if bbox[2][1] > max_h:
- max_h = bbox[2][1]
- new_bbox = [[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]]
- new_bbox_list.append(new_bbox)
- return new_bbox_list
- def get_first_col_rows(first_col, table_location):
- """
- 根据第一列的bbox,分行
- :return:
- """
- location_top = table_location[0][1]
- location_bottom = table_location[1][1]
- row_block_list = [table_location[0][1]]
- for i in range(len(first_col)):
- bbox = first_col[i]
- if i + 1 < len(first_col):
- next_bbox = first_col[i+1]
- bbox_distance = abs(bbox[2][1] - next_bbox[0][1])
- else:
- bbox_distance = 10000
- if i == 0:
- top_block = abs(bbox[0][1] - location_top)
- bottom_block = min(top_block, bbox_distance)
- sub_row = bbox[2][1] + bottom_block
- else:
- top_block = abs(bbox[0][1] - row_block_list[-1])
- bottom_block = min(top_block, bbox_distance)
- sub_row = bbox[2][1] + bottom_block
- row_block_list.append(sub_row)
- if len(row_block_list) == 2:
- row_block_list.append(location_bottom)
- else:
- row_block_list[-1] = max(row_block_list[-1], location_bottom)
- return row_block_list
- def judge_standard_table(row_list):
- up_h = 10000
- bottom_h = 0
- left_w = 10000
- right_w = 0
- table_rows = 0
- now_row_len = 0
- init_flag = 0
- tolerance_list = []
- area_row_list = []
- temp_row_list = []
- table_location_list = []
- for row in row_list:
- if init_flag:
- up_h = 10000
- bottom_h = 0
- left_w = 10000
- right_w = 0
- table_rows = 0
- tolerance_list = []
- temp_row_list = []
- init_flag = 0
- if len(row) >= 2:
- if now_row_len == 0:
- now_row_len = len(row)
- else:
- if len(row) != now_row_len:
- init_flag = 1
- continue
- table_rows += 1
- temp_row_list.append(row)
- for bbox in row:
- if up_h > bbox[0][1]:
- up_h = bbox[0][1]
- if bottom_h < bbox[2][1]:
- bottom_h = bbox[2][1]
- if left_w > bbox[0][0]:
- left_w = bbox[0][0]
- if right_w < bbox[2][0]:
- right_w = bbox[2][0]
- else:
- if len(tolerance_list) < 1 and table_rows > 0:
- tolerance_list.append(row)
- temp_row_list.append(row)
- continue
- if table_rows > 1 and up_h < bottom_h:
- table_location_list.append([[int(left_w), int(up_h)],
- [int(right_w), int(bottom_h)]])
- if tolerance_list[-1] == temp_row_list[-1]:
- area_row_list.append(temp_row_list[:-1])
- else:
- area_row_list.append(temp_row_list)
- init_flag = 1
- return table_location_list, area_row_list
- def split_bbox(img, bbox, bbox_text_dict):
- text = bbox_text_dict.get(str(bbox))
- sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
- split_line_list = []
- last_i_status = 1
- # 从左到右遍历img
- for i in range(1, sub_img.shape[1]):
- # 若这一列黑色像素超过一定值
- if np.where(sub_img[:, i, :] < 200)[0].size > sub_img.shape[0]/5:
- i_status = 0
- else:
- i_status = 1
- # 异或,上个像素列为黑且这个像素列为白,或上个像素列为白且这个像素列为黑
- if last_i_status ^ i_status:
- split_line_list.append(int(i))
- last_i_status = i_status
- # 两条分割线太近的去重
- min_len = 5
- last_l = split_line_list[0]
- temp_list = [split_line_list[0]]
- for l in split_line_list[1:]:
- if l - last_l > min_len:
- temp_list.append(l)
- last_l = l
- split_line_list = temp_list
- # 若两个分割线间无黑像素,则是应该分割的
- split_pair_list = []
- last_line = split_line_list[0]
- for line in split_line_list[1:]:
- print('last_line, line', last_line, line, np.where(sub_img[:, last_line:line, :] < 100)[0].size)
- if line - last_line >= 10 and np.where(sub_img[:, last_line:line, :] < 100)[0].size < 10:
- split_pair_list.append([last_line, line])
- last_line = line
- print('split_pair_list', split_pair_list)
- for l in split_line_list:
- l = int(l + bbox[0][0])
- cv2.line(img, (l, int(bbox[0][1])), (l, int(bbox[2][1])), (0, 255, 0), 2)
- cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])),
- (0, 0, 255), 1)
- cv2.imshow('img', img)
- cv2.waitKey(0)
- # 分割得到新bbox
- split_bbox_list = []
- if split_pair_list:
- start_line = 0
- for line1, line2 in split_pair_list:
- w1 = start_line + bbox[0][0]
- w2 = line1 + bbox[0][0]
- start_line = line2
- split_bbox_list.append([[w1, bbox[0][1]], [], [w2, bbox[2][1]], []])
- w1 = start_line + bbox[0][0]
- w2 = bbox[2][0]
- split_bbox_list.append([[w1, bbox[0][1]], [], [w2, bbox[2][1]], []])
- print('split_bbox_list', split_bbox_list)
- # 计算每个字长度
- all_len = 0
- bbox_len_list = []
- for bbox in split_bbox_list:
- _len = abs(bbox[2][0] - bbox[0][0])
- all_len += _len
- bbox_len_list.append(_len)
- single_char_len = all_len / len(text)
- # 根据bbox长度和单字长度比例计算得到截取后的text
- split_text_list = []
- text_start = 0
- for _len in bbox_len_list:
- text_num = int(_len / single_char_len + 0.5)
- text_end = text_start+text_num
- if text_end >= len(text):
- text_end = len(text)
- split_text_list.append(text[text_start:text_end])
- text_start = text_end
- print('split_text_list', split_text_list)
- # 更新bbox_text_dict
- for i, bbox in enumerate(split_bbox_list):
- bbox_text_dict[str(bbox)] = split_text_list[i]
- return split_bbox_list, bbox_text_dict
- def split_table(table_location_list, area_row_list, bbox_text_dict):
- temp_location_list = []
- temp_area_row_list = []
- for i in range(len(table_location_list)):
- location = table_location_list[i]
- sub_row_list = area_row_list[i]
- # 截断标题,对只有行中间或行开头一个bbox的行进行排除
- need_split_index = []
- for j in range(len(sub_row_list)):
- row = sub_row_list[j]
- if len(row) == 1:
- threshold = (row[0][2][0]-row[0][0][0])*1/3
- if row[0][0][0] + threshold <= (location[0][0]+location[1][0])/2 <= row[0][2][0] - threshold:
- need_split_index.append(j)
- elif abs(location[0][0] - row[0][0][0]) <= 5 \
- and row[0][2][0] - row[0][0][0] >= 1/5 * (location[1][0]-location[0][0]):
- need_split_index.append(j)
- print('need_split_index', need_split_index)
- if not need_split_index:
- temp_location_list.append(location)
- temp_area_row_list.append(sub_row_list)
- else:
- last_index = 0
- need_split_index.append(len(sub_row_list))
- for index in need_split_index:
- if index == last_index:
- last_index = index + 1
- continue
- if len(sub_row_list[last_index:index]) < 2:
- last_index = index + 1
- continue
- temp_area_row_list.append(sub_row_list[last_index:index])
- min_w = 10000
- max_w = 0
- min_h = 10000
- max_h = 0
- for row in sub_row_list[last_index:index]:
- for bbox in row:
- if bbox[0][0] < min_w:
- min_w = bbox[0][0]
- if bbox[0][1] < min_h:
- min_h = bbox[0][1]
- if bbox[2][0] > max_w:
- max_w = bbox[2][0]
- if bbox[2][1] > max_h:
- max_h = bbox[2][1]
- temp_location_list.append([[int(min_w), int(min_h)], [int(max_w), int(max_h)]])
- last_index = index+1
- return temp_location_list, temp_area_row_list
- def split_table_by_col(table_location_list, area_table_bbox_list, bbox_text_dict):
- for i in range(len(table_location_list)):
- location = table_location_list[i]
- table_bbox_list = area_table_bbox_list[i]
- # 循环每一行
- split_index_list = []
- for j in range(1, len(table_bbox_list)):
- row = table_bbox_list[j]
- last_row = table_bbox_list[j-1]
- row_bbox_cnt_list = [len(x) for x in row]
- last_row_bbox_cnt_list = [len(x) for x in last_row]
- diff_num = 0
- diff_flag = 0
- for k in range(len(row_bbox_cnt_list)):
- if row_bbox_cnt_list[k] > last_row_bbox_cnt_list[k]:
- if last_row_bbox_cnt_list[k] != 0:
- diff_flag = 1
- diff_num += 1
- if diff_num > 0 and diff_flag:
- split_index_list.append(j)
- continue
- print('split_index_list', split_index_list)
- return table_location_list
- # def split_table_by_table_head(table_location_list, area_table_bbox_list, bbox_text_dict):
- # new_table_location_list = []
- # for i in range(len(table_location_list)):
- # location = table_location_list[i]
- # table_bbox_list = area_table_bbox_list[i]
- #
- # # 每行单独进行表头预测
- # table_head_row_list = []
- # for j in range(len(table_bbox_list)):
- # row = table_bbox_list[j]
- # print('row', row)
- #
- # if row.count([]) == len(row):
- # table_head_row_list.append([['', 0]])
- # continue
- #
- # row_bbox_list = []
- # for col in row:
- # for b in col:
- # new_b = bbox_text_dict.get(str(b))
- # new_b = re.sub("^[^\u4e00-\u9fa5a-zA-Z0-9]+", "", new_b)
- # new_b = re.sub("[^\u4e00-\u9fa5a-zA-Z0-9]+$", "", new_b)
- # row_bbox_list.append(new_b)
- # result_list = predict([row_bbox_list])
- # # 组合结果
- # for m in range(len(result_list)):
- # for n in range(len(result_list[m])):
- # result_list[m][n] = [row_bbox_list[n], int(result_list[m][n])]
- # result_list = result_list[0]
- # print('table_head', result_list)
- # table_head_row_list.append(result_list)
- #
- # # 根据表头分割
- # split_index_list = []
- # for j in range(1, len(table_head_row_list)):
- # row_head = [x[1] for x in table_head_row_list[j]]
- # last_row_head = [x[1] for x in table_head_row_list[j-1]]
- #
- # # [['6', 0], ['税费', 0], ['依法缴纳', 0], ['1', 0], ['次', 0], ['25000', 0], ['25000', 0]]
- # # [['大写', 1], ['肆抢柒万元整', 0]]
- # if 1 in row_head and 1 not in last_row_head:
- # split_index_list.append(j)
- #
- # # [['供应商', 1], ['广东一线达通网络科技有限公司', 0]]
- # # [['货物明细', 1], ['单价金额(元', 1], ['数量', 1], ['总计金额(元', 1]]
- # if 1 in row_head and 1 in last_row_head and 0 not in row_head and row_head.count(1) != last_row_head.count(1):
- # split_index_list.append(j)
- # print('split_index_list', split_index_list)
- #
- # new_location_list = table_split_by_index(location, split_index_list, table_bbox_list)
- # print('new_location_list, location', new_location_list, location)
- # new_table_location_list += new_location_list
- # print('new_table_location_list', new_table_location_list)
- # return new_table_location_list
- def table_split_by_index(table_location, split_index_list, table_bbox_list):
- if split_index_list:
- # 分割表格
- split_index_list = [0] + split_index_list + [len(table_bbox_list)]
- split_index_list = list(set(split_index_list))
- split_index_list.sort(key=lambda x: x)
- print('split_index_list', split_index_list)
- new_location_list = []
- for l in range(1, len(split_index_list)):
- index = split_index_list[l]
- last_index = split_index_list[l-1]
- # if index - last_index <= 2:
- # continue
- # 获取范围
- rows = table_bbox_list[last_index:index]
- min_h, min_w = 10000, 10000
- max_h, max_w = 0, 0
- for r in rows:
- for c in r:
- for b in c:
- if b:
- if b[0][0] < min_w:
- min_w = int(b[0][0])
- if b[0][1] < min_h:
- min_h = int(b[0][1])
- if b[2][0] > max_w:
- max_w = int(b[2][0])
- if b[2][1] > max_h:
- max_h = int(b[2][1])
- new_location = [[min_w, min_h], [max_w, max_h]]
- new_location_list.append(new_location)
- print('new_location', new_location)
- if new_location_list:
- return new_location_list
- else:
- return [table_location]
- else:
- return [table_location]
- def split_table_new(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict):
- temp_location_list = []
- temp_area_row_list = []
- for k in range(len(table_location_list)):
- table = area_table_bbox_list[k]
- location = table_location_list[k]
- row_list = area_row_list[k]
- table_cell_list = area_table_cell_list[k]
- split_row_index_list = []
- # 遍历所有行
- for i in range(len(table)):
- row = table[i]
- # print('row', i)
- # for j in range(len(row)):
- # col = row[j]
- # print('col', j, ';'.join([bbox_text_dict.get(str(x)) for x in col]))
- # 判断该行为表格分割行:
- # 1. 这一行只有一列有值,且这一行的这一列中的最长的bbox,包含其他行同列中的2个以上bbox
- # 2. 这一行只有一列有值,且这一行的这一列中的最长的bbox,跨过了多列
- # 获取前后多行
- n = 3
- if i-n < 0:
- last_n_rows = table[0:i]
- else:
- last_n_rows = table[i-n:i]
- if i+1 >= len(table):
- next_n_rows = []
- elif i+n+1 >= len(table):
- next_n_rows = table[i+1:len(table)]
- else:
- next_n_rows = table[i+1:i+n+1]
- # 寻找一行只有一格有数据的
- not_empty_col_cnt = 0
- only_one_index = -1
- for j in range(len(row)):
- col = row[j]
- if col:
- not_empty_col_cnt += len(col)
- only_one_index = j
- if not_empty_col_cnt == 1:
- print('only_one_index, i', only_one_index, i)
- # 对比前后n行的同一列数据
- for r in last_n_rows+next_n_rows:
- col = r[only_one_index]
- if len(col) > 1:
- print('col', [bbox_text_dict.get(str(x)) for x in col])
- # 找出其他行同一单元格中包含多个横向排列bbox的
- col_bbox_list = [col[0]]
- for bbox in col:
- for j in range(len(col_bbox_list)):
- bbox1 = col_bbox_list[j]
- if bbox1[0][0] <= bbox[0][0] <= bbox[2][0] <= bbox1[2][0]:
- col_bbox_list[j] = bbox
- elif bbox[0][0] <= bbox1[0][0] <= bbox1[2][0] <= bbox[2][0]:
- continue
- else:
- col_bbox_list.append(bbox)
- if len(col_bbox_list) > 1:
- # 找出这一行同列最长的bbox,有没有包含其他行同列的多个bbox
- col = row[only_one_index]
- print('long col', [bbox_text_dict.get(str(x)) for x in col])
- col.sort(key=lambda x: abs(x[2][0]-x[0][0]))
- longest_bbox = col[-1]
- contain_cnt = 0
- cross_cnt = 0
- for bbox in col_bbox_list:
- if longest_bbox[0][0] <= bbox[0][0] <= bbox[2][0] <= longest_bbox[2][0]:
- contain_cnt += 1
- if bbox[0][0] < longest_bbox[0][0] < bbox[2][0] or bbox[0][0] < longest_bbox[2][0] < bbox[2][0]:
- cross_cnt += 1
- print('cross_cnt', cross_cnt)
- if contain_cnt >= 2 or cross_cnt >= 2:
- print('包含多个横向排列bbox', i)
- split_row_index_list.append(i)
- # 看这一行这一列最长bbox有无跨单元格
- col = row[only_one_index]
- col.sort(key=lambda x: abs(x[2][0]-x[0][0]))
- longest_bbox = col[-1]
- cell_row = table_cell_list[i]
- cell_col = cell_row[only_one_index]
- threshold = 15
- if cell_col[0][0]-threshold <= longest_bbox[0][0] <= longest_bbox[2][0] <= cell_col[1][0]+threshold:
- pass
- else:
- print('最长bbox跨单元格', i)
- split_row_index_list.append(i)
- if split_row_index_list:
- # 分割表格
- split_row_index_list.insert(0, -1)
- split_row_index_list.insert(len(split_row_index_list), len(table))
- split_row_index_list = list(set(split_row_index_list))
- split_row_index_list.sort(key=lambda x: x)
- print('split_row_index_list', split_row_index_list, len(table))
- for l in range(1, len(split_row_index_list)):
- index = split_row_index_list[l]
- last_index = split_row_index_list[l-1]
- if index - last_index <= 2:
- continue
- start_row_index = last_index+1
- end_row_index = index-1
- start_row = table[last_index+1]
- end_row = table[index-1]
- start_row = [x for y in start_row for x in y]
- end_row = [x for y in end_row for x in y]
- start_row = list(filter(lambda x: x != [], start_row))
- end_row = list(filter(lambda x: x != [], end_row))
- if not start_row:
- start_row_index = last_index + 2
- start_row = table[start_row_index]
- start_row = [x for y in start_row for x in y]
- start_row = list(filter(lambda x: x != [], start_row))
- if not end_row:
- end_row_index = index - 2
- end_row = table[end_row_index]
- end_row = [x for y in end_row for x in y]
- end_row = list(filter(lambda x: x != [], end_row))
- if not start_row or not end_row or end_row_index-start_row_index < 1:
- continue
- start_row.sort(key=lambda x: x[0][1])
- min_h = int(start_row[0][0][1])
- min_w = location[0][0]
- end_row.sort(key=lambda x: x[2][1])
- max_h = int(end_row[-1][2][1])
- max_w = location[1][0]
- new_location = [[min_w, min_h], [max_w, max_h]]
- temp_location_list.append(new_location)
- temp_area_row_list.append(row_list[last_index+1:index])
- else:
- temp_location_list.append(location)
- temp_area_row_list.append(row_list)
- table_location_list = temp_location_list
- area_row_list = temp_area_row_list
- return table_location_list, area_row_list
- def split_table_new2(table_location_list, area_table_bbox_list, area_table_cell_list, area_row_list, bbox_text_dict):
- temp_location_list = []
- temp_area_row_list = []
- for k in range(len(table_location_list)):
- table = area_table_bbox_list[k]
- location = table_location_list[k]
- row_list = area_row_list[k]
- table_cell_list = area_table_cell_list[k]
- split_row_index_list = []
- # 遍历所有行
- table_start_row_index = 0
- for i in range(len(table)):
- row = table[i]
- # 判断该行为表格分割行:
- # 1. 这一行只有一列有值,且这一行的这一列中的最长的bbox,包含其他行同列中的2个以上bbox
- # 2. 这一行只有一列有值,且这一行的这一列中的最长的bbox,跨过了多列
- # print(i, [bbox_text_dict.get(str(y)) for x in row for y in x])
- # 每次找到分割行,更新
- if table_start_row_index >= len(table):
- break
- # 获取前n行
- n = 2
- if i-n < table_start_row_index:
- last_n_rows = table[table_start_row_index:i]
- else:
- last_n_rows = table[i-n:i]
- # 寻找一行中最长的bbox
- max_len_bbox = []
- for col in row:
- for b in col:
- if not max_len_bbox:
- max_len_bbox = b
- else:
- if b[2][0] - b[0][0] > max_len_bbox[2][0]-max_len_bbox[0][0]:
- max_len_bbox = b
- # 对比前n行的数据
- for r in last_n_rows:
- b_list = [y for x in r for y in x]
- # 第n行中的非上下重合的bbox
- temp_b_list = []
- for b in b_list:
- if not temp_b_list:
- temp_b_list.append(b)
- else:
- find_flag = 0
- for tb in temp_b_list:
- if line_overlap(tb[0][0], tb[2][0], b[0][0], b[2][0]) > 0:
- find_flag = 1
- break
- if not find_flag:
- temp_b_list.append(b)
- b_list = temp_b_list
- if len(b_list) > 1 and max_len_bbox:
- # 最长bbox是否包含第n行多个bbox
- contain_cnt = 0
- for b in b_list:
- threshold = (b[2][0]-b[0][0])/4
- if max_len_bbox[0][0] <= b[0][0] <= b[2][0] <= max_len_bbox[2][0]:
- contain_cnt += 1
- if b[0][0]+threshold < max_len_bbox[0][0] < b[2][0]-threshold \
- or b[0][0]+threshold < max_len_bbox[2][0] < b[2][0]-threshold:
- contain_cnt += 1
- # print('contain_cnt', contain_cnt)
- if contain_cnt >= 2:
- # print('包含多个横向排列bbox', i)
- split_row_index_list.append(i)
- table_start_row_index = i+1
- if split_row_index_list:
- # 分割表格
- split_row_index_list.insert(0, -1)
- split_row_index_list.insert(len(split_row_index_list), len(table))
- split_row_index_list = list(set(split_row_index_list))
- split_row_index_list.sort(key=lambda x: x)
- print('split_row_index_list', split_row_index_list, len(table))
- for l in range(1, len(split_row_index_list)):
- index = split_row_index_list[l]
- last_index = split_row_index_list[l-1]
- if index - last_index <= 2:
- continue
- start_row_index = last_index+1
- end_row_index = index-1
- start_row = table[last_index+1]
- end_row = table[index-1]
- start_row = [x for y in start_row for x in y]
- end_row = [x for y in end_row for x in y]
- start_row = list(filter(lambda x: x != [], start_row))
- end_row = list(filter(lambda x: x != [], end_row))
- if not start_row:
- start_row_index = last_index + 2
- start_row = table[start_row_index]
- start_row = [x for y in start_row for x in y]
- start_row = list(filter(lambda x: x != [], start_row))
- if not end_row:
- end_row_index = index - 2
- end_row = table[end_row_index]
- end_row = [x for y in end_row for x in y]
- end_row = list(filter(lambda x: x != [], end_row))
- if not start_row or not end_row or end_row_index-start_row_index < 1:
- continue
- start_row.sort(key=lambda x: x[0][1])
- min_h = int(start_row[0][0][1])
- min_w = location[0][0]
- end_row.sort(key=lambda x: x[2][1])
- # print('end_row', [bbox_text_dict.get(str(x)) for x in end_row])
- max_h = int(end_row[-1][2][1])
- max_w = location[1][0]
- new_location = [[min_w, min_h], [max_w, max_h]]
- temp_location_list.append(new_location)
- temp_area_row_list.append(row_list[start_row_index:end_row_index+1])
- else:
- temp_location_list.append(location)
- temp_area_row_list.append(row_list)
- table_location_list = temp_location_list
- area_row_list = temp_area_row_list
- return table_location_list, area_row_list
- def delete_not_standard_table(img, area_row_list, area_col_list, table_location_list, bbox_list, bbox_text_dict):
- table_standard_list = []
- for i in range(len(table_location_list)):
- row_list = area_row_list[i]
- col_list = area_col_list[i]
- location = table_location_list[i]
- table_standard = True
- # 1. 只有单行或单列
- if len(row_list) <= 1 or len(col_list) <= 1:
- table_standard = False
- table_standard_list.append(table_standard)
- continue
- # 1. 单个单元格过多bbox
- for row in row_list:
- for col in col_list:
- inter = [j for j in row if j in col]
- inter = [eval(x) for x in list(set([str(x) for x in inter]))]
- if len(inter) >= 8:
- table_standard = False
- break
- # 1. 判断表格中,不在bbox中的黑色像素
- table_black_cnt = count_black(img[location[0][1]:location[1][1], location[0][0]:location[1][0], :])
- bbox_black_cnt = 0
- for bbox in bbox_list:
- if location[0][1] <= bbox[0][1] <= location[1][1]:
- sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
- if sub_img.shape[0] >= 3 and sub_img.shape[1] >= 3:
- bbox_black_cnt += count_black(sub_img)
- print('table_black_cnt, bbox_black_cnt', table_black_cnt, bbox_black_cnt, bbox_black_cnt / table_black_cnt)
- if bbox_black_cnt / table_black_cnt < 0.5:
- table_standard = False
- table_standard_list.append(table_standard)
- print('table_standard_list', table_standard_list)
- return table_standard_list
- def bbox_preprocess(bbox_list, text_list, row_list, bbox_text_dict):
- # 合并同一行中多个单字bbox
- for row in row_list:
- single_bbox_list = []
- row.sort(key=lambda x: x[0][0])
- i = 0
- for bbox in row:
- if len(bbox_text_dict.get(str(bbox))) == 1 and i != len(row) - 1:
- single_bbox_list.append(bbox)
- else:
- if len(single_bbox_list) >= 3:
- if len(bbox_text_dict.get(str(bbox))) == 1:
- single_bbox_list.append(bbox)
- new_bbox = single_bbox_list[0]
- new_text = ""
- single_bbox_list.sort(key=lambda x: x[0][0])
- for b in single_bbox_list:
- new_bbox = [[new_bbox[0][0], new_bbox[0][1]],
- [b[2][0], new_bbox[0][1]],
- [b[2][0], b[2][1]],
- [new_bbox[0][0], b[2][1]],
- ]
- bbox_list.remove(b)
- new_text += bbox_text_dict.get(str(b))
- text_list.remove(bbox_text_dict.get(str(b)))
- # print('new_bbox, new_text', new_bbox, new_text)
- bbox_list.append(new_bbox)
- text_list.append(new_text)
- bbox_text_dict[str(new_bbox)] = new_text
- single_bbox_list = []
- i += 1
- return bbox_list, text_list, bbox_text_dict
- def merge_table(area_row_list, area_col_list, table_location_list, bbox_list):
- table_location_list.sort(key=lambda x: x[0][1])
- merge_index_list = []
- temp_merge_list = []
- for i in range(1, len(table_location_list)):
- last_col_list = area_col_list[i-1]
- col_list = area_col_list[i]
- last_location = table_location_list[i-1]
- location = table_location_list[i]
- merge_flag = 0
- # 获取每个列的宽度
- col_width_list = []
- for col in col_list:
- col.sort(key=lambda x: x[0][0])
- min_w = col[0][0][0]
- col.sort(key=lambda x: x[2][0])
- max_w = col[-1][2][0]
- col_width_list.append([min_w, max_w])
- # 获取两个表格之间的bbox,判断bbox是否跨越多列
- threshold = 5
- merge_flag2 = 1
- for bbox in bbox_list:
- if last_location[1][1]-threshold <= bbox[0][1] <= bbox[2][1] <= location[0][1]+threshold:
- if bbox[0][0] < col_width_list[0][0] or bbox[2][0] > col_width_list[-1][1]:
- merge_flag2 = 0
- break
- for w in col_width_list:
- if w[0] <= bbox[0][0] <= w[1] and bbox[2][0] - bbox[0][0] > w[1] - w[0]:
- merge_flag2 = 0
- break
- # if location[0][1] - last_location[1][1] <= 20:
- if merge_flag2:
- if len(last_col_list) == len(col_list):
- temp_merge_list += [i-1, i]
- merge_flag = 1
- if not merge_flag:
- if temp_merge_list:
- merge_index_list.append(temp_merge_list)
- else:
- merge_index_list.append([i-1])
- temp_merge_list = []
- if temp_merge_list:
- merge_index_list.append(temp_merge_list)
- else:
- merge_index_list.append([len(table_location_list)-1])
- # print('merge_index_list', merge_index_list)
- if not merge_index_list:
- return area_row_list, area_col_list, table_location_list
- new_table_location_list = []
- new_area_row_list = []
- new_area_col_list = []
- for index_list in merge_index_list:
- if not table_location_list:
- break
- index_list = list(set(index_list))
- temp_table = table_location_list[index_list[0]]
- new_area_row_list.append(area_row_list[index_list[0]])
- new_area_col_list.append(area_col_list[index_list[0]])
- for index in index_list[1:]:
- temp_table = [[min(temp_table[0][0], table_location_list[index][0][0]),
- min(temp_table[0][1], table_location_list[index][0][1])],
- [max(temp_table[1][0], table_location_list[index][1][0]),
- max(temp_table[1][1], table_location_list[index][1][1])]
- ]
- new_area_row_list[-1] += area_row_list[index]
- new_area_col_list[-1] += area_col_list[index]
- new_table_location_list.append(temp_table)
- return new_area_row_list, new_area_col_list, new_table_location_list
- def add_col_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict):
- """
- 对单个单元格内多列的,增加列线
- :return:
- """
- add_area_col_lines = []
- for i in range(len(table_location_list)):
- row_list = area_row_list[i]
- col_list = area_col_list[i]
- location = table_location_list[i]
- add_col_lines = []
- new_col_list = []
- for col in col_list:
- row_cnt = 0
- new_row_list = []
- cell_col_lines = []
- col.sort(key=lambda x: (x[0][1], x[0][0]))
- # print('col')
- for row in row_list:
- row.sort(key=lambda x: (x[0][0], x[0][1]))
- inter = [j for j in row if j in col]
- inter = [eval(x) for x in list(set([str(x) for x in inter]))]
- inter.sort(key=lambda x: (x[0][1], x[0][0]))
- new_row = []
- print('inter', [bbox_text_dict.get(str(x)) for x in inter])
- # if inter:
- # # 先将同个单元格内上下重叠的bbox合并
- # temp_inter = []
- #
- # used_bbox_list = []
- # for bbox1 in inter:
- # if bbox1 in used_bbox_list:
- # continue
- # temp_merge_bbox = [bbox1]
- # for bbox2 in inter:
- # if bbox2 in used_bbox_list:
- # continue
- # if line_overlap(bbox1[0][0], bbox1[2][0], bbox2[0][0], bbox2[2][0]) >= 2/3 * min(bbox1[2][0]-bbox1[0][0], bbox2[2][0], bbox2[0][0]) \
- # and line_overlap(bbox1[0][1], bbox1[2][1], bbox2[0][1], bbox2[2][1]) > 0:
- # temp_merge_bbox += [bbox1, bbox2]
- # used_bbox_list += [bbox1, bbox2]
- # temp_merge_bbox = [eval(y) for y in list(set([str(x) for x in temp_merge_bbox]))]
- # temp_inter.append(temp_merge_bbox)
- #
- # inter = []
- # for m_bbox in temp_inter:
- # min_w, min_h, max_w, max_h = 10000, 10000, 0, 0
- # temp_text = ""
- # for bbox in m_bbox:
- # if bbox[0][0] < min_w:
- # min_w = bbox[0][0]
- # if bbox[0][1] < min_h:
- # min_h = bbox[0][1]
- # if bbox[2][0] > max_w:
- # max_w = bbox[2][0]
- # if bbox[2][1] > max_h:
- # max_h = bbox[2][1]
- # temp_text += bbox_text_dict.get(str(bbox)) + ' '
- # inter.append([[min_w, min_h], [max_w, min_h], [max_w, max_h], [min_w, max_h]])
- # bbox_text_dict[str(inter[-1])] = temp_text
- # print('merge inter', [bbox_text_dict.get(str(x)) for x in inter])
- # 一个单元格内多个bbox
- if len(inter) > 1:
- # 单元格内分行
- cell_row = []
- temp_row = [inter[0]]
- row_len = [inter[0][0][1], inter[0][2][1]]
- for bbox in inter[1:]:
- temp_bbox = temp_row[0]
- bbox_h_len = bbox[2][1] - bbox[0][1]
- temp_bbox_h_len = temp_bbox[2][1] - temp_bbox[0][1]
- # if temp_bbox[0][1]-5 <= bbox[0][1] <= bbox[2][1] <= temp_bbox[2][1]+5 \
- # or bbox[0][1]-5 <= temp_bbox[0][1] <= temp_bbox[2][1] <= bbox[2][1]+5 \
- if line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]) >= 1/3 * min(bbox_h_len, temp_bbox_h_len):
- temp_row.append(bbox)
- row_len[0] = min(row_len[0], bbox[0][1])
- row_len[1] = max(row_len[1], bbox[2][1])
- # print('in row', bbox_text_dict.get(str(bbox)), bbox_text_dict.get(str(temp_bbox)),
- # row_len[0], row_len[1], bbox[0][1], bbox[2][1],
- # line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]),
- # 1/3 * min(bbox_h_len, temp_bbox_h_len))
- else:
- # print('not in row', bbox_text_dict.get(str(bbox)), bbox_text_dict.get(str(temp_bbox)),
- # line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]), 1/3 * min(bbox_h_len, temp_bbox_h_len))
- # print(bbox_text_dict.get(str(bbox)), temp_bbox[2][1] - bbox[0][1], 1/2 * min(bbox_h_len, temp_bbox_h_len),
- # bbox[2][1] - temp_bbox[0][1], 1/2 * min(bbox_h_len, temp_bbox_h_len),
- # line_overlap(temp_bbox[0][1], temp_bbox[2][1], bbox[0][1], bbox[2][1]), 1/3 * min(bbox_h_len, temp_bbox_h_len),
- # temp_bbox[0][1], temp_bbox[2][1], bbox[0][1], bbox[2][1], bbox_text_dict.get(str(temp_bbox)))
- cell_row.append(temp_row)
- temp_row = [bbox]
- row_len = [bbox[0][1], bbox[2][1]]
- if temp_row:
- cell_row.append(temp_row)
- print('row_cnt', row_cnt)
- for c in cell_row:
- c.sort(key=lambda x: x[0][0])
- print('cell_row', [bbox_text_dict.get(str(x)) for x in c])
- if row_cnt == 0:
- # 获取最大列数的列
- temp_cell_row = copy.deepcopy(cell_row)
- temp_cell_row.sort(key=lambda x: len(x))
- max_cell_row = temp_cell_row[-1]
- # 对行内上下堆叠的进行合并
- max_cell_row.sort(key=lambda x: (x[0][0], x[0][1]))
- used_bbox_list = []
- merge_bbox_list = []
- for bbox1 in max_cell_row:
- temp_merge_bbox = [bbox1]
- if bbox1 in used_bbox_list:
- continue
- for bbox2 in max_cell_row:
- if bbox2 in used_bbox_list:
- continue
- if line_overlap(bbox1[0][0], bbox1[2][0], bbox2[0][0], bbox2[2][0]) >= 2/3 * min(bbox1[2][0]-bbox1[0][0], bbox2[2][0], bbox2[0][0]):
- temp_merge_bbox.append(bbox2)
- used_bbox_list += [bbox1, bbox2]
- # 选范围最大的bbox
- temp_merge_bbox.sort(key=lambda x: (x[2][0], -x[0][0]))
- merge_bbox_list.append(temp_merge_bbox[-1])
- temp_cell_row[-1] = merge_bbox_list
- print('temp_cell_row', [bbox_text_dict.get(str(x)) for x in temp_cell_row[-1]])
- # print('temp_cell_row', temp_cell_row[-1])
- for c in temp_cell_row[-1]:
- cell_col_lines.append([c[0][0], c[2][0]])
- cell_col_lines.sort(key=lambda x: x[0])
- for c in cell_col_lines:
- add_col_lines.append([[int(c[1]), location[0][1]], [int(c[1]), location[1][1]]])
- add_area_col_lines.append(add_col_lines)
- # # 循环所有行,若跨行
- # cell_col_lines.sort(key=lambda x: x[0])
- # cell_row.sort(key=lambda x: (x[0][0], x[0][1]))
- # print('sorted cell_col_lines', cell_col_lines)
- # for r in cell_row:
- # right_bbox = []
- # for bbox in r:
- # for k in range(len(cell_col_lines)):
- # if k == 0:
- # min_w = -10000
- # if len(cell_col_lines) <= 1:
- # max_w = cell_col_lines[k][1]
- # else:
- # max_w = cell_col_lines[k+1][0]
- # # + 1/4*(cell_col_lines[k+1][1]-cell_col_lines[k+1][0])
- # elif k == len(cell_col_lines) - 1:
- # max_w = 10000
- # if len(cell_col_lines) <= 1:
- # min_w = cell_col_lines[k-1][1]
- # else:
- # min_w = cell_col_lines[k][0]
- # else:
- # if len(cell_col_lines) <= 1:
- # min_w = -10000
- # max_w = 10000
- # else:
- # min_w = cell_col_lines[k-1][1]
- # max_w = cell_col_lines[k+1][0]
- # # + 1/4*(cell_col_lines[k+1][1]-cell_col_lines[k+1][0])
- #
- # # 判断跨行
- # if min_w <= bbox[0][0] <= bbox[2][0] <= max_w:
- # new_row.append(bbox)
- # right_bbox.append(bbox)
- # # else:
- # # print(min_w, bbox[0][0], bbox[2][0], max_w,
- # # bbox_text_dict.get(str(bbox)))
- #
- # # 有跨行,该行舍弃
- # if len(right_bbox) != len(r):
- # for r1 in r:
- # if r1 in new_row:
- # new_row.remove(r1)
- #
- # # 单元格只有一个bbox
- # else:
- # new_row = inter
- # print('new_row', [bbox_text_dict.get(str(x)) for x in new_row])
- # new_row.sort(key=lambda x: x[0][0])
- # new_row_list.append(new_row)
- # row_cnt += 1
- # new_col_list.append(new_row_list)
- #
- # new_row_list = [x for x in new_col_list[0]]
- # for col in new_col_list[1:]:
- # for j in range(len(col)):
- # new_row_list[j] += col[j]
- #
- # temp_new_row_list = []
- # for r in new_row_list:
- # if r:
- # temp_new_row_list.append(r)
- # print('new_row_list', [bbox_text_dict.get(str(x)) for x in r])
- # new_row_list = temp_new_row_list
- # area_row_list[i] = new_row_list
- return add_area_col_lines
- def judge_col_lines(img, area_col_lines, table_location_list, bbox_list, bbox_text_dict):
- new_area_col_lines = []
- for i in range(len(table_location_list)):
- location = table_location_list[i]
- col_lines = area_col_lines[i]
- col_lines.sort(key=lambda x: x[0][0])
- sub_bbox_list = []
- for bbox in bbox_list:
- if location[0][1] <= bbox[0][1] <= location[1][1]:
- sub_bbox_list.append(bbox)
- # 判断线穿过bbox,那一行的线去掉
- # temp_col_lines = []
- # for c in col_lines:
- # for bbox in sub_bbox_list:
- # 判断新增线有没有压在黑色像素上或有没有在bbox之间
- # temp_col_lines = []
- # for c in col_lines:
- # if c[1][1] >= img.shape[1] or c[0][1] <= 0:
- # continue
- #
- # black_cnt = count_black(img[c[0][1]:c[1][1], c[0][0]:c[1][0]+1, :])
- # if black_cnt > 10:
- # continue
- # temp_col_lines.append(c)
- # col_lines = temp_col_lines
- # 判断两线之间有没有完整bbox
- col_lines = [eval(y) for y in list(set([str(x) for x in col_lines]))]
- col_lines.sort(key=lambda x: x[0][0])
- threshold = 5
- if not col_lines:
- new_area_col_lines.append([])
- continue
- temp_col_lines = [col_lines[0]]
- for j in range(1, len(col_lines)):
- last_col_w = temp_col_lines[-1][0][0]
- col_w = col_lines[j][0][0]
- for bbox in sub_bbox_list:
- if last_col_w-threshold <= bbox[0][0] <= bbox[2][0] <= col_w+threshold:
- temp_col_lines.append(col_lines[j])
- break
- temp_col_lines.append(col_lines[-1])
- col_lines = temp_col_lines
- # 判断线穿过bbox,向右移动到空位
- for col in col_lines:
- cross_bbox_list = []
- for bbox in sub_bbox_list:
- if bbox[0][0] < col[0][0] < bbox[2][0]:
- cross_bbox_list.append(bbox)
- if cross_bbox_list:
- # cross_bbox_list.sort(key=lambda x: x[2][0], reverse=True)
- # for bbox in cross_bbox_list:
- # line_now_w = col[0][0]
- # line_move_w = bbox[2][0]
- # find_flag = 0
- # for bbox1 in sub_bbox_list:
- # if bbox1 in cross_bbox_list:
- # continue
- # if line_now_w <= bbox1[0][0] <= line_move_w:
- # find_flag = 1
- # break
- #
- # if not find_flag:
- # col[0][0] = int(line_move_w)
- # col[1][0] = int(line_move_w)
- # break
- cross_bbox_list.sort(key=lambda x: x[2][0], reverse=True)
- line_move_w = cross_bbox_list[0][2][0]
- line_now_w = col[0][0]
- for bbox1 in sub_bbox_list:
- if bbox1 in cross_bbox_list:
- continue
- if line_now_w <= bbox1[0][0] <= line_move_w:
- line_now_w = line_move_w
- line_move_w = bbox1[2][0]
- col[0][0] = int(line_move_w)
- col[1][0] = int(line_move_w)
- # 将边框线加上
- left_col = [[location[0][0], location[0][1]], [location[0][0], location[1][1]]]
- right_col = [[location[1][0], location[0][1]], [location[1][0], location[1][1]]]
- if left_col not in col_lines:
- col_lines.append(left_col)
- if right_col not in col_lines:
- col_lines.append(right_col)
- new_area_col_lines.append(col_lines)
- return new_area_col_lines
- def add_row_lines(area_row_list, area_col_list, table_location_list, bbox_text_dict, area_row_lines):
- add_area_row_lines = []
- for i in range(len(table_location_list)):
- row_list = area_row_list[i]
- col_list = area_col_list[i]
- location = table_location_list[i]
- row_lines = area_row_lines[i]
- add_row_lines = []
- for row in row_list:
- col_cnt = 0
- row.sort(key=lambda x: (x[0][0], x[0][1]))
- # # 只以第一列为标准
- # first_col = col_list[0]
- # first_col.sort(key=lambda x: (x[0][1], x[0][0]))
- # inter = [j for j in row if j in first_col]
- # inter = [eval(x) for x in list(set([str(x) for x in inter]))]
- # inter.sort(key=lambda x: (x[0][1], x[0][0]))
- # 所有列都参与
- for col in col_list:
- col.sort(key=lambda x: (x[0][1], x[0][0]))
- inter = [j for j in row if j in col]
- print('col', col_cnt, [bbox_text_dict.get(str(x)) for x in col], [bbox_text_dict.get(str(x)) for x in row])
- inter = [eval(x) for x in list(set([str(x) for x in inter]))]
- inter.sort(key=lambda x: (x[0][1], x[0][0]))
- print('add_row_lines inter', [bbox_text_dict.get(str(x)) for x in inter])
- if len(inter) > 0:
- # 单元格内分行
- cell_row = []
- temp_row = [inter[0]]
- row_len = [inter[0][0][1], inter[0][2][1]]
- for bbox in inter[1:]:
- temp_bbox = temp_row[0]
- bbox_h_len = bbox[2][1] - bbox[0][1]
- temp_bbox_h_len = temp_bbox[2][1] - temp_bbox[0][1]
- if line_overlap(row_len[0], row_len[1], bbox[0][1], bbox[2][1]) >= 1/3 * min(bbox_h_len, temp_bbox_h_len):
- temp_row.append(bbox)
- row_len[0] = min(row_len[0], bbox[0][1])
- row_len[1] = max(row_len[1], bbox[2][1])
- else:
- cell_row.append(temp_row)
- temp_row = [bbox]
- row_len = [bbox[0][1], bbox[2][1]]
- if temp_row:
- cell_row.append(temp_row)
- print('col_cnt', col_cnt)
- for c in cell_row:
- c.sort(key=lambda x: x[0][0])
- print('cell_row', [bbox_text_dict.get(str(x)) for x in c])
- # 对有多行的
- if len(cell_row) > 0:
- if len(cell_row) == 1:
- h = int(cell_row[0][0][2][1])
- add_row_lines.append([[location[0][0], h], [location[1][0], h]])
- for j in range(1, len(cell_row)):
- last_row = cell_row[j-1]
- row1 = cell_row[j]
- last_row.sort(key=lambda x: x[2][1])
- row1.sort(key=lambda x: x[0][1])
- find_flag = 0
- for l in row_lines:
- if last_row[-1][2][1] <= l[0][1] <= row1[0][0][1]:
- find_flag = 1
- break
- if not find_flag:
- h = int(last_row[-1][2][1])
- if j == 1:
- last_row.sort(key=lambda x: x[0][1])
- h += int(last_row[0][0][1] - location[0][1])
- else:
- last_two_row = cell_row[j-2]
- last_two_row.sort(key=lambda x: x[2][1])
- last_row.sort(key=lambda x: x[0][1])
- h += int(last_row[0][0][1] - last_two_row[-1][2][1])
- add_row_lines.append([[location[0][0], h], [location[1][0], h]])
- col_cnt += 1
- add_area_row_lines.append(add_row_lines)
- return add_area_row_lines
- def judge_row_lines(img, area_row_lines, table_location_list, bbox_list, bbox_text_dict):
- new_area_row_lines = []
- for i in range(len(table_location_list)):
- location = table_location_list[i]
- row_lines = area_row_lines[i]
- sub_bbox_list = []
- for bbox in bbox_list:
- if location[0][1] <= bbox[0][1] <= location[1][1]:
- sub_bbox_list.append(bbox)
- # 判断新增线有没有压在黑色像素上或有没有在bbox之间
- # temp_row_lines = []
- # for c in row_lines:
- # if c[1][1] >= img.shape[1] or c[0][1] <= 0:
- # continue
- # if c[1][1] > location[1][1] or c[1][1] < location[0][1]:
- # continue
- # black_cnt = count_black(img[c[0][1]:c[1][1]+1, c[0][0]:c[1][0], :])
- # if black_cnt > 10:
- # continue
- # temp_row_lines.append(c)
- # row_lines = temp_row_lines
- # if not row_lines:
- # new_area_row_lines.append([])
- # continue
- # 判断线穿过bbox,向下移动到空位
- row_lines.sort(key=lambda x: x[0][1])
- for row in row_lines:
- # 循环找出空位
- while True:
- cross_bbox_list = []
- for bbox in sub_bbox_list:
- # if (bbox[0][1]+bbox[2][1])/2 < row[0][1] < bbox[2][1]:
- if bbox[0][1] < row[0][1] < bbox[2][1]:
- cross_bbox_list.append(bbox)
- if cross_bbox_list:
- # print('row cross_bbox_list', [bbox_text_dict.get(str(x)) for x in cross_bbox_list])
- cross_bbox_list.sort(key=lambda x: x[2][1], reverse=True)
- line_move_h = cross_bbox_list[0][2][1]
- move_bbox = cross_bbox_list[0]
- row[0][1] = int(line_move_h)
- row[1][1] = int(line_move_h)
- # print('move_bbox', bbox_text_dict.get(str(move_bbox)))
- else:
- break
- # 判断两线之间有没有完整bbox
- row_lines.sort(key=lambda x: x[0][1])
- threshold = 5
- temp_row_lines = [row_lines[0]]
- for j in range(1, len(row_lines)):
- last_row_w = temp_row_lines[-1][0][1]
- row_w = row_lines[j][0][1]
- for bbox in sub_bbox_list:
- if last_row_w-threshold <= bbox[0][1] <= bbox[2][1] <= row_w+threshold:
- temp_row_lines.append(row_lines[j])
- break
- temp_row_lines.append(row_lines[-1])
- row_lines = temp_row_lines
- # 将边框线加上
- up_row = [[location[0][0], location[0][1]], [location[1][0], location[0][1]]]
- bottom_row = [[location[0][0], location[1][1]], [location[1][0], location[1][1]]]
- if up_row not in row_lines:
- row_lines.append(up_row)
- if bottom_row not in row_lines:
- row_lines.append(bottom_row)
- new_area_row_lines.append(row_lines)
- return new_area_row_lines
- def merge_lines(lines, axis=0, threshold=5):
- lines.sort(key=lambda x: x[0][1-axis])
- used_lines = []
- new_lines = copy.deepcopy(lines)
- for line1 in lines:
- if line1 in used_lines:
- continue
- current_line = line1
- temp_merge_lines = [line1]
- for line2 in lines:
- if line2 in used_lines:
- continue
- if abs(current_line[0][1-axis] - line2[0][1-axis]) <= threshold:
- temp_merge_lines.append(line2)
- used_lines.append(line2)
- current_line = line2
- # 取最右或最下线
- temp_merge_lines.sort(key=lambda x: x[0][1-axis])
- for l in temp_merge_lines:
- if l in new_lines:
- new_lines.remove(l)
- new_lines.append(temp_merge_lines[-1])
- new_lines.sort(key=lambda x: x[0][1-axis])
- return new_lines
- def merge_row_bbox_list(area_row_list):
- new_area_row_list = []
- for row_list in area_row_list:
- new_row_list = copy.deepcopy(row_list)
- # 针对表头有多行被分在不同行
- for i in range(1, len(row_list)-1):
- last_row = row_list[i-1]
- row = row_list[i]
- next_row = row_list[i+1]
- merge_cnt = 0
- for last_b in last_row:
- find_flag1 = 0
- for next_b in next_row:
- if line_overlap(last_b[0][0], last_b[2][0], next_b[0][0], next_b[2][0]) \
- >= 0.8*min(last_b[2][0] - last_b[0][0], next_b[2][0] - next_b[0][0]):
- find_flag1 = 1
- break
- find_flag2 = 0
- if find_flag1:
- for b in row:
- if line_overlap(last_b[0][0], last_b[2][0], b[0][0], b[2][0]) \
- >= 0.8*min(last_b[2][0] - last_b[0][0], b[2][0] - b[0][0]):
- find_flag2 = 1
- break
- if find_flag1 and not find_flag2:
- merge_cnt += 1
- if merge_cnt == len(last_row) and merge_cnt > 1:
- new_row_list = new_row_list[:i-1] + [last_row+row+next_row] + new_row_list[i+2:]
- new_area_row_list.append(new_row_list)
- return new_area_row_list
- def count_black(image_np, threshold=150):
- lower = np.array([0, 0, 0])
- upper = np.array([threshold, threshold, threshold])
- mask = cv2.inRange(image_np, lower, upper)
- cnt = np.sum(mask != 0)
- # print("count color ", cnt)
- return cnt
- def get_bbox_list_by_lines(img, area_row_lines, area_col_lines, table_location_list, bbox_list, axis=0):
- area_row_list = []
- for i in range(len(table_location_list)):
- row_lines = area_row_lines[i]
- col_lines = area_col_lines[i]
- # 求线交点
- cross_points = get_points_by_line(img, row_lines, col_lines)
- # 交点分行
- cross_points.sort(key=lambda x: (x[1-axis], x[axis]))
- row_point_list = []
- current_row = [cross_points[0]]
- for p in cross_points[1:]:
- if current_row[0][1-axis] == p[1-axis]:
- current_row.append(p)
- else:
- row_point_list.append(current_row)
- current_row = [p]
- if current_row:
- row_point_list.append(current_row)
- used_bbox_list = []
- row_list = []
- for j in range(1, len(row_point_list)):
- last_row = row_point_list[j-1]
- row = row_point_list[j]
- sub_row_list = []
- for k in range(1, len(row)):
- last_p = last_row[k-1]
- p = row[k]
- for bbox in bbox_list:
- if bbox in used_bbox_list:
- continue
- bbox_h_center = (bbox[0][1-axis]+bbox[2][1-axis]) / 2
- bbox_w_center = (bbox[0][axis]+bbox[2][axis]) / 2
- if last_p[1-axis] <= bbox_h_center <= p[1-axis] and last_p[axis] <= bbox_w_center <= p[axis]:
- sub_row_list.append(bbox)
- used_bbox_list.append(bbox)
- row_list.append(sub_row_list)
- area_row_list.append(row_list)
- return area_row_list
- def get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list):
- area_table_bbox_list = []
- area_table_cell_list = []
- for i in range(len(table_location_list)):
- row_lines = area_row_lines[i]
- col_lines = area_col_lines[i]
- # 求线交点
- cross_points = get_points_by_line(img, row_lines, col_lines)
- # 交点分行
- cross_points.sort(key=lambda x: (x[1], x[0]))
- row_point_list = []
- if not cross_points:
- area_table_bbox_list.append([])
- area_table_cell_list.append([])
- continue
- current_row = [cross_points[0]]
- for p in cross_points[1:]:
- if current_row[0][1] == p[1]:
- current_row.append(p)
- else:
- row_point_list.append(current_row)
- current_row = [p]
- if current_row:
- row_point_list.append(current_row)
- # bbox以表格格式排列
- used_bbox_list = []
- row_list = []
- row_cell_list = []
- for j in range(1, len(row_point_list)):
- last_row = row_point_list[j-1]
- row = row_point_list[j]
- col_list = []
- col_cell_list = []
- for k in range(1, len(row)):
- last_p = last_row[k-1]
- p = row[k]
- cell = []
- for bbox in bbox_list:
- if bbox in used_bbox_list:
- continue
- bbox_h_center = (bbox[0][1]+bbox[2][1]) / 2
- bbox_w_center = (bbox[0][0]+bbox[2][0]) / 2
- if last_p[1] <= bbox_h_center <= p[1] and last_p[0] <= bbox_w_center <= p[0]:
- cell.append(bbox)
- used_bbox_list.append(bbox)
- col_list.append(cell)
- col_cell_list.append([last_p, p])
- row_list.append(col_list)
- row_cell_list.append(col_cell_list)
- area_table_bbox_list.append(row_list)
- area_table_cell_list.append(row_cell_list)
- return area_table_bbox_list, area_table_cell_list
- def get_lines_from_img(img):
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- # 开操作提取水平线
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 1))
- binary1 = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel)
- # cv2.imshow('7,1', binary1)
- # 开操作提取垂直线
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 7))
- binary2 = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel)
- # cv2.imshow('1,7', binary2)
- #
- # cv2.imshow('table', binary1+binary2)
- # cv2.waitKey(0)
- return binary1, binary2
- def get_bbox_by_img(row_img, col_img):
- # cv2.imshow('table', row_img+col_img)
- # 求线交点
- point_img = np.bitwise_and(row_img, col_img)
- # cv2.imshow('point_img', point_img)
- # cv2.waitKey(0)
- # 识别黑白图中的白色交叉点,将横纵坐标取出
- ys, xs = np.where(point_img > 200)
- cross_points = []
- for i in range(len(xs)):
- cross_points.append((xs[i], ys[i]))
- cross_points.sort(key=lambda x: (x[0], x[1]))
- temp_cross_points = []
- for p1 in cross_points:
- find_flag = 0
- for p2 in temp_cross_points:
- if abs(p1[1] - p2[1]) <= 5 and abs(p1[0] - p2[0]) <= 5:
- find_flag = 1
- break
- if not find_flag:
- temp_cross_points.append(p1)
- cross_points = temp_cross_points
- if not cross_points:
- return [], False
- print('cross_points', len(cross_points))
- axis = 0
- # 交点分行
- cross_points.sort(key=lambda x: (x[1-axis], x[axis]))
- row_point_list = []
- current_row = [cross_points[0]]
- for p in cross_points[1:]:
- if abs(current_row[0][1-axis] - p[1-axis]) <= 5:
- current_row.append(p)
- else:
- current_row.sort(key=lambda x: x[0])
- row_point_list.append(current_row)
- current_row = [p]
- if current_row:
- current_row.sort(key=lambda x: x[0])
- row_point_list.append(current_row)
- row_len = len(row_point_list[0])
- for row in row_point_list:
- # print('row_point_list row', len(row), row)
- if row_len != len(row):
- return [], False
- row_list = []
- standard_flag = True
- for j in range(1, len(row_point_list)):
- last_row = row_point_list[j-1]
- row = row_point_list[j]
- sub_row = []
- for k in range(1, len(row)):
- if k-1 >= len(last_row):
- # print(len(last_row), len(row))
- standard_flag = False
- break
- last_p = last_row[k-1]
- p = row[k]
- bbox = [last_p, p]
- sub_row.append(bbox)
- row_list.append(sub_row)
- if not row_list:
- return [], False
- row_len = len(row_list[0])
- for row in row_list:
- if len(row) != row_len:
- standard_flag = False
- break
- print('standard_flag', standard_flag)
- if standard_flag:
- new_img = np.zeros((row_img.shape[0], row_img.shape[1], 3), dtype=np.uint8)
- # for row in row_list:
- # for b in row:
- # print('b', b)
- # cv2.rectangle(new_img, [int(b[0][0]), int(b[0][1])], [int(b[1][0]), int(b[1][1])],
- # (0, 0, 255), 1)
- # cv2.imshow('new_img', new_img)
- # cv2.waitKey(0)
- return row_list, standard_flag
- def get_points_by_line(img, row_lines, col_lines):
- row_img = np.zeros_like(img[:, :, 0], dtype=np.uint8)
- col_img = np.zeros_like(img[:, :, 0], dtype=np.uint8)
- for r in row_lines:
- cv2.line(row_img, r[0], r[1], (255, 255, 255), 1)
- for c in col_lines:
- cv2.line(col_img, c[0], c[1], (255, 255, 255), 1)
- point_img = np.bitwise_and(row_img, col_img)
- # 识别黑白图中的白色交叉点,将横纵坐标取出
- ys, xs = np.where(point_img > 0)
- points = []
- for i in range(len(xs)):
- points.append((xs[i], ys[i]))
- points.sort(key=lambda x: (x[0], x[1]))
- return points
- def merge_text_and_table(text_bbox_list, table_row_list):
- used_bbox_list = []
- row_list = []
- for row in table_row_list:
- sub_row_list = []
- row.sort(key=lambda x: x[0][0])
- for bbox1 in row:
- sub_bbox_list = []
- for bbox2 in text_bbox_list:
- if bbox2 in used_bbox_list:
- continue
- bbox_h_center = (bbox2[0][1]+bbox2[2][1]) / 2
- bbox_w_center = (bbox2[0][0]+bbox2[2][0]) / 2
- if bbox1[0][1] <= bbox_h_center <= bbox1[1][1] and bbox1[0][0] <= bbox_w_center <= bbox1[1][0]:
- sub_bbox_list.append(bbox2)
- used_bbox_list.append(bbox2)
- sub_row_list.append(sub_bbox_list)
- row_list.append(sub_row_list)
- return row_list
- def shrink_bbox(img, bbox_list):
- def return_first_black_index(image_np):
- lower = np.array([0, 0, 0])
- upper = np.array([150, 150, 150])
- mask = cv2.inRange(image_np, lower, upper)
- black_index_list = np.where(mask != 0)
- return black_index_list
- new_bbox_list = []
- for bbox in bbox_list:
- img_bbox = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
- if 0 in img_bbox.shape:
- new_bbox_list.append(bbox)
- continue
- # 左右上下开始扫描,碰到黑像素即停
- index_list = return_first_black_index(img_bbox[:, :, :])
- if index_list[0].size == 0 or index_list[1].size == 0:
- new_bbox_list.append(bbox)
- continue
- min_h = index_list[0][0]
- max_h = index_list[0][-1]
- img_bbox1 = np.swapaxes(img_bbox, 0, 1)
- index_list = return_first_black_index(img_bbox1[:, :, :])
- if index_list[0].size == 0 or index_list[1].size == 0:
- new_bbox_list.append(bbox)
- continue
- min_w = index_list[0][0]
- max_w = index_list[0][-1]
- real_min_w = bbox[0][0] + min_w
- real_max_w = bbox[0][0] + max_w
- real_min_h = bbox[0][1] + min_h
- real_max_h = bbox[0][1] + max_h
- new_bbox = [[real_min_w, real_min_h], [real_min_w, real_max_h], [real_max_w, real_max_h], [real_max_w, real_min_h]]
- new_bbox_list.append(new_bbox)
- # cv2.imshow('img', img_bbox)
- # cv2.imshow('shrink', img[int(new_bbox[0][1]):int(new_bbox[2][1]), int(new_bbox[0][0]):int(new_bbox[2][0]), :])
- # cv2.waitKey(0)
- return new_bbox_list
- def affinity_propagation(data_list):
- """
- 聚类:近邻传播
- :return:
- """
- data_np = np.array(data_list)
- random_state = 170
- model = AffinityPropagation(damping=0.5, convergence_iter=15, random_state=random_state).fit(data_np)
- # cluster_centers_indices = model.cluster_centers_indices_
- y_pred = model.labels_
- if y_pred[0] == -1:
- print('ap dp0.5 ci50')
- model = AffinityPropagation(convergence_iter=50, random_state=random_state).fit(data_np)
- y_pred = model.labels_
- if y_pred[0] == -1:
- print('ap dp0.7 ci15')
- model = AffinityPropagation(damping=0.7, convergence_iter=15, random_state=random_state).fit(data_np)
- y_pred = model.labels_
- if y_pred[0] == -1:
- print('ap dp0.7 ci50')
- model = AffinityPropagation(damping=0.7, convergence_iter=50, random_state=random_state).fit(data_np)
- y_pred = model.labels_
- if y_pred[0] == -1:
- print('all -1')
- y_pred = np.zeros(y_pred.shape[0])
- y_pred = y_pred.tolist()
- return y_pred
- def dbscan(data_list):
- """
- 聚类:dbscan
- :return:
- """
- data_np = np.array(data_list)
- model = DBSCAN(eps=3, min_samples=2).fit(data_np)
- y_pred = model.labels_
- y_pred = y_pred.tolist()
- return y_pred
- def test_ocr_model(img_path):
- with open(img_path, "rb") as f:
- file_bytes = f.read()
- file_base64 = base64.b64encode(file_bytes)
- file_json = {"data": file_base64, "md5": 0}
- _url = "http://192.168.2.103:17000/ocr"
- # _url = "http://127.0.0.1:17000/ocr"
- result = json.loads(request_post(_url, file_json))
- return result
- def test_cho_model(text):
- # text = "巧克力"
- text = [x for x in text]
- data_json = {"data": json.dumps(text)}
- _url = "http://192.168.2.103:17058/cho"
- result = json.loads(request_post(_url, data_json))
- if result.get("success"):
- decode_list = result.get("data")
- print("char_list", text)
- print("decode_list", decode_list)
- return decode_list
- else:
- print("failed!")
- if __name__ == '__main__':
- get_table_new()
- # _l = [[18, 0], [0, 0], [14, 0], [0, 0], [12, 0], [0, 0], [14, 0], [2, 0], [15, 0], [0, 0]]
- # # _l = [[27, 0], [26, 0], [17, 0]]
- # print(affinity_propagation(_l))
- # print(dbscan(_l))
- # _img = cv2.imread(r'C:\Users\Administrator\Desktop\111.jpg')
- # shrink_bbox(_img, [[[0, 0], [0, 0], [_img.shape[1], _img.shape[0]], [_img.shape[1], _img.shape[0]]]])
|