123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Thu Sep 9 23:11:51 2020
- table line detect
- @author: chineseocr
- """
- import copy
- import io
- import logging
- import sys
- import traceback
- import tensorflow as tf
- import tensorflow.keras.backend as K
- from tensorflow.keras.models import Model
- from tensorflow.keras.layers import Input, concatenate, Conv2D, MaxPooling2D, BatchNormalization, UpSampling2D
- from tensorflow.keras.layers import LeakyReLU
- from otr.utils import letterbox_image, get_table_line, adjust_lines, line_to_line, draw_boxes
- import numpy as np
- import cv2
- import time
- from format_convert import _global
- from format_convert.utils import log
- def dice_coef(y_true, y_pred, smooth=1e-5):
- y_true_f = K.flatten(y_true)
- y_pred_f = K.flatten(y_pred)
- intersection = K.sum(y_true_f * y_pred_f)
- return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
- def dice_coef_loss():
- def dice_coef_loss_fixed(y_true, y_pred):
- return -dice_coef(y_true, y_pred)
- return dice_coef_loss_fixed
- def focal_loss(gamma=3., alpha=.5):
- # 3 0.85 2000e acc-0.6 p-0.99 r-0.99 val_acc-0.56 val_p-0.86 val_r-0.95
- # 2 0.85 double_gpu acc-
- # 3 0.25 gpu 50e acc-0.5 p-0.99 r-0.99 val_acc-0.45 val_p-0.96 val_r-0.88
- # 2 0.25 gpu acc-
- # 3 0.5 double_gpu acc-0.6 p-0.99 r-0.99 val_acc-0.60 val_p-0.93 val_r-0.93
- def focal_loss_fixed(y_true, y_pred):
- pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
- pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
- return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon()+pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
- return focal_loss_fixed
- def table_net_large(input_shape=(1152, 896, 3), num_classes=1):
- inputs = Input(shape=input_shape)
- # 512
- use_bias = False
- down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(inputs)
- down0a = BatchNormalization()(down0a)
- down0a = LeakyReLU(alpha=0.1)(down0a)
- down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0a)
- down0a = BatchNormalization()(down0a)
- down0a = LeakyReLU(alpha=0.1)(down0a)
- down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a)
- # 256
- down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0a_pool)
- down0 = BatchNormalization()(down0)
-
- down0 = LeakyReLU(alpha=0.1)(down0)
- down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0)
- down0 = BatchNormalization()(down0)
- down0 = LeakyReLU(alpha=0.1)(down0)
- down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0)
- # 128
- down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down0_pool)
- down1 = BatchNormalization()(down1)
- down1 = LeakyReLU(alpha=0.1)(down1)
- down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down1)
- down1 = BatchNormalization()(down1)
- down1 = LeakyReLU(alpha=0.1)(down1)
- down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
- # 64
- down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down1_pool)
- down2 = BatchNormalization()(down2)
- down2 = LeakyReLU(alpha=0.1)(down2)
- down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down2)
- down2 = BatchNormalization()(down2)
- down2 = LeakyReLU(alpha=0.1)(down2)
- down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
- # 32
- down3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(down2_pool)
- down3 = BatchNormalization()(down3)
- down3 = LeakyReLU(alpha=0.1)(down3)
- down3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(down3)
- down3 = BatchNormalization()(down3)
- down3 = LeakyReLU(alpha=0.1)(down3)
- down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)
- # 16
- down4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(down3_pool)
- down4 = BatchNormalization()(down4)
- down4 = LeakyReLU(alpha=0.1)(down4)
- down4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(down4)
- down4 = BatchNormalization()(down4)
- down4 = LeakyReLU(alpha=0.1)(down4)
- down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)
- # 8
- center = Conv2D(1024, (3, 3), padding='same', use_bias=use_bias)(down4_pool)
- center = BatchNormalization()(center)
- center = LeakyReLU(alpha=0.1)(center)
- center = Conv2D(1024, (3, 3), padding='same', use_bias=use_bias)(center)
- center = BatchNormalization()(center)
- center = LeakyReLU(alpha=0.1)(center)
- # center
- up4 = UpSampling2D((2, 2))(center)
- up4 = concatenate([down4, up4], axis=3)
- up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4)
- up4 = BatchNormalization()(up4)
- up4 = LeakyReLU(alpha=0.1)(up4)
- up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4)
- up4 = BatchNormalization()(up4)
- up4 = LeakyReLU(alpha=0.1)(up4)
- up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4)
- up4 = BatchNormalization()(up4)
- up4 = LeakyReLU(alpha=0.1)(up4)
- # 16
- up3 = UpSampling2D((2, 2))(up4)
- up3 = concatenate([down3, up3], axis=3)
- up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3)
- up3 = BatchNormalization()(up3)
- up3 = LeakyReLU(alpha=0.1)(up3)
- up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3)
- up3 = BatchNormalization()(up3)
- up3 = LeakyReLU(alpha=0.1)(up3)
- up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3)
- up3 = BatchNormalization()(up3)
- up3 = LeakyReLU(alpha=0.1)(up3)
- # 32
- up2 = UpSampling2D((2, 2))(up3)
- up2 = concatenate([down2, up2], axis=3)
- up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2)
- up2 = BatchNormalization()(up2)
- up2 = LeakyReLU(alpha=0.1)(up2)
- up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2)
- up2 = BatchNormalization()(up2)
- up2 = LeakyReLU(alpha=0.1)(up2)
- up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2)
- up2 = BatchNormalization()(up2)
- up2 = LeakyReLU(alpha=0.1)(up2)
- # 64
- up1 = UpSampling2D((2, 2))(up2)
- up1 = concatenate([down1, up1], axis=3)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- # 128
- up0 = UpSampling2D((2, 2))(up1)
- up0 = concatenate([down0, up0], axis=3)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- # 256
- up0a = UpSampling2D((2, 2))(up0)
- up0a = concatenate([down0a, up0a], axis=3)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- # 512
- classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0a)
- model = Model(inputs=inputs, outputs=classify)
- return model
- def table_net(input_shape=(1152, 896, 3), num_classes=1):
- inputs = Input(shape=input_shape)
- # 512
- use_bias = False
- down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(inputs)
- down0a = BatchNormalization()(down0a)
- down0a = LeakyReLU(alpha=0.1)(down0a)
- down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0a)
- down0a = BatchNormalization()(down0a)
- down0a = LeakyReLU(alpha=0.1)(down0a)
- down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a)
- # 256
- down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0a_pool)
- down0 = BatchNormalization()(down0)
- down0 = LeakyReLU(alpha=0.1)(down0)
- down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0)
- down0 = BatchNormalization()(down0)
- down0 = LeakyReLU(alpha=0.1)(down0)
- down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0)
- # 128
- down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down0_pool)
- down1 = BatchNormalization()(down1)
- down1 = LeakyReLU(alpha=0.1)(down1)
- down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down1)
- down1 = BatchNormalization()(down1)
- down1 = LeakyReLU(alpha=0.1)(down1)
- down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
- # 64
- down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down1_pool)
- down2 = BatchNormalization()(down2)
- down2 = LeakyReLU(alpha=0.1)(down2)
- down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down2)
- down2 = BatchNormalization()(down2)
- down2 = LeakyReLU(alpha=0.1)(down2)
- down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
- # 32
- up1 = UpSampling2D((2, 2))(down2)
- up1 = concatenate([down1, up1], axis=3)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- # 128
- up0 = UpSampling2D((2, 2))(up1)
- up0 = concatenate([down0, up0], axis=3)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- # 256
- up0a = UpSampling2D((2, 2))(up0)
- up0a = concatenate([down0a, up0a], axis=3)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- # 512
- classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0a)
- model = Model(inputs=inputs, outputs=classify)
- return model
- model = table_net((None, None, 3), 2)
- def draw_pixel(pred, prob=0.2, is_test=1):
- if not is_test:
- return
- else:
- import matplotlib.pyplot as plt
- _array = []
- for _h in range(len(pred)):
- _line = []
- for _w in range(len(pred[_h])):
- _prob = pred[_h][_w]
- if _prob[0] > prob:
- _line.append((0, 0, 255))
- elif _prob[1] > prob:
- _line.append((255, 0, 0))
- else:
- _line.append((255, 255, 255))
- _array.append(_line)
- plt.axis('off')
- plt.imshow(np.array(_array))
- plt.show()
- return
- def expansionAndShrinkage(pred,width=3):
- pred_array = np.array(pred)
- print("pred_array=====",pred_array.shape)
- _array = pred_array[...,0]
- _l = [_array]
- for _i in range(width):
- tmp_array = np.pad(_array[:-(_i+1),...],((_i+1,0),(0,0)))
- _l.append(tmp_array)
- for _i in range(width):
- tmp_array = np.pad(_array[_i+1:,...],((0,_i+1),(0,0)))
- _l.append(tmp_array)
- for _i in range(width):
- tmp_array = np.pad(_array[...,:-(_i+1)],((0,0),(_i+1,0)))
- _l.append(tmp_array)
- for _i in range(width):
- tmp_array = np.pad(_array[...,_i+1:],((0,0),(0,_i+1)))
- _l.append(tmp_array)
- for _a in _l:
- print(_a.shape)
- h_array = np.stack(_l,axis=0)
- h_array = np.max(h_array,axis=0,keepdims=False)
- _array = pred_array[...,1]
- _l = [_array]
- for _i in range(width):
- tmp_array = np.pad(_array[:-(_i+1),...],((_i+1,0),(0,0)))
- _l.append(tmp_array)
- for _i in range(width):
- tmp_array = np.pad(_array[_i+1:,...],((0,_i+1),(0,0)))
- _l.append(tmp_array)
- for _i in range(width):
- tmp_array = np.pad(_array[...,:-(_i+1)],((0,0),(_i+1,0)))
- _l.append(tmp_array)
- for _i in range(width):
- tmp_array = np.pad(_array[...,_i+1:],((0,0),(0,_i+1)))
- _l.append(tmp_array)
- v_array = np.stack(_l,axis=0)
- print("v_array=====",v_array.shape)
- v_array = np.max(v_array,axis=0,keepdims=False)
- print("h_array=====",h_array.shape)
- print("v_array=====",v_array.shape)
- last_array = np.stack([h_array,v_array],axis=-1)
- print("pred_array=====",last_array.shape)
- return last_array
- def getIOU(bbox0, bbox1):
- width = abs(max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0]))-(abs(bbox0[2]-bbox0[0])+abs(bbox1[2]-bbox1[0]))
- height = abs(max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1]))-(abs(bbox0[3]-bbox0[1])+abs(bbox1[3]-bbox1[1]))
- if width < 0 and height < 0:
- iou = abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),
- abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))
- # print("getIOU", iou)
- return iou
- return 0
- def lines_cluster(list_lines,line_width):
- after_len = 0
- log("len lines %d"%len(list_lines))
- append_width = line_width//2
- while 1:
- c_lines = []
- first_len = after_len
- for _line in list_lines:
- bbox = _line["bbox"]
- _find = False
- for c_l_i in range(len(c_lines)):
- c_l = c_lines[len(c_lines)-c_l_i-1]
- bbox1 = c_l["bbox"]
- bboxa = [max(0,bbox[0]-append_width),max(0,bbox[1]-append_width),bbox[2]+append_width,bbox[3]+append_width]
- bboxb = [max(0,bbox1[0]-append_width),max(0,bbox1[1]-append_width),bbox1[2]+append_width,bbox1[3]+append_width]
- _iou = getIOU(bboxa,bboxb)
- if _iou>0:
- new_bbox = [min(bbox[0],bbox[2],bbox1[0],bbox1[2]),min(bbox[1],bbox[3],bbox1[1],bbox1[3]),max(bbox[0],bbox[2],bbox1[0],bbox1[2]),max(bbox[1],bbox[3],bbox1[1],bbox1[3])]
- _find = True
- c_l["bbox"] = new_bbox
- break
- if not _find:
- c_lines.append(_line)
- after_len = len(c_lines)
- if first_len==after_len:
- break
- list_lines = c_lines
- return c_lines
- def points2lines(pred,sourceP_LB=True, prob=0.2, line_width=8, padding=3, min_len=10,
- cell_width=13):
- _time = time.time()
- log("starting points2lines")
- height = len(pred)
- width = len(pred[0])
- _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1))
- h_index = -1
- h_lines = []
- v_lines = []
- _step = line_width
- while 1:
- h_index += 1
- if h_index>=height:
- break
- w_index = -1
- if sourceP_LB:
- h_i = height-1-h_index
- else:
- h_i = h_index
- _start = None
- if _sum[h_index]<min_len:
- continue
- while 1:
- w_index += _step
- if w_index>=width:
- break
- _h,_v = pred[h_i][w_index]
- if _h>prob:
- if _start is None:
- _start = w_index
- else:
- if _start is not None:
- _end = w_index-1
- _bbox = [_start,h_i,_end,h_i]
- _dict = {"bbox":_bbox}
- h_lines.append(_dict)
- _start = None
- w_index -= _step//2
- log("starting points2lines 1")
- w_index = -1
- _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=0))
- _step = line_width
- while 1:
- w_index += 1
- if w_index>=width:
- break
- if _sum[w_index]<min_len:
- continue
- h_index = -1
- _start = None
- while 1:
- h_index += _step
- if h_index>=height:
- break
- if sourceP_LB:
- h_i = height-1-h_index
- else:
- h_i = h_index
- _h,_v = pred[h_index][w_index]
- if _v>prob:
- if _start is None:
- _start = h_i
- else:
- if _start is not None:
- _end = last_h
- _bbox = [w_index,_start,w_index,_end]
- _dict = {"bbox":_bbox}
- v_lines.append(_dict)
- _start = None
- h_index -= _step//2
- last_h = h_i
- log("starting points2lines 2")
- for _line in h_lines:
- _bbox = _line["bbox"]
- _bbox = [max(_bbox[0]-2,0),(_bbox[1]+_bbox[3])/2,_bbox[2]+2,(_bbox[1]+_bbox[3])/2]
- _line["bbox"] = _bbox
- for _line in v_lines:
- _bbox = _line["bbox"]
- _bbox = [(_bbox[0]+_bbox[2])/2,max(_bbox[1]-2,0),(_bbox[0]+_bbox[2])/2,_bbox[3]+2]
- _line["bbox"] = _bbox
- h_lines = lines_cluster(h_lines,line_width=line_width)
- v_lines = lines_cluster(v_lines,line_width=line_width)
- list_line = []
- for _line in h_lines:
- _bbox = _line["bbox"]
- _bbox = [max(_bbox[0]-1,0),(_bbox[1]+_bbox[3])/2,_bbox[2]+1,(_bbox[1]+_bbox[3])/2]
- list_line.append(_bbox)
- for _line in v_lines:
- _bbox = _line["bbox"]
- _bbox = [(_bbox[0]+_bbox[2])/2,max(_bbox[1]-1,0),(_bbox[0]+_bbox[2])/2,_bbox[3]+1]
- list_line.append(_bbox)
- log("points2lines cost %.2fs"%(time.time()-_time))
- # import matplotlib.pyplot as plt
- # plt.figure()
- # for _line in list_line:
- # x0,y0,x1,y1 = _line
- # plt.plot([x0,x1],[y0,y1])
- # for _line in list_line:
- # x0,y0,x1,y1 = _line.bbox
- # plt.plot([x0,x1],[y0,y1])
- # for point in list_crosspoints:
- # plt.scatter(point.get("point")[0],point.get("point")[1])
- # plt.show()
- return list_line
- def points2lines_bak(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_len=10,
- cell_width=13):
- def inBbox(bbox,point,line_width):
- x,y = point
- if x>=bbox[0]-line_width and x<=bbox[2]+line_width and y>=bbox[1]-line_width and y<=bbox[3]+line_width:
- return True,[min(x,bbox[0]),min(y,bbox[1]),max(x,bbox[2]),max(y,bbox[3])]
- return False,None
- _time = time.time()
- height = len(pred)
- width = len(pred[0])
- clust_horizontal = []
- clust_vertical = []
- h_index = -1
- _step = line_width
- _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=0))
- _last = False
- _current = False
- while 1:
- h_index += 2
- if h_index>=height:
- break
- w_index = -1
- if sourceP_LB:
- h_i = height-1-h_index
- else:
- h_i = h_index
- while 1:
- w_index += 2
- if w_index>=width:
- break
- if _sum[w_index]<min_len:
- continue
- _h,_v = pred[h_index][w_index]
- if _v>prob:
- _find = False
- _point = (w_index,h_i)
- for l_h_i in range(len(clust_vertical)):
- l_h = clust_vertical[len(clust_vertical)-l_h_i-1]
- bbox = l_h.get("bbox")
- b_in,_bbox = inBbox(bbox,_point,line_width)
- if b_in:
- _find = True
- l_h.get("points").append(_point)
- l_h["bbox"] = _bbox
- break
- if not _find:
- clust_vertical.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
- w_index = -1
- _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1))
- while 1:
- w_index += 2
- if w_index>=width:
- break
- h_index = -1
- while 1:
- h_index += 2
- if h_index>=height:
- break
- if _sum[h_index]<min_len:
- continue
- if sourceP_LB:
- h_i = height-1-h_index
- else:
- h_i = h_index
- _h,_v = pred[h_index][w_index]
- if _h>prob:
- _find = False
- _point = (w_index,h_i)
- for l_h_i in range(len(clust_horizontal)):
- l_h = clust_horizontal[len(clust_horizontal)-l_h_i-1]
- bbox = l_h.get("bbox")
- b_in,_bbox = inBbox(bbox,_point,line_width)
- if b_in:
- _find = True
- l_h.get("points").append(_point)
- l_h["bbox"] = _bbox
- break
- if not _find:
- clust_horizontal.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
- tmp_vertical = []
- for _dict in clust_vertical:
- _bbox = _dict.get("bbox")
- if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len:
- tmp_vertical.append([(_bbox[0]+_bbox[2])/2,_bbox[1]-padding,(_bbox[0]+_bbox[2])/2,_bbox[3]+padding])
- tmp_horizontal = []
- for _dict in clust_horizontal:
- _bbox = _dict.get("bbox")
- if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len:
- tmp_horizontal.append([_bbox[0]-padding,(_bbox[1]+_bbox[3])/2,_bbox[2]+padding,(_bbox[1]+_bbox[3])/2])
- #merge lines
- tmp_vertical.sort(key=lambda x:x[3],reverse=True)
- tmp_horizontal.sort(key=lambda x:x[0])
- pop_index = []
- final_vertical = []
- for _line in tmp_vertical:
- _find = False
- x0,y0,x1,y1 = _line
- for _line2 in final_vertical:
- x2,y2,x3,y3 = _line2
- if abs(x0-x2)<line_width and abs(y0-y3)<cell_width or abs(y1-y2)<cell_width:
- _find = True
- final_vertical.append([x0,min(y0,y2),x1,max(y1,y3)])
- break
- if not _find:
- final_vertical.append(_line)
- final_horizontal = []
- for _line in tmp_horizontal:
- _find = False
- x0,y0,x1,y1 = _line
- for _line2 in final_horizontal:
- x2,y2,x3,y3 = _line2
- if abs(y0-y2)<line_width and abs(x0-x3)<cell_width or abs(x1-x2)<cell_width:
- _find = True
- final_horizontal.append([min(x0,x2),y0,max(x1,x3),y1])
- break
- if not _find:
- final_horizontal.append(_line)
- list_line = []
- for _line in final_vertical:
- list_line.append(_line)
- for _line in final_horizontal:
- list_line.append(_line)
- log("points2lines cost %.2fs"%(time.time()-_time))
- # import matplotlib.pyplot as plt
- # plt.figure()
- # for _line in list_line:
- # x0,y0,x1,y1 = _line
- # plt.plot([x0,x1],[y0,y1])
- # for _line in list_line:
- # x0,y0,x1,y1 = _line.bbox
- # plt.plot([x0,x1],[y0,y1])
- # for point in list_crosspoints:
- # plt.scatter(point.get("point")[0],point.get("point")[1])
- # plt.show()
- return list_line
- def get_line_from_binary_image(image_np, point_value=1, axis=0):
- """
- 根据像素点的变化,将像素点为特定值的转化为line,即找出端点坐标。
- 需要二值化的图。
- 仅支持竖线横线。
- :param image_np: numpy格式 image
- :param point_value: 像素点的特定值
- :param axis: 是否是行,否则为列
- :return: line list
- """
- def get_axis_points(_list, axis=0):
- _list.sort(key=lambda x: (x[1-axis], x[axis]))
- standard_axis = points[axis][1-axis]
- axis_points = []
- sub_points = []
- for p in _list:
- if p[1-axis] == standard_axis:
- sub_points.append(p)
- else:
- standard_axis = p[1-axis]
- if sub_points:
- axis_points.append(sub_points)
- sub_points = []
- # 最后一行/列
- if sub_points:
- axis_points.append(sub_points)
- return axis_points
- def get_axis_lines(_list, axis=0):
- # 逐行/列判断,一行/列可能多条横线/竖线
- points_lines = []
- for axis_list in _list:
- sub_line = [axis_list[0]]
- for p in axis_list:
- # 设置基准点
- standard_p = sub_line[-1]
- # 判断连续
- if p[axis] - standard_p[axis] == 1:
- sub_line.append(p)
- else:
- points_lines.append(sub_line)
- sub_line = [p]
- # 最后一行/列
- if sub_line:
- points_lines.append(sub_line)
- # 许多点组成的line转为两点line
- lines = []
- for line in points_lines:
- line.sort(key=lambda x: (x[axis], x[1-axis]))
- lines.append([line[0][0], line[0][1], line[-1][0], line[-1][1]])
- return lines
- # 取值大于point_value的点的坐标
- ys, xs = np.where(image_np >= point_value)
- points = [[xs[i], ys[i]] for i in range(len(xs))]
- # 提出所有相同x或相同y的点
- # 提取行/列
- axis_points = get_axis_points(points, axis)
- # 提取每行/列的横线/竖线
- axis_lines = get_axis_lines(axis_points, axis)
- # print("axis_lines", axis_lines)
- return axis_lines
- def table_preprocess(img_data, prob=0.2):
- try:
- log("into table_preprocess, prob is " + str(prob))
- start_time = time.time()
- # 二进制数据流转np.ndarray [np.uint8: 8位像素]
- img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
- # 将bgr转为rbg
- image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- # 模型输入
- inputs = np.array([image_np])
- # # 压缩numpy
- # compressed_array = io.BytesIO()
- # np.savez_compressed(compressed_array, inputs)
- # compressed_array.seek(0)
- # inputs_compressed = compressed_array.read()
- log("otr preprocess time: " + str(round(float(time.time()-start_time), 4)) + "s")
- return image_np, inputs
- except Exception as e:
- log("table pre process failed!")
- traceback.print_exc()
- return [-1], [-1]
- def table_postprocess(img_new, pred, prob=0.2, is_test=0):
- try:
- # 横线预测结果
- # row_pred = pred[..., 0] > hprob
- # row_pred = row_pred.astype(np.uint8)
- # # 竖线预测结果
- # col_pred = pred[..., 1] > vprob
- # col_pred = col_pred.astype(np.uint8)
- # # 打印模型输出
- # cv2.imshow("predict", (col_pred+row_pred)*255)
- # cv2.waitKey(0)
- start_time = time.time()
- list_line = points2lines(pred, False, prob=prob)
- mat_plot(list_line, "points2lines", is_test)
- log("points2lines " + str(time.time()-start_time))
- # 清除短线
- # print(img_new.shape)
- start_time = time.time()
- list_line = delete_short_lines(list_line, img_new.shape)
- mat_plot(list_line, "delete_short_lines", is_test)
- log("delete_short_lines " + str(time.time()-start_time))
- # # 清除无交点线--无需清除,会影响后面的结果
- # start_time = time.time()
- # list_line = delete_no_cross_lines(list_line)
- # mat_plot(list_line, "delete_no_cross_lines", is_test)
- # log("delete_no_cross_lines " + str(time.time()-start_time))
- # 分成横竖线
- start_time = time.time()
- list_rows = []
- list_cols = []
- for line in list_line:
- if line[0] == line[2]:
- list_cols.append(line)
- elif line[1] == line[3]:
- list_rows.append(line)
- log("divide rows and cols " + str(time.time()-start_time))
- # 合并错开线
- start_time = time.time()
- list_rows = merge_line(list_rows, axis=0)
- list_cols = merge_line(list_cols, axis=1)
- mat_plot(list_rows+list_cols, "merge_line", is_test)
- log("merge_line " + str(time.time()-start_time))
- # 计算交点、分割线
- start_time = time.time()
- cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
- if not cross_points:
- return []
- log("get_points " + str(time.time()-start_time))
- # 清掉外围的没用的线
- # list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points)
- # mat_plot(list_rows+list_cols, "delete_outline", is_test)
- # 多个表格分割线
- start_time = time.time()
- list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
- split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
- log("get_split_line " + str(time.time()-start_time))
- # 修复边框
- start_time = time.time()
- new_rows, new_cols, long_rows, long_cols = fix_outline(img_new, list_rows, list_cols, cross_points,
- split_y)
- # 如有补线
- if new_rows or new_cols:
- # 连接至补线的延长线
- if long_rows:
- list_rows = long_rows
- if long_cols:
- list_cols = long_cols
- # 新的补线
- if new_rows:
- list_rows += new_rows
- if new_cols:
- list_cols += new_cols
- list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
- # 修复边框后重新计算交点、分割线
- cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
- cv_plot(cross_points, img_new.shape, 0, is_test)
- split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
- print("fix new split_y", split_y)
- print("fix new split_lines", split_lines)
- # 修复内部缺线
- # cross_points = fix_inner(list_rows, list_cols, cross_points, split_y)
- # if not cross_points:
- # return []
- mat_plot(list_rows+list_cols, "fix_outline", is_test)
- split_lines_show = []
- for _l in split_lines:
- split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]])
- mat_plot(split_lines_show+list_cols,
- "split_lines", is_test)
- log("fix_outline " + str(time.time()-start_time))
- # 修复表格4个角
- start_time = time.time()
- list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0)
- mat_plot(list_rows+list_cols, "fix_corner", is_test)
- log("fix_corner " + str(time.time()-start_time))
- # 修复内部缺线
- start_time = time.time()
- list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y)
- mat_plot(list_rows+list_cols, "fix_inner", is_test)
- log("fix_inner " + str(time.time()-start_time))
- # 合并错开线
- start_time = time.time()
- list_rows = merge_line(list_rows, axis=0)
- list_cols = merge_line(list_cols, axis=1)
- mat_plot(list_rows+list_cols, "merge_line", is_test)
- log("merge_line " + str(time.time()-start_time))
- list_line = list_rows + list_cols
- # 打印处理后线
- mat_plot(list_line, "all", is_test)
- log("otr postprocess table_line " + str(time.time()-start_time))
- return list_line
- except Exception as e:
- log("table post process failed!")
- traceback.print_exc()
- return [-1]
- def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
- log("into table_line, prob is " + str(prob))
- sizew, sizeh = size
- img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
- start_time = time.time()
- pred = model.predict(np.array([img_new]))
- log("otr model predict time " + str(time.time()-start_time))
- pred = pred[0]
- draw_pixel(pred, prob, is_test)
- # 横线预测结果
- # row_pred = pred[..., 0] > hprob
- # row_pred = row_pred.astype(np.uint8)
- # # 竖线预测结果
- # col_pred = pred[..., 1] > vprob
- # col_pred = col_pred.astype(np.uint8)
- # # 打印模型输出
- # cv2.imshow("predict", (col_pred+row_pred)*255)
- # cv2.waitKey(0)
- start_time = time.time()
- list_line = points2lines(pred, False, prob=prob)
- mat_plot(list_line, "points2lines", is_test)
- log("points2lines " + str(time.time()-start_time))
- # 清除短线
- # print(img_new.shape)
- start_time = time.time()
- list_line = delete_short_lines(list_line, img_new.shape)
- mat_plot(list_line, "delete_short_lines", is_test)
- log("delete_short_lines " + str(time.time()-start_time))
- # # 清除无交点线--无需清除,会影响后面的结果
- # start_time = time.time()
- # list_line = delete_no_cross_lines(list_line)
- # mat_plot(list_line, "delete_no_cross_lines", is_test)
- # log("delete_no_cross_lines " + str(time.time()-start_time))
- # 分成横竖线
- start_time = time.time()
- list_rows = []
- list_cols = []
- for line in list_line:
- if line[0] == line[2]:
- list_cols.append(line)
- elif line[1] == line[3]:
- list_rows.append(line)
- log("divide rows and cols " + str(time.time()-start_time))
- # 合并错开线
- start_time = time.time()
- list_rows = merge_line(list_rows, axis=0)
- list_cols = merge_line(list_cols, axis=1)
- mat_plot(list_rows+list_cols, "merge_line", is_test)
- log("merge_line " + str(time.time()-start_time))
- # 计算交点、分割线
- start_time = time.time()
- cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
- if not cross_points:
- return []
- log("get_points " + str(time.time()-start_time))
- # 清掉外围的没用的线
- # list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points)
- # mat_plot(list_rows+list_cols, "delete_outline", is_test)
- # 多个表格分割线
- start_time = time.time()
- list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
- split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
- log("get_split_line " + str(time.time()-start_time))
- # 修复边框
- start_time = time.time()
- new_rows, new_cols, long_rows, long_cols = fix_outline(img_new, list_rows, list_cols, cross_points,
- split_y)
- # 如有补线
- if new_rows or new_cols:
- # 连接至补线的延长线
- if long_rows:
- list_rows = long_rows
- if long_cols:
- list_cols = long_cols
- # 新的补线
- if new_rows:
- list_rows += new_rows
- if new_cols:
- list_cols += new_cols
- list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
- # 修复边框后重新计算交点、分割线
- cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
- cv_plot(cross_points, img_new.shape, 0, is_test)
- split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
- print("fix new split_y", split_y)
- print("fix new split_lines", split_lines)
- # 修复内部缺线
- # cross_points = fix_inner(list_rows, list_cols, cross_points, split_y)
- # if not cross_points:
- # return []
- mat_plot(list_rows+list_cols, "fix_outline", is_test)
- split_lines_show = []
- for _l in split_lines:
- split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]])
- mat_plot(split_lines_show+list_cols,
- "split_lines", is_test)
- log("fix_outline " + str(time.time()-start_time))
- # 修复表格4个角
- start_time = time.time()
- list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0)
- mat_plot(list_rows+list_cols, "fix_corner", is_test)
- log("fix_corner " + str(time.time()-start_time))
- # 修复内部缺线
- start_time = time.time()
- list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y)
- mat_plot(list_rows+list_cols, "fix_inner", is_test)
- log("fix_inner " + str(time.time()-start_time))
- # 合并错开线
- start_time = time.time()
- list_rows = merge_line(list_rows, axis=0)
- list_cols = merge_line(list_cols, axis=1)
- mat_plot(list_rows+list_cols, "merge_line", is_test)
- log("merge_line " + str(time.time()-start_time))
- list_line = list_rows + list_cols
- # 打印处理后线
- mat_plot(list_line, "all", is_test)
- log("otr postprocess table_line " + str(time.time()-start_time))
- return list_line
- def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=30, alph=15):
- sizew, sizeh = size
- # [..., ::-1] 最后一维内部反向输出
- # inputBlob, fx, fy = letterbox_image(img[..., ::-1], (sizew, sizeh))
- # pred = model.predict(np.array([np.array(inputBlob)]))
- # pred = model.predict(np.array([np.array(inputBlob)/255.0]))
- img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
- # log("into table_line 1")
- pred = model.predict(np.array([img_new]))
- # log("into table_line 2")
- pred = pred[0]
- draw_pixel(pred)
- _time = time.time()
- points2lines(pred)
- log("points2lines takes %ds"%(time.time()-_time))
- vpred = pred[..., 1] > vprob # 横线
- hpred = pred[..., 0] > hprob # 竖线
- vpred = vpred.astype(int)
- hpred = hpred.astype(int)
- # print("vpred shape", vpred)
- # print("hpred shape", hpred)
- colboxes = get_table_line(vpred, axis=1, lineW=col)
- rowboxes = get_table_line(hpred, axis=0, lineW=row)
- # log("into table_line 3")
- # if len(rowboxes) > 0:
- # rowboxes = np.array(rowboxes)
- # rowboxes[:, [0, 2]] = rowboxes[:, [0, 2]]/fx
- # rowboxes[:, [1, 3]] = rowboxes[:, [1, 3]]/fy
- # rowboxes = rowboxes.tolist()
- # if len(colboxes) > 0:
- # colboxes = np.array(colboxes)
- # colboxes[:, [0, 2]] = colboxes[:, [0, 2]]/fx
- # colboxes[:, [1, 3]] = colboxes[:, [1, 3]]/fy
- # colboxes = colboxes.tolist()
- nrow = len(rowboxes)
- ncol = len(colboxes)
- for i in range(nrow):
- for j in range(ncol):
- rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], 10)
- colboxes[j] = line_to_line(colboxes[j], rowboxes[i], 10)
- # log("into table_line 4")
- # 删掉贴着边框的line
- temp_list = []
- threshold = 5
- for line in rowboxes:
- if line[1]-0 <= threshold or size[1]-line[1] <= threshold:
- continue
- # 内部排序
- if line[0] > line[2]:
- line = [line[2], line[3], line[0], line[1]]
- temp_list.append(line)
- rowboxes = temp_list
- temp_list = []
- for line in colboxes:
- if line[0]-0 <= threshold or size[0]-line[0] <= threshold:
- continue
- # 内部排序
- if line[1] > line[3]:
- line = [line[2], line[3], line[0], line[1]]
- temp_list.append(line)
- colboxes = temp_list
- return rowboxes, colboxes, img_new
- def fix_in_split_lines(_rows, _cols, _img):
- # 补线贴着边缘无法得到split_y,导致无法分区
- for _row in _rows:
- if _row[1] >= _img.shape[0] - 5:
- _row[1] = _img.shape[0] - 6
- _row[3] = _img.shape[0] - 6
- print("_row", _row)
- if _row[1] <= 0 + 5:
- _row[1] = 6
- _row[3] = 6
- for _col in _cols:
- if _col[3] >= _img.shape[0] - 5:
- _col[3] = _img.shape[0] - 6
- if _col[1] <= 0 + 5:
- _col[1] = 6
- return _rows, _cols
- def mat_plot(list_line, name="", is_test=1):
- if not is_test:
- return
- from matplotlib import pyplot as plt
- plt.figure()
- plt.title(name)
- for _line in list_line:
- x0, y0, x1, y1 = _line
- plt.plot([x0, x1], [y0, y1])
- plt.show()
- def cv_plot(_list, img_shape, line_or_point=1, is_test=1):
- if is_test == 0:
- return
- img_print = np.zeros(img_shape, np.uint8)
- img_print.fill(255)
- if line_or_point:
- for line in _list:
- cv2.line(img_print, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
- (255, 0, 0))
- cv2.imshow("cv_plot", img_print)
- cv2.waitKey(0)
- else:
- for point in _list:
- cv2.circle(img_print, (int(point[0]), int(point[1])), 1, (255, 0, 0), 2)
- cv2.imshow("cv_plot", img_print)
- cv2.waitKey(0)
- def delete_no_cross_lines(list_lines):
- def get_cross_point(l1, l2):
- # https://www.zhihu.com/question/381406535/answer/1095948349
- flag = 0
- # 相交一定是一条横线一条竖线
- if (l1[0] == l1[2] and l2[1] == l2[3]) or (l1[1] == l1[3] and l2[0] == l2[2]):
- if l1[0] <= l2[0] <= l1[2] and l2[1] <= l1[1] <= l2[3]:
- flag = 1
- elif l2[0] <= l1[0] <= l2[2] and l1[1] <= l2[1] <= l1[3]:
- flag = 1
- return flag
- new_list_lines = []
- for i in range(len(list_lines)):
- line1 = list_lines[i]
- find_flag = 0
- for j in range(i+1, len(list_lines)):
- line2 = list_lines[j]
- if get_cross_point(line1, line2):
- # print("delete_no_cross_lines", line1, line2)
- find_flag = 1
- if line2 not in new_list_lines:
- new_list_lines.append(line2)
- if find_flag and line1 not in new_list_lines:
- new_list_lines.append(line1)
- return new_list_lines
- def delete_short_lines(list_lines, image_shape, scale=100):
- x_min_len = max(5, int(image_shape[0] / scale))
- y_min_len = max(5, int(image_shape[1] / scale))
- new_list_lines = []
- for line in list_lines:
- if line[0] == line[2]:
- if abs(line[3] - line[1]) >= y_min_len:
- # print("y_min_len", abs(line[3] - line[1]), y_min_len)
- new_list_lines.append(line)
- else:
- if abs(line[2] - line[0]) >= x_min_len:
- # print("x_min_len", abs(line[2] - line[0]), x_min_len)
- new_list_lines.append(line)
- return new_list_lines
- def get_outline(points, image_np):
- # 取出x, y的最大值最小值
- x_min = points[0][0]
- x_max = points[-1][0]
- points.sort(key=lambda x: (x[1], x[0]))
- y_min = points[0][1]
- y_max = points[-1][1]
- # 创建空图
- # outline_img = np.zeros(image_size, np.uint8)
- outline_img = np.copy(image_np)
- cv2.rectangle(outline_img, (x_min-5, y_min-5), (x_max+5, y_max+5), (0, 0, 0), 2)
- # cv2.imshow("outline_img", outline_img)
- # cv2.waitKey(0)
- return outline_img
- def get_split_line(points, col_lines, image_np, threshold=5):
- # print("get_split_line", image_np.shape)
- points.sort(key=lambda x: (x[1], x[0]))
- # 遍历y坐标,并判断y坐标与上一个y坐标是否存在连接线
- i = 0
- split_line_y = []
- for point in points:
- # 从已分开的线下面开始判断
- if split_line_y:
- if point[1] <= split_line_y[-1] + threshold:
- last_y = point[1]
- continue
- if last_y <= split_line_y[-1] + threshold:
- last_y = point[1]
- continue
- if i == 0:
- last_y = point[1]
- i += 1
- continue
- current_line = (last_y, point[1])
- split_flag = 1
- for col in col_lines:
- # 只要找到一条col包含就不是分割线
- if current_line[0] >= col[1]-3 and current_line[1] <= col[3]+3:
- split_flag = 0
- # print("img", img.shape)
- # print("col", col)
- # print("current_line", current_line)
- break
- if split_flag:
- split_line_y.append(current_line[0]+5)
- split_line_y.append(current_line[1]-5)
- last_y = point[1]
- # 加上收尾分割线
- points.sort(key=lambda x: (x[1], x[0]))
- y_min = points[0][1]
- y_max = points[-1][1]
- # print("加上收尾分割线", y_min, y_max)
- if y_min-threshold < 0:
- split_line_y.append(0)
- else:
- split_line_y.append(y_min-threshold)
- if y_max+threshold > image_np.shape[0]:
- split_line_y.append(image_np.shape[0])
- else:
- split_line_y.append(y_max+threshold)
- split_line_y = list(set(split_line_y))
- # 剔除两条相隔太近分割线
- temp_split_line_y = []
- split_line_y.sort(key=lambda x: x)
- last_y = -20
- for y in split_line_y:
- # print(y)
- if y - last_y >= 20:
- # print(y, last_y)
- temp_split_line_y.append(y)
- last_y = y
- split_line_y = temp_split_line_y
- # print("split_line_y", split_line_y)
- # 生成分割线
- split_line = []
- last_y = 0
- for y in split_line_y:
- # if y - last_y <= 15:
- # continue
- split_line.append([(0, y), (image_np.shape[1], y)])
- last_y = y
- split_line.append([(0, 0), (image_np.shape[1], 0)])
- split_line.append([(0, image_np.shape[0]), (image_np.shape[1], image_np.shape[0])])
- split_line.sort(key=lambda x: x[0][1])
- # print("split_line", split_line)
- # 画图画线
- # split_line_img = np.copy(image_np)
- # for y in split_line_y:
- # cv2.line(split_line_img, (0, y), (image_np.shape[1], y), (0, 0, 0), 1)
- # cv2.imshow("split_line_img", split_line_img)
- # cv2.waitKey(0)
- return split_line, split_line_y
- def get_points(row_lines, col_lines, image_size):
- # 创建空图
- row_img = np.zeros(image_size, np.uint8)
- col_img = np.zeros(image_size, np.uint8)
- # 画线
- thresh = 3
- for row in row_lines:
- cv2.line(row_img, (int(row[0]-thresh), int(row[1])), (int(row[2]+thresh), int(row[3])), (255, 255, 255), 1)
- for col in col_lines:
- cv2.line(col_img, (int(col[0]), int(col[1]-thresh)), (int(col[2]), int(col[3]+thresh)), (255, 255, 255), 1)
- # 求出交点
- point_img = np.bitwise_and(row_img, col_img)
- # cv2.imwrite("get_points.jpg", row_img+col_img)
- # cv2.imshow("get_points", row_img+col_img)
- # cv2.waitKey(0)
- # 识别黑白图中的白色交叉点,将横纵坐标取出
- ys, xs = np.where(point_img > 0)
- points = []
- for i in range(len(xs)):
- points.append((xs[i], ys[i]))
- points.sort(key=lambda x: (x[0], x[1]))
- return points
- def get_minAreaRect(image):
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- gray = cv2.bitwise_not(gray)
- thresh = cv2.threshold(gray, 0, 255,
- cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
- coords = np.column_stack(np.where(thresh > 0))
- return cv2.minAreaRect(coords)
- def get_contours(image):
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
- contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- cv2.drawContours(image, contours, -1, (0, 0, 255), 3)
- cv2.imshow("get contours", image)
- cv2.waitKey(0)
- def merge_line(lines, axis, threshold=5):
- """
- 解决模型预测一条直线错开成多条直线,合并成一条直线
- :param lines: 线条列表
- :param axis: 0:横线 1:竖线
- :param threshold: 两条线间像素差阈值
- :return: 合并后的线条列表
- """
- # 任意一条line获取该合并的line,横线往下找,竖线往右找
- lines.sort(key=lambda x: (x[axis], x[1-axis]))
- merged_lines = []
- used_lines = []
- for line1 in lines:
- if line1 in used_lines:
- continue
- merged_line = [line1]
- used_lines.append(line1)
- for line2 in lines:
- if line2 in used_lines:
- continue
- if line1[1-axis]-threshold <= line2[1-axis] <= line1[1-axis]+threshold:
- # 计算基准长度
- min_axis = 10000
- max_axis = 0
- for line3 in merged_line:
- if line3[axis] < min_axis:
- min_axis = line3[axis]
- if line3[axis+2] > max_axis:
- max_axis = line3[axis+2]
- # 判断两条线有无交集
- if min_axis <= line2[axis] <= max_axis \
- or min_axis <= line2[axis+2] <= max_axis:
- merged_line.append(line2)
- used_lines.append(line2)
- if merged_line:
- merged_lines.append(merged_line)
- # 合并line
- result_lines = []
- for merged_line in merged_lines:
- # 获取line宽的平均值
- axis_average = 0
- for line in merged_line:
- axis_average += line[1-axis]
- axis_average = int(axis_average/len(merged_line))
- # 获取最长line两端
- merged_line.sort(key=lambda x: (x[axis]))
- axis_start = merged_line[0][axis]
- merged_line.sort(key=lambda x: (x[axis+2]))
- axis_end = merged_line[-1][axis+2]
- if axis:
- result_lines.append([axis_average, axis_start, axis_average, axis_end])
- else:
- result_lines.append([axis_start, axis_average, axis_end, axis_average])
- return result_lines
- def fix_inner2(row_points, col_points, row_lines, col_lines, threshold=3):
- for i in range(len(row_points)):
- row = row_points[i]
- row.sort(key=lambda x: (x[1], x[0]))
- for j in range(len(row)):
- # 当前点
- point = row[j]
- # 获取当前点在所在行的下个点
- if j >= len(row) - 1:
- next_row_point = []
- else:
- next_row_point = row[j+1]
- if next_row_point:
- for k in range(len(row_lines)):
- line = row_lines[k]
- if line[1] - threshold <= point[1] <= line[1] + threshold:
- if not line[0] <= point[0] <= next_row_point[0] <= line[2]:
- if point[0] <= line[2] < next_row_point[0]:
- if line[2] - point[0] >= 1/3 * (next_row_point[0] - point[0]):
- row_lines[k][2] = next_row_point[0]
- if point[0] < line[0] <= next_row_point[0]:
- if next_row_point[0] - line[0] >= 1/3 * (next_row_point[0] - point[0]):
- row_lines[k][0] = point[0]
- # 获取当前点所在列的下个点
- next_col_point = []
- for col in col_points:
- if point in col:
- col.sort(key=lambda x: (x[0], x[1]))
- if col.index(point) < len(col) - 1:
- next_col_point = col[col.index(point)+1]
- break
- # 获取当前点的对角线点,通过该列下个点所在行的下个点获得
- next_row_next_col_point = []
- if next_col_point:
- for row2 in row_points:
- if next_col_point in row2:
- row2.sort(key=lambda x: (x[1], x[0]))
- if row2.index(next_col_point) < len(row2) - 1:
- next_row_next_col_point = row2[row2.index(next_col_point)+1]
- break
- # 有该列下一点但没有该列下一点所在行的下个点
- if not next_row_next_col_point:
- # 如果有该行下个点
- if next_row_point:
- next_row_next_col_point = [next_row_point[0], next_col_point[1]]
- if next_col_point:
- for k in range(len(col_lines)):
- line = col_lines[k]
- if line[0] - threshold <= point[0] <= line[0] + threshold:
- if not line[1] <= point[1] <= next_col_point[1] <= line[3]:
- if point[1] <= line[3] < next_col_point[1]:
- if line[3] - point[1] >= 1/3 * (next_col_point[1] - point[1]):
- col_lines[k][3] = next_col_point[1]
- if point[1] < line[1] <= next_col_point[1]:
- if next_col_point[1] - line[1] >= 1/3 * (next_col_point[1] - point[1]):
- col_lines[k][1] = point[1]
- if next_row_next_col_point:
- for k in range(len(col_lines)):
- line = col_lines[k]
- if line[0] - threshold <= next_row_next_col_point[0] <= line[0] + threshold:
- if not line[1] <= point[1] <= next_row_next_col_point[1] <= line[3]:
- if point[1] < line[1] <= next_row_next_col_point[1]:
- if next_row_next_col_point[1] - line[1] >= 1/3 * (next_row_next_col_point[1] - point[1]):
- col_lines[k][1] = point[1]
- return row_lines, col_lines
- def fix_inner1(row_lines, col_lines, points, split_y):
- def fix(fix_lines, assist_lines, split_points, axis):
- new_points = []
- for line1 in fix_lines:
- min_assist_line = [[], []]
- min_distance = [1000, 1000]
- if_find = [0, 0]
- # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点
- fix_line_points = []
- for point in split_points:
- if abs(point[1-axis] - line1[1-axis]) <= 2:
- if line1[axis] <= point[axis] <= line1[axis+2]:
- fix_line_points.append(point)
- # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交
- line1_point = [line1[:2], line1[2:]]
- for i in range(2):
- point = line1_point[i]
- for line2 in assist_lines:
- if not if_find[i] and abs(point[axis] - line2[axis]) <= 2:
- if line1[1-axis] <= point[1-axis] <= line2[1-axis+2]:
- # print("line1, match line2", line1, line2)
- if_find[i] = 1
- break
- else:
- if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1-axis] <= point[1-axis] <= line2[1-axis+2]:
- if line1[axis] <= line2[axis] <= line1[axis+2]:
- continue
- min_distance[i] = abs(line1[axis] - line2[axis])
- min_assist_line[i] = line2
- # 找出离assist_line最近的交点
- # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3
- min_distance = [1000, 1000]
- min_col_point = [[], []]
- for i in range(2):
- # print("顶点", i, line1_point[i])
- if min_assist_line[i]:
- for point in fix_line_points:
- if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]:
- min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
- min_col_point[i] = point
- if min_col_point[i]:
- bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis])
- line_distance = abs(min_col_point[i][axis] - line1_point[i][axis])
- if bbox_len/3 <= line_distance <= bbox_len:
- add_point = (line1_point[i][1-axis], min_assist_line[i][axis])
- print("============================table line==")
- print("fix_inner add point", add_point)
- print(min_col_point[i][axis], line1_point[i][axis], min_col_point[i][axis], min_assist_line[i][axis])
- print(abs(min_col_point[i][axis] - line1_point[i][axis]), abs(min_col_point[i][axis] - min_assist_line[i][axis])/3)
- print("line1, line2", line1, min_assist_line[i])
- new_points.append(add_point)
- return new_points
- new_points = []
- for i in range(1, len(split_y)):
- last_y = split_y[i-1]
- y = split_y[i]
- # 先对点线进行分区
- split_row_lines = []
- split_col_lines = []
- split_points = []
- for row in row_lines:
- if last_y <= row[1] <= y:
- split_row_lines.append(row)
- for col in col_lines:
- if last_y <= col[1] <= y:
- split_col_lines.append(col)
- for point in points:
- if last_y <= point[1] <= y:
- split_points.append(point)
- new_points += fix(split_col_lines, split_row_lines, split_points, axis=1)
- new_points += fix(split_row_lines, split_col_lines, split_points, axis=0)
- # 找出所有col的顶点不在row上的、row的顶点不在col上的
- # for col in split_col_lines:
- # print("*"*30)
- #
- # # 获取该line中的所有point
- # col_points = []
- # for point in split_points:
- # if abs(point[0] - col[0]) <= 2:
- # if col[1] <= point[1] <= col[3]:
- # col_points.append(point)
- #
- # # 比较顶点
- # min_row_1 = []
- # min_row_2 = []
- # min_distance_1 = 1000
- # min_distance_2 = 1000
- # if_find_1 = 0
- # if_find_2 = 0
- # for row in split_row_lines:
- # # 第一个顶点
- # if not if_find_1 and abs(col[1] - row[1]) <= 2:
- # if row[0] <= col[0] <= row[2]:
- # print("col, match row", col, row)
- # if_find_1 = 1
- # break
- # else:
- # if abs(col[1] - row[1]) < min_distance_1 and row[0] <= col[0] <= row[2]:
- # if col[1] <= row[1] <= col[3]:
- # continue
- # min_distance_1 = abs(col[1] - row[1])
- # min_row_1 = row
- #
- # # 第二个顶点
- # if not if_find_2 and abs(col[3] - row[1]) <= 2:
- # if row[0] <= col[2] <= row[2]:
- # if_find_2 = 1
- # break
- # else:
- # if abs(col[3] - row[1]) < min_distance_2 and row[0] <= col[2] <= row[2]:
- # min_distance_2 = abs(col[3] - row[1])
- # min_row_2 = row
- #
- # if not if_find_1:
- # print("col", col)
- # print("min_row_1", min_row_1)
- # if min_row_1:
- # min_distance_1 = 1000
- # min_col_point = []
- # for point in col_points:
- # if abs(point[1] - min_row_1[1]) < min_distance_1:
- # min_distance_1 = abs(point[1] - min_row_1[1])
- # min_col_point = point
- #
- # if abs(min_col_point[1] - col[1]) >= abs(min_col_point[1] - min_row_1[1])/3:
- #
- # add_point = (col[0], min_row_1[1])
- # print("fix_inner add point", add_point)
- # new_points.append(add_point)
- # else:
- # print("distance too long", min_col_point, min_row_1)
- # print(abs(min_col_point[1] - col[1]), abs(min_col_point[1] - min_row_1[1])/3)
- return points+new_points
- def fix_inner(row_lines, col_lines, points, split_y):
- def fix(fix_lines, assist_lines, split_points, axis):
- new_points = []
- for line1 in fix_lines:
- min_assist_line = [[], []]
- min_distance = [1000, 1000]
- if_find = [0, 0]
- # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点
- fix_line_points = []
- for point in split_points:
- if abs(point[1-axis] - line1[1-axis]) <= 2:
- if line1[axis] <= point[axis] <= line1[axis+2]:
- fix_line_points.append(point)
- # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交
- line1_point = [line1[:2], line1[2:]]
- for i in range(2):
- point = line1_point[i]
- for line2 in assist_lines:
- if not if_find[i] and abs(point[axis] - line2[axis]) <= 2:
- if line1[1-axis] <= point[1-axis] <= line2[1-axis+2]:
- # print("line1, match line2", line1, line2)
- if_find[i] = 1
- break
- else:
- if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1-axis] <= point[1-axis] <= line2[1-axis+2]:
- if line1[axis] <= line2[axis] <= line1[axis+2]:
- continue
- min_distance[i] = abs(line1[axis] - line2[axis])
- min_assist_line[i] = line2
- # 找出离assist_line最近的交点
- # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3
- min_distance = [1000, 1000]
- min_col_point = [[], []]
- for i in range(2):
- # print("顶点", i, line1_point[i])
- if min_assist_line[i]:
- for point in fix_line_points:
- if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]:
- min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
- min_col_point[i] = point
- # print("min_col_point", min_col_point)
- # print("min_assist_line", min_assist_line)
- # print("line1_point", line1_point)
- if min_assist_line[0] and min_assist_line[0] == min_assist_line[1]:
- if min_assist_line[0][axis] < line1_point[0][axis]:
- bbox_len = abs(min_col_point[0][axis] - min_assist_line[0][axis])
- line_distance = abs(min_col_point[0][axis] - line1_point[0][axis])
- if bbox_len/3 <= line_distance <= bbox_len:
- if axis == 1:
- add_point = (line1_point[0][1-axis], min_assist_line[0][axis])
- else:
- add_point = (min_assist_line[0][axis], line1_point[0][1-axis])
- new_points.append([line1, add_point])
- elif min_assist_line[1][axis] > line1_point[1][axis]:
- bbox_len = abs(min_col_point[1][axis] - min_assist_line[1][axis])
- line_distance = abs(min_col_point[1][axis] - line1_point[1][axis])
- if bbox_len/3 <= line_distance <= bbox_len:
- if axis == 1:
- add_point = (line1_point[1][1-axis], min_assist_line[1][axis])
- else:
- add_point = (min_assist_line[1][axis], line1_point[1][1-axis])
- new_points.append([line1, add_point])
- else:
- for i in range(2):
- if min_col_point[i]:
- bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis])
- line_distance = abs(min_col_point[i][axis] - line1_point[i][axis])
- # print("bbox_len, line_distance", bbox_len, line_distance)
- if bbox_len/3 <= line_distance <= bbox_len:
- if axis == 1:
- add_point = (line1_point[i][1-axis], min_assist_line[i][axis])
- else:
- add_point = (min_assist_line[i][axis], line1_point[i][1-axis])
- # print("============================table line==")
- # print("fix_inner add point", add_point)
- # print(min_col_point[i][axis], line1_point[i][axis], min_col_point[i][axis], min_assist_line[i][axis])
- # print(abs(min_col_point[i][axis] - line1_point[i][axis]), abs(min_col_point[i][axis] - min_assist_line[i][axis])/3)
- # print("line1, line2", line1, min_assist_line[i])
- # print("line1, add_point", [line1, add_point])
- new_points.append([line1, add_point])
- return new_points
- row_lines_copy = copy.deepcopy(row_lines)
- col_lines_copy = copy.deepcopy(col_lines)
- try:
- new_points = []
- for i in range(1, len(split_y)):
- last_y = split_y[i-1]
- y = split_y[i]
- # 先对点线进行分区
- split_row_lines = []
- split_col_lines = []
- split_points = []
- for row in row_lines:
- if last_y <= row[1] <= y:
- split_row_lines.append(row)
- for col in col_lines:
- if last_y <= col[1] <= y:
- split_col_lines.append(col)
- for point in points:
- if last_y <= point[1] <= y:
- split_points.append(point)
- new_point_list = fix(split_col_lines, split_row_lines, split_points, axis=1)
- for line, new_point in new_point_list:
- if line in col_lines:
- index = col_lines.index(line)
- point1 = line[:2]
- point2 = line[2:]
- if new_point[1] >= point2[1]:
- col_lines[index] = [point1[0], point1[1], new_point[0], new_point[1]]
- elif new_point[1] <= point1[1]:
- col_lines[index] = [new_point[0], new_point[1], point2[0], point2[1]]
- new_point_list = fix(split_row_lines, split_col_lines, split_points, axis=0)
- for line, new_point in new_point_list:
- if line in row_lines:
- index = row_lines.index(line)
- point1 = line[:2]
- point2 = line[2:]
- if new_point[0] >= point2[0]:
- row_lines[index] = [point1[0], point1[1], new_point[0], new_point[1]]
- elif new_point[0] <= point1[0]:
- row_lines[index] = [new_point[0], new_point[1], point2[0], point2[1]]
- return row_lines, col_lines
- except:
- traceback.print_exc()
- return row_lines_copy, col_lines_copy
- def fix_corner(row_lines, col_lines, split_y, threshold=0):
- new_row_lines = []
- new_col_lines = []
- last_y = split_y[0]
- for y in split_y:
- if y == last_y:
- continue
- split_row_lines = []
- split_col_lines = []
- for row in row_lines:
- if last_y-threshold <= row[1] <= y+threshold or last_y-threshold <= row[3] <= y+threshold:
- split_row_lines.append(row)
- for col in col_lines:
- # fix corner 容易因split line 漏掉线
- if last_y-threshold <= col[1] <= y+threshold or last_y-threshold <= col[3] <= y+threshold:
- split_col_lines.append(col)
- if not split_row_lines or not split_col_lines:
- last_y = y
- continue
- split_row_lines.sort(key=lambda x: (x[1], x[0]))
- split_col_lines.sort(key=lambda x: (x[0], x[1]))
- up_line = split_row_lines[0]
- bottom_line = split_row_lines[-1]
- left_line = split_col_lines[0]
- right_line = split_col_lines[-1]
- # 左上角
- if up_line[0:2] != left_line[0:2]:
- # print("up_line, left_line", up_line, left_line)
- add_corner = [left_line[0], up_line[1]]
- split_row_lines[0][0] = add_corner[0]
- split_col_lines[0][1] = add_corner[1]
- # 右上角
- if up_line[2:] != right_line[:2]:
- # print("up_line, right_line", up_line, right_line)
- add_corner = [right_line[0], up_line[1]]
- split_row_lines[0][2] = add_corner[0]
- split_col_lines[-1][1] = add_corner[1]
- new_row_lines = new_row_lines + split_row_lines
- new_col_lines = new_col_lines + split_col_lines
- last_y = y
- return new_row_lines, new_col_lines
- def delete_outline(row_lines, col_lines, points):
- row_lines.sort(key=lambda x: (x[1], x[0]))
- col_lines.sort(key=lambda x: (x[0], x[1]))
- line = [row_lines[0], row_lines[-1], col_lines[0], col_lines[-1]]
- threshold = 2
- point_cnt = [0, 0, 0, 0]
- for point in points:
- for i in range(4):
- if i < 2:
- if line[i][1]-threshold <= point[1] <= line[i][1]+threshold:
- if line[i][0] <= point[0] <= line[i][2]:
- point_cnt[i] += 1
- else:
- if line[i][0]-threshold <= point[0] <= line[i][0]+threshold:
- if line[i][1] <= point[1] <= line[i][3]:
- point_cnt[i] += 1
- # if line[0][1]-threshold <= point[1] <= line[0][1]+threshold:
- # if line[0][0] <= point[0] <= line[0][2]:
- # point_cnt[0] += 1
- # elif line[1][1]-threshold <= point[1] <= line[1][1]+threshold:
- # if line[1][0] <= point[0] <= line[1][2]:
- # point_cnt[1] += 1
- # elif line[2][0]-threshold <= point[0] <= line[2][0]+threshold:
- # if line[2][1] <= point[1] <= line[2][3]:
- # point_cnt[2] += 1
- # elif line[3][0]-threshold <= point[0] <= line[3][0]+threshold:
- # if line[3][1] <= point[1] <= line[3][3]:
- # point_cnt[3] += 1
- # 轮廓line至少包含3个交点
- for i in range(4):
- if point_cnt[i] < 3:
- if i < 2:
- if line[i] in row_lines:
- row_lines.remove(line[i])
- else:
- if line[i] in col_lines:
- col_lines.remove(line[i])
- return row_lines, col_lines
- def fix_outline2(image, row_lines, col_lines, points, split_y):
- print("split_y", split_y)
- # 分割线纵坐标
- if len(split_y) < 2:
- return [], [], [], []
- # elif len(split_y) == 2:
- # split_y = [2000., 2000., 2000., 2000.]
- split_y.sort(key=lambda x: x)
- new_split_y = []
- for i in range(1, len(split_y), 2):
- new_split_y.append(int((split_y[i]+split_y[i-1])/2))
- # # 查看是否正确输出区域分割线
- # for line in split_y:
- # cv2.line(image, (0, int(line)), (int(image.shape[1]), int(line)), (0, 0, 255), 2)
- # cv2.imshow("split_y", image)
- # cv2.waitKey(0)
- # 预测线根据分割线纵坐标分为多个分割区域
- # row_lines.sort(key=lambda x: (x[3], x[2], x[1], x[0]))
- # col_lines.sort(key=lambda x: (x[3], x[2], x[1], x[0]))
- # points.sort(key=lambda x: (x[1], x[0]))
- # row_count = 0
- # col_count = 0
- # point_count = 0
- split_row_list = []
- split_col_list = []
- split_point_list = []
- # for i in range(1, len(split_y)):
- # y = split_y[i]
- # last_y = split_y[i-1]
- # row_lines = row_lines[row_count:]
- # col_lines = col_lines[col_count:]
- # points = points[point_count:]
- # row_count = 0
- # col_count = 0
- # point_count = 0
- #
- # if not row_lines:
- # split_row_list.append([])
- # for row in row_lines:
- # if last_y <= row[3] <= y:
- # row_count += 1
- # else:
- # split_row_list.append(row_lines[:row_count])
- # break
- # if row_count == len(row_lines):
- # split_row_list.append(row_lines[:row_count])
- # break
- #
- # if not col_lines:
- # split_col_list.append([])
- #
- # for col in col_lines:
- # # if last_y <= col[3] <= y:
- # if col[1] <= last_y <= y <= col[3] or last_y <= col[3] <= y:
- # # if last_y <= col[1] <= y or last_y <= col[3] <= y:
- # col_count += 1
- # else:
- # split_col_list.append(col_lines[:col_count])
- # break
- # if col_count == len(col_lines):
- # split_col_list.append(col_lines[:col_count])
- # break
- #
- # if not points:
- # split_point_list.append([])
- # for point in points:
- # if last_y <= point[1] <= y:
- # point_count += 1
- # else:
- # split_point_list.append(points[:point_count])
- # break
- # if point_count == len(points):
- # split_point_list.append(points[:point_count])
- # break
- #
- # # print("len(split_row_list)", len(split_row_list))
- # # print("len(split_col_list)", len(split_col_list))
- # if row_count < len(row_lines) - 1 and col_count < len(col_lines) - 1:
- # row_lines = row_lines[row_count:]
- # split_row_list.append(row_lines)
- # col_lines = col_lines[col_count:]
- # split_col_list.append(col_lines)
- #
- # if point_count < len(points) - 1:
- # points = points[point_count:len(points)]
- # split_point_list.append(points)
- for i in range(1, len(split_y)):
- y = split_y[i]
- last_y = split_y[i-1]
- split_row = []
- for row in row_lines:
- if last_y <= row[3] <= y:
- split_row.append(row)
- split_row_list.append(split_row)
- split_col = []
- for col in col_lines:
- if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]:
- split_col.append(col)
- split_col_list.append(split_col)
- split_point = []
- for point in points:
- if last_y <= point[1] <= y:
- split_point.append(point)
- split_point_list.append(split_point)
- # 预测线取上下左右4个边(会有超出表格部分) [(), ()]
- area_row_line = []
- area_col_line = []
- for area in split_row_list:
- if not area:
- area_row_line.append([])
- continue
- area.sort(key=lambda x: (x[1], x[0]))
- up_line = area[0]
- bottom_line = area[-1]
- area_row_line.append([up_line, bottom_line])
- for area in split_col_list:
- if not area:
- area_col_line.append([])
- continue
- area.sort(key=lambda x: x[0])
- left_line = area[0]
- right_line = area[-1]
- area_col_line.append([left_line, right_line])
- # 线交点根据分割线纵坐标分为多个分割区域
- # points.sort(key=lambda x: (x[1], x[0]))
- # point_count = 0
- # split_point_list = []
- # for y in new_split_y:
- # points = points[point_count:len(points)]
- # point_count = 0
- # for point in points:
- # if point[1] <= y:
- # point_count += 1
- # else:
- # split_point_list.append(points[:point_count])
- # break
- # if point_count == len(points):
- # split_point_list.append(points[:point_count])
- # break
- # if point_count < len(points) - 1:
- # points = points[point_count:len(points)]
- # split_point_list.append(points)
- # print("len(split_point_list)", len(split_point_list))
- # 取每个分割区域的4条线(无超出表格部分)
- area_row_line2 = []
- area_col_line2 = []
- for area in split_point_list:
- if not area:
- area_row_line2.append([])
- area_col_line2.append([])
- continue
- area.sort(key=lambda x: (x[0], x[1]))
- left_up = area[0]
- right_bottom = area[-1]
- up_line = [left_up[0], left_up[1], right_bottom[0], left_up[1]]
- bottom_line = [left_up[0], right_bottom[1], right_bottom[0], right_bottom[1]]
- left_line = [left_up[0], left_up[1], left_up[0], right_bottom[1]]
- right_line = [right_bottom[0], left_up[1], right_bottom[0], right_bottom[1]]
- area_row_line2.append([up_line, bottom_line])
- area_col_line2.append([left_line, right_line])
- # 判断超出部分的长度,超出一定长度就补线
- new_row_lines = []
- new_col_lines = []
- longer_row_lines = []
- longer_col_lines = []
- all_longer_row_lines = []
- all_longer_col_lines = []
- # print("split_y", split_y)
- # print("split_row_list", split_row_list, len(split_row_list))
- # print("split_row_list", split_col_list, len(split_col_list))
- # print("area_row_line", area_row_line, len(area_row_line))
- # print("area_col_line", area_col_line, len(area_col_line))
- for i in range(len(area_row_line)):
- if not area_row_line[i] or not area_col_line[i]:
- continue
- up_line = area_row_line[i][0]
- up_line2 = area_row_line2[i][0]
- bottom_line = area_row_line[i][1]
- bottom_line2 = area_row_line2[i][1]
- left_line = area_col_line[i][0]
- left_line2 = area_col_line2[i][0]
- right_line = area_col_line[i][1]
- right_line2 = area_col_line2[i][1]
- # 计算单格高度宽度
- if len(split_row_list[i]) > 1:
- height_dict = {}
- for j in range(len(split_row_list[i])):
- if j + 1 > len(split_row_list[i]) - 1:
- break
- height = abs(int(split_row_list[i][j][3] - split_row_list[i][j+1][3]))
- if height in height_dict.keys():
- height_dict[height] = height_dict[height] + 1
- else:
- height_dict[height] = 1
- height_list = [[x, height_dict[x]] for x in height_dict.keys()]
- height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True)
- # print("height_list", height_list)
- box_height = height_list[0][0]
- else:
- box_height = 10
- if len(split_col_list[i]) > 1:
- box_width = abs(split_col_list[i][1][2] - split_col_list[i][0][2])
- else:
- box_width = 10
- print("box_height", box_height, "box_width", box_width)
- # cv2.line(image, (int(up_line[0]), int(up_line[1])),
- # (int(up_line[2]), int(up_line[3])),
- # (255, 255, 0), 2)
- # cv2.line(image, (int(right_line[0]), int(right_line[1])),
- # (int(right_line[2]), int(right_line[3])),
- # (0, 255, 255), 2)
- # cv2.imshow("right_line", image)
- # cv2.waitKey(0)
- # 补左右两条竖线超出来的线的row
- if (up_line[1] - left_line[1] >= 10 and up_line[1] - right_line[1] >= 2) or \
- (up_line[1] - left_line[1] >= 2 and up_line[1] - right_line[1] >= 10):
- if up_line[1] - left_line[1] >= up_line[1] - right_line[1]:
- new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]])
- new_col_y = left_line[1]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差大于一格
- # print("abs(new_col_y - col[1])", abs(new_col_y - col[1]))
- if abs(new_col_y - col[1]) <= box_height:
- split_col_list[i][j][1] = min([new_col_y, col[1]])
- longer_col_lines.append([col[0], min([new_col_y, col[1]]), col[2], col[3]])
- else:
- new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]])
- new_col_y = right_line[1]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- # 需判断该线在这个区域中
- # if up_line2[1]-3 <= col[1] <= col[3] <= bottom_line2[1]+3:
- col = split_col_list[i][j]
- # 且距离不能相差太大
- # print("abs(new_col_y - col[1])", abs(new_col_y - col[1]))
- if abs(new_col_y - col[1]) <= box_height:
- split_col_list[i][j][1] = min([new_col_y, col[1]])
- if (left_line[3] - bottom_line[3] >= 10 and right_line[3] - bottom_line[3] >= 2) or \
- (left_line[3] - bottom_line[3] >= 2 and right_line[3] - bottom_line[3] >= 10):
- if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]:
- new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]])
- new_col_y = left_line[3]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差太大
- if abs(new_col_y - col[3]) <= box_height:
- split_col_list[i][j][3] = max([new_col_y, col[3]])
- else:
- new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]])
- new_col_y = right_line[3]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差太大
- if abs(new_col_y - col[3]) <= box_height:
- split_col_list[i][j][3] = max([new_col_y, col[3]])
- # 补上下两条横线超出来的线的col
- if (left_line[0] - up_line[0] >= 10 and left_line[0] - bottom_line[0] >= 2) or \
- (left_line[0] - up_line[0] >= 2 and left_line[0] - bottom_line[0] >= 10):
- if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]:
- new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]])
- new_row_x = up_line[0]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[0]) <= box_width:
- split_row_list[i][j][0] = min([new_row_x, row[0]])
- else:
- new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]])
- new_row_x = bottom_line[0]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[0]) <= box_width:
- split_row_list[i][j][0] = min([new_row_x, row[0]])
- if (up_line[2] - right_line[2] >= 10 and bottom_line[2] - right_line[2] >= 2) or \
- (up_line[2] - right_line[2] >= 2 and bottom_line[2] - right_line[2] >= 10):
- if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]:
- new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]])
- new_row_x = up_line[2]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[2]) <= box_width:
- split_row_list[i][j][2] = max([new_row_x, row[2]])
- else:
- new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]])
- new_row_x = bottom_line[2]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- # 需判断该线在这个区域中
- # if up_line2[1]-3 <= row[1] <= bottom_line2[1]+3:
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[2]) <= box_width:
- split_row_list[i][j][2] = max([new_row_x, row[2]])
- all_longer_row_lines += split_row_list[i]
- all_longer_col_lines += split_col_list[i]
- # print("all_longer_row_lines", len(all_longer_row_lines), i)
- # print("all_longer_col_lines", len(all_longer_col_lines), i)
- # print("new_row_lines", len(new_row_lines), i)
- # print("new_col_lines", len(new_col_lines), i)
- # 删除表格内部的补线
- # temp_list = []
- # for row in new_row_lines:
- # if up_line[1]-5 <= row[1] <= bottom_line[1]+5:
- # continue
- # temp_list.append(row)
- # print("fix_outline", new_row_lines)
- # new_row_lines = temp_list
- # print("fix_outline", new_row_lines)
- # temp_list = []
- # for col in new_col_lines:
- # if left_line[0]-5 <= col[0] <= right_line[0]+5:
- # continue
- # temp_list.append(col)
- #
- # new_col_lines = temp_list
- # print("fix_outline", new_col_lines)
- # print("fix_outline", new_row_lines)
- # 删除重复包含的补线
- # temp_list = []
- # for row in new_row_lines:
- # if up_line[1]-5 <= row[1] <= bottom_line[1]+5:
- # continue
- # temp_list.append(row)
- # new_row_lines = temp_list
- # 展示上下左右边框线
- # for i in range(len(area_row_line)):
- # print("row1", area_row_line[i])
- # print("row2", area_row_line2[i])
- # print("col1", area_col_line[i])
- # print("col2", area_col_line2[i])
- # cv2.line(image, (int(area_row_line[i][0][0]), int(area_row_line[i][0][1])),
- # (int(area_row_line[i][0][2]), int(area_row_line[i][0][3])), (0, 255, 0), 2)
- # cv2.line(image, (int(area_row_line2[i][1][0]), int(area_row_line2[i][1][1])),
- # (int(area_row_line2[i][1][2]), int(area_row_line2[i][1][3])), (0, 0, 255), 2)
- # cv2.imshow("fix_outline", image)
- # cv2.waitKey(0)
- # 展示所有线
- # for line in all_longer_col_lines:
- # cv2.line(image, (int(line[0]), int(line[1])),
- # (int(line[2]), int(line[3])),
- # (0, 255, 0), 2)
- # cv2.imshow("fix_outline", image)
- # cv2.waitKey(0)
- # for line in all_longer_row_lines:
- # cv2.line(image, (int(line[0]), int(line[1])),
- # (int(line[2]), int(line[3])),
- # (0, 0, 255), 2)
- # cv2.imshow("fix_outline", image)
- # cv2.waitKey(0)
- return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines
- def fix_outline(image, row_lines, col_lines, points, split_y, scale=25):
- log("into fix_outline")
- x_min_len = max(10, int(image.shape[0] / scale))
- y_min_len = max(10, int(image.shape[1] / scale))
- # print("x_min_len", x_min_len, "y_min_len", y_min_len)
- # print("split_y", split_y)
- # 分割线纵坐标
- if len(split_y) < 2:
- return [], [], [], []
- split_y.sort(key=lambda x: x)
- new_split_y = []
- for i in range(1, len(split_y), 2):
- new_split_y.append(int((split_y[i]+split_y[i-1])/2))
- split_row_list = []
- split_col_list = []
- split_point_list = []
- for i in range(1, len(split_y)):
- y = split_y[i]
- last_y = split_y[i-1]
- split_row = []
- for row in row_lines:
- if last_y <= row[3] <= y:
- split_row.append(row)
- split_row_list.append(split_row)
- split_col = []
- for col in col_lines:
- if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]:
- split_col.append(col)
- split_col_list.append(split_col)
- split_point = []
- for point in points:
- if last_y <= point[1] <= y:
- split_point.append(point)
- split_point_list.append(split_point)
- # 预测线取上下左右4个边(会有超出表格部分) [(), ()]
- area_row_line = []
- area_col_line = []
- for area in split_row_list:
- if not area:
- area_row_line.append([])
- continue
- area.sort(key=lambda x: (x[1], x[0]))
- up_line = area[0]
- bottom_line = area[-1]
- area_row_line.append([up_line, bottom_line])
- for area in split_col_list:
- if not area:
- area_col_line.append([])
- continue
- area.sort(key=lambda x: x[0])
- left_line = area[0]
- right_line = area[-1]
- area_col_line.append([left_line, right_line])
- # 取每个分割区域的4条线(无超出表格部分)
- area_row_line2 = []
- area_col_line2 = []
- for area in split_point_list:
- if not area:
- area_row_line2.append([])
- area_col_line2.append([])
- continue
- area.sort(key=lambda x: (x[0], x[1]))
- left_up = area[0]
- right_bottom = area[-1]
- up_line = [left_up[0], left_up[1], right_bottom[0], left_up[1]]
- bottom_line = [left_up[0], right_bottom[1], right_bottom[0], right_bottom[1]]
- left_line = [left_up[0], left_up[1], left_up[0], right_bottom[1]]
- right_line = [right_bottom[0], left_up[1], right_bottom[0], right_bottom[1]]
- area_row_line2.append([up_line, bottom_line])
- area_col_line2.append([left_line, right_line])
- # 判断超出部分的长度,超出一定长度就补线
- new_row_lines = []
- new_col_lines = []
- longer_row_lines = []
- longer_col_lines = []
- all_longer_row_lines = []
- all_longer_col_lines = []
- for i in range(len(area_row_line)):
- if not area_row_line[i] or not area_col_line[i]:
- continue
- up_line = area_row_line[i][0]
- up_line2 = area_row_line2[i][0]
- bottom_line = area_row_line[i][1]
- bottom_line2 = area_row_line2[i][1]
- left_line = area_col_line[i][0]
- left_line2 = area_col_line2[i][0]
- right_line = area_col_line[i][1]
- right_line2 = area_col_line2[i][1]
- # 计算单格高度宽度
- if len(split_row_list[i]) > 1:
- height_dict = {}
- for j in range(len(split_row_list[i])):
- if j + 1 > len(split_row_list[i]) - 1:
- break
- # print("height_dict", split_row_list[i][j], split_row_list[i][j+1])
- height = abs(int(split_row_list[i][j][3] - split_row_list[i][j+1][3]))
- if height >= 10:
- if height in height_dict.keys():
- height_dict[height] = height_dict[height] + 1
- else:
- height_dict[height] = 1
- height_list = [[x, height_dict[x]] for x in height_dict.keys()]
- height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True)
- # print("box_height", height_list)
- box_height = height_list[0][0]
- else:
- box_height = y_min_len
- if len(split_col_list[i]) > 1:
- box_width = abs(split_col_list[i][1][2] - split_col_list[i][0][2])
- else:
- box_width = x_min_len
- # print("box_height", box_height, "box_width", box_width)
- # 设置轮廓线需超出阈值
- if box_height >= 2*y_min_len:
- fix_h_len = y_min_len
- else:
- fix_h_len = box_height * 2/3
- if box_width >= 2*x_min_len:
- fix_w_len = x_min_len
- else:
- fix_w_len = box_width * 2/3
- # 补左右两条竖线超出来的线的row
- if up_line[1] - left_line[1] >= fix_h_len and up_line[1] - right_line[1] >= fix_h_len:
- if up_line[1] - left_line[1] >= up_line[1] - right_line[1]:
- new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]])
- new_col_y = left_line[1]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差大于一格
- # print("abs(new_col_y - col[1])", abs(new_col_y - col[1]))
- if abs(new_col_y - col[1]) <= box_height:
- split_col_list[i][j][1] = min([new_col_y, col[1]])
- longer_col_lines.append([col[0], min([new_col_y, col[1]]), col[2], col[3]])
- else:
- new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]])
- new_col_y = right_line[1]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- # 需判断该线在这个区域中
- # if up_line2[1]-3 <= col[1] <= col[3] <= bottom_line2[1]+3:
- col = split_col_list[i][j]
- # 且距离不能相差太大
- # print("abs(new_col_y - col[1])", abs(new_col_y - col[1]))
- if abs(new_col_y - col[1]) <= box_height:
- split_col_list[i][j][1] = min([new_col_y, col[1]])
- if left_line[3] - bottom_line[3] >= fix_h_len and right_line[3] - bottom_line[3] >= fix_h_len:
- if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]:
- new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]])
- new_col_y = left_line[3]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差太大
- if abs(new_col_y - col[3]) <= box_height:
- split_col_list[i][j][3] = max([new_col_y, col[3]])
- else:
- new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]])
- new_col_y = right_line[3]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差太大
- if abs(new_col_y - col[3]) <= box_height:
- split_col_list[i][j][3] = max([new_col_y, col[3]])
- # 补上下两条横线超出来的线的col
- if left_line[0] - up_line[0] >= fix_w_len and left_line[0] - bottom_line[0] >= fix_w_len:
- if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]:
- new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]])
- new_row_x = up_line[0]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[0]) <= box_width:
- split_row_list[i][j][0] = min([new_row_x, row[0]])
- else:
- new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]])
- new_row_x = bottom_line[0]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[0]) <= box_width:
- split_row_list[i][j][0] = min([new_row_x, row[0]])
- if up_line[2] - right_line[2] >= fix_w_len and bottom_line[2] - right_line[2] >= fix_w_len:
- if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]:
- new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]])
- new_row_x = up_line[2]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[2]) <= box_width:
- split_row_list[i][j][2] = max([new_row_x, row[2]])
- else:
- new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]])
- new_row_x = bottom_line[2]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- # 需判断该线在这个区域中
- # if up_line2[1]-3 <= row[1] <= bottom_line2[1]+3:
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[2]) <= box_width:
- split_row_list[i][j][2] = max([new_row_x, row[2]])
- all_longer_row_lines += split_row_list[i]
- all_longer_col_lines += split_col_list[i]
- return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines
- def fix_table(row_point_list, col_point_list, split_y, row_lines, col_lines):
- # 分割线纵坐标
- if len(split_y) < 2:
- return []
- # 获取bbox
- bbox = []
- # 每个点获取与其x最相近和y最相近的点
- for i in range(1, len(split_y)):
- # 循环每行
- for row in row_point_list:
- row.sort(key=lambda x: (x[0], x[1]))
- # 行不在该区域跳过
- if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]:
- continue
- # print("len(row)", len(row))
- # print("row", row)
- # 循环行中的点
- for j in range(len(row)):
- if j == len(row) - 1:
- break
- current_point = row[j]
- next_point_in_row_list = row[j+1:]
- # 循环这一行的下一个点
- for next_point_in_row in next_point_in_row_list:
- # 是否在这一行点找到,找不到就这一行的下个点
- not_found = 1
- # 查询下个点所在列
- next_col = []
- for col in col_point_list:
- col.sort(key=lambda x: (x[1], x[0]))
- # 列不在该区域跳过
- if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]:
- continue
- if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3:
- next_col = col
- break
- # 循环匹配当前点和下一列点
- next_col.sort(key=lambda x: (x[1], x[0]))
- for point1 in next_col:
- # 同一行的就跳过
- if current_point[1]-3 <= point1[1] <= current_point[1]+3:
- continue
- if point1[1] <= current_point[1]-3:
- continue
- # 候选bbox
- candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]]
- # print("candidate_bbox", candidate_bbox)
- # 判断该bbox是否存在,判断bbox的上下两条边是否有包含在row中
- contain_flag1 = 0
- contain_flag2 = 0
- for row1 in row_lines:
- # 行不在该区域跳过
- if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]:
- continue
- # bbox上边框 y一样
- if not contain_flag1:
- if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3:
- # 格子里的断开线段
- row1_break = (max([row1[0], candidate_bbox[0]]),
- row1[1],
- min([row1[2], candidate_bbox[2]]),
- row1[3])
- if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- contain_flag1 = 1
- # bbox下边框 y一样
- if not contain_flag2:
- if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3:
- # 格子里的断开线段
- row1_break = (max([row1[0], candidate_bbox[0]]),
- row1[1],
- min([row1[2], candidate_bbox[2]]),
- row1[3])
- if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- contain_flag2 = 1
- # 判断该bbox是否存在,判断bbox的左右两条边是否有包含在col中
- contain_flag3 = 0
- contain_flag4 = 0
- for col1 in col_lines:
- # 列不在该区域跳过
- if col1[1] <= split_y[i-1] or col1[3] >= split_y[i]:
- continue
- # bbox左边线 x一样
- if not contain_flag3:
- if col1[0]-3 <= candidate_bbox[0] <= col1[0]+3:
- # 格子里的断开线段
- col1_break = (col1[0],
- max([col1[1], candidate_bbox[1]]),
- col1[2],
- min([col1[3], candidate_bbox[3]]))
- if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- contain_flag3 = 1
- # bbox右边框 x一样
- if not contain_flag4:
- if col1[0]-3 <= candidate_bbox[2] <= col1[0]+3:
- # 格子里的断开线段
- col1_break = (col1[0],
- max([col1[1], candidate_bbox[1]]),
- col1[2],
- min([col1[3], candidate_bbox[3]]))
- if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- contain_flag4 = 1
- # 找到了该bbox,并且是存在的
- if contain_flag1 and contain_flag2 and contain_flag3 and contain_flag4:
- bbox.append([(candidate_bbox[0], candidate_bbox[1]),
- (candidate_bbox[2], candidate_bbox[3])])
- not_found = 0
- break
- if not not_found:
- break
- return bbox
- def delete_close_points(point_list, row_point_list, col_point_list, threshold=5):
- new_point_list = []
- delete_point_list = []
- point_list.sort(key=lambda x: (x[1], x[0]))
- for i in range(len(point_list)):
- point1 = point_list[i]
- if point1 in delete_point_list:
- continue
- if i == len(point_list) - 1:
- new_point_list.append(point1)
- break
- point2 = point_list[i+1]
- # 判断坐标
- if abs(point1[0] - point2[0]) > threshold or abs(point1[1] - point2[1]) > threshold:
- new_point_list.append(point1)
- else:
- # 看两个点上的相同坐标点哪个多,就保留哪个
- count1 = 0
- count2 = 0
- for col in col_point_list:
- if point1[0] == col[0][0]:
- count1 += len(col)
- elif point2[0] == col[0][0]:
- count2 += len(col)
- if count1 >= count2:
- new_point_list.append(point1)
- delete_point_list.append(point2)
- else:
- new_point_list.append(point2)
- delete_point_list.append(point1)
- point_list = new_point_list
- new_point_list = []
- delete_point_list = []
- point_list.sort(key=lambda x: (x[0], x[1]))
- for i in range(len(point_list)):
- point1 = point_list[i]
- if point1 in delete_point_list:
- continue
- if i == len(point_list) - 1:
- new_point_list.append(point1)
- break
- point2 = point_list[i+1]
- # 判断坐标
- if abs(point1[0] - point2[0]) > threshold or abs(point1[1] - point2[1]) > threshold:
- new_point_list.append(point1)
- else:
- count1 = 0
- count2 = 0
- for row in row_point_list:
- if point1[0] == row[0][0]:
- count1 += len(row)
- elif point2[0] == row[0][0]:
- count2 += len(row)
- if count1 >= count2:
- new_point_list.append(point1)
- delete_point_list.append(point2)
- else:
- new_point_list.append(point2)
- delete_point_list.append(point1)
- return new_point_list
- def get_bbox2(image_np, points):
- # # 坐标点按行分
- # row_point_list = []
- # row_point = []
- # points.sort(key=lambda x: (x[0], x[1]))
- # for p in points:
- # if len(row_point) == 0:
- # x = p[0]
- # if x-5 <= p[0] <= x+5:
- # row_point.append(p)
- # else:
- # row_point_list.append(row_point)
- # row_point = []
- # # 坐标点按列分
- # col_point_list = []
- # col_point = []
- # points.sort(key=lambda x: (x[1], x[0]))
- # for p in points:
- # if len(col_point) == 0:
- # y = p[1]
- # if y-5 <= p[1] <= y+5:
- # col_point.append(p)
- # else:
- # col_point_list.append(col_point)
- # col_point = []
- row_point_list = get_points_row(points)
- col_point_list = get_points_col(points)
- print("len(points)", len(points))
- for point in points:
- cv2.circle(image_np, point, 1, (0, 255, 0), 1)
- cv2.imshow("points_deleted", image_np)
- points = delete_close_points(points, row_point_list, col_point_list)
- print("len(points)", len(points))
- for point in points:
- cv2.circle(image_np, point, 1, (255, 0, 0), 3)
- cv2.imshow("points_deleted", image_np)
- cv2.waitKey(0)
- row_point_list = get_points_row(points, 5)
- col_point_list = get_points_col(points, 5)
- print("len(row_point_list)", len(row_point_list))
- for row in row_point_list:
- print("row", len(row))
- print("col_point_list", len(col_point_list))
- for col in col_point_list:
- print("col", len(col))
- bbox = []
- for i in range(len(row_point_list)):
- if i == len(row_point_list) - 1:
- break
- # 遍历每个row的point,找到其所在列的下一个点和所在行的下一个点
- current_row = row_point_list[i]
- for j in range(len(current_row)):
- current_point = current_row[j]
- if j == len(current_row) - 1:
- break
- next_row_point = current_row[j+1]
- # 找出当前点所在的col,得到该列下一个point
- current_col = col_point_list[j]
- for k in range(len(current_col)):
- if current_col[k][1] > current_point[1] + 10:
- next_col_point = current_col[k]
- break
- next_row = row_point_list[k]
- for k in range(len(next_row)):
- if next_row[k][0] >= next_row_point[0] + 5:
- next_point = next_row[k]
- break
- # 得到bbox
- bbox.append([(current_point[0], current_point[1]), (next_point[0], next_point[1])])
- # bbox = []
- # for p in points:
- # # print("p", p)
- # p_row = []
- # p_col = []
- # for row in row_point_list:
- # if p[0] == row[0][0]:
- # for p1 in row:
- # if abs(p[1]-p1[1]) <= 5:
- # continue
- # p_row.append([p1, abs(p[1]-p1[1])])
- # p_row.sort(key=lambda x: x[1])
- # for col in col_point_list:
- # if p[1] == col[0][1]:
- # for p2 in col:
- # if abs(p[0]-p2[0]) <= 5:
- # continue
- # p_col.append([p2, abs(p[0]-p2[0])])
- # p_col.sort(key=lambda x: x[1])
- # if len(p_row) == 0 or len(p_col) == 0:
- # continue
- # break_flag = 0
- # for i in range(len(p_row)):
- # for j in range(len(p_col)):
- # # print(p_row[i][0])
- # # print(p_col[j][0])
- # another_point = (p_col[j][0][0], p_row[i][0][1])
- # # print("another_point", another_point)
- # if abs(p[0]-another_point[0]) <= 5 or abs(p[1]-another_point[1]) <= 5:
- # continue
- # if p[0] >= another_point[0] or p[1] >= another_point[1]:
- # continue
- # if another_point in points:
- # box = [p, another_point]
- # box.sort(key=lambda x: x[0])
- # if box not in bbox:
- # bbox.append(box)
- # break_flag = 1
- # break
- # if break_flag:
- # break
- #
- # # delete duplicate
- # delete_bbox = []
- # for i in range(len(bbox)):
- # for j in range(i+1, len(bbox)):
- # if bbox[i][0] == bbox[j][0]:
- # if bbox[i][1][0] - bbox[j][1][0] <= 3 \
- # and bbox[i][1][1] - bbox[j][1][1] <= 3:
- # delete_bbox.append(bbox[j])
- # if bbox[i][1] == bbox[j][1]:
- # if bbox[i][0][0] - bbox[j][0][0] <= 3 \
- # and bbox[i][0][1] - bbox[j][0][1] <= 3:
- # delete_bbox.append(bbox[j])
- # # delete too small area
- # # for box in bbox:
- # # if box[1][0] - box[0][0] <=
- # for d_box in delete_bbox:
- # if d_box in bbox:
- # bbox.remove(d_box)
- # print bbox
- bbox.sort(key=lambda x: (x[0][0], x[0][1], x[1][0], x[1][1]))
- # origin bbox
- # origin_bbox = []
- # for box in bbox:
- # origin_bbox.append([(box[0][0], box[0][1] - 40), (box[1][0], box[1][1] - 40)])
- # for box in origin_bbox:
- # cv2.rectangle(origin_image, box[0], box[1], (0, 0, 255), 2, 8)
- # cv2.imshow('AlanWang', origin_image)
- # cv2.waitKey(0)
- for box in bbox:
- cv2.rectangle(image_np, box[0], box[1], (0, 0, 255), 2, 8)
- cv2.imshow('bboxes', image_np)
- cv2.waitKey(0)
- # for point in points:
- # print(point)
- # cv2.circle(image_np, point, 1, (0, 0, 255), 3)
- # cv2.imshow('points', image_np)
- # cv2.waitKey(0)
- return bbox
- def get_bbox1(image_np, points, split_y):
- # 分割线纵坐标
- # print("split_y", split_y)
- if len(split_y) < 2:
- return []
- # 计算行列,剔除相近交点
- row_point_list = get_points_row(points)
- col_point_list = get_points_col(points)
- print("len(row_point_list)", row_point_list)
- print("len(col_point_list)", len(col_point_list))
- # for point in points:
- # cv2.circle(image_np, point, 1, (0, 255, 0), 1)
- # cv2.imshow("points", image_np)
- points = delete_close_points(points, row_point_list, col_point_list)
- # print("len(points)", len(points))
- # for point in points:
- # cv2.circle(image_np, point, 1, (255, 0, 0), 3)
- # cv2.imshow("points_deleted", image_np)
- # cv2.waitKey(0)
- # 获取bbox
- bbox = []
- # 每个点获取与其x最相近和y最相近的点
- for i in range(1, len(split_y)):
- for point1 in points:
- if point1[1] <= split_y[i-1] or point1[1] >= split_y[i]:
- continue
- distance_x = 10000
- distance_y = 10000
- x = 0
- y = 0
- threshold = 10
- for point2 in points:
- if point2[1] <= split_y[i-1] or point2[1] >= split_y[i]:
- continue
- # 最近 x y
- if 2 < point2[0] - point1[0] < distance_x and point2[1] - point1[1] <= threshold:
- distance_x = point2[0] - point1[0]
- x = point2[0]
- if 2 < point2[1] - point1[1] < distance_y and point2[0] - point1[0] <= threshold:
- distance_y = point2[1] - point1[1]
- y = point2[1]
- if not x or not y:
- continue
- bbox.append([(point1[0], point1[1]), (x, y)])
- # 删除包含关系bbox
- temp_list = []
- for i in range(len(bbox)):
- box1 = bbox[i]
- for j in range(len(bbox)):
- if i == j:
- continue
- box2 = bbox[j]
- contain_flag = 0
- if box2[0][0] <= box1[0][0] <= box1[1][0] <= box2[1][0] and \
- box2[0][1] <= box1[0][1] <= box1[1][1] <= box2[1][1]:
- contain_flag = 1
- break
- temp_list.append(box1)
- bbox = temp_list
- # 展示
- for box in bbox:
- # print(box[0], box[1])
- # if abs(box[0][1] - box[1][1]) > abs(box[0][0] - box[1][0]):
- # continue
- cv2.rectangle(image_np, box[0], box[1], (0, 0, 255), 2, 8)
- cv2.imshow('bboxes', image_np)
- cv2.waitKey(0)
- return bbox
- def get_bbox0(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines):
- # 分割线纵坐标
- if len(split_y) < 2:
- return []
- # 计算行列,剔除相近交点
- # row_point_list = get_points_row(points)
- # col_point_list = get_points_col(points)
- # points = delete_close_points(points, row_point_list, col_point_list)
- # row_point_list = get_points_row(points)
- # col_point_list = get_points_col(points)
- # 获取bbox
- bbox = []
- # print("get_bbox split_y", split_y)
- # 每个点获取与其x最相近和y最相近的点
- for i in range(1, len(split_y)):
- # 循环每行
- for row in row_point_list:
- row.sort(key=lambda x: (x[0], x[1]))
- # 行不在该区域跳过
- if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]:
- continue
- # 循环行中的点
- for j in range(len(row)):
- if j == len(row) - 1:
- break
- current_point = row[j]
- next_point_in_row = row[j+1]
- # 查询下个点所在列
- next_col = []
- for col in col_point_list:
- col.sort(key=lambda x: (x[1], x[0]))
- # 列不在该区域跳过
- if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]:
- continue
- if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3:
- next_col = col
- break
- # 循环匹配当前点和下一列点
- for point1 in next_col:
- # 同一行的就跳过
- if current_point[1]-3 <= point1[1] <= current_point[1]+3:
- continue
- if point1[1] <= current_point[1]-3:
- continue
- # 候选bbox
- candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]]
- # 判断该bbox是否存在,线条包含关系
- contain_flag1 = 0
- contain_flag2 = 0
- for row1 in row_lines:
- # 行不在该区域跳过
- if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]:
- continue
- # bbox上边框 y一样
- if not contain_flag1:
- if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3:
- # candidate的x1,x2需被包含在row线中
- if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3:
- contain_flag1 = 1
- # bbox下边框 y一样
- if not contain_flag2:
- if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3:
- # candidate的x1,x2需被包含在row线中
- if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3:
- contain_flag2 = 1
- # 找到了该bbox,并且是存在的
- if contain_flag1 and contain_flag2:
- bbox.append([(candidate_bbox[0], candidate_bbox[1]),
- (candidate_bbox[2], candidate_bbox[3])])
- break
- return bbox
- def get_bbox3(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines):
- # 分割线纵坐标
- if len(split_y) < 2:
- return []
- # 获取bbox
- bbox = []
- # 每个点获取与其x最相近和y最相近的点
- for i in range(1, len(split_y)):
- # 循环每行
- for row in row_point_list:
- row.sort(key=lambda x: (x[0], x[1]))
- # 行不在该区域跳过
- if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]:
- continue
- # print("len(row)", len(row))
- # print("row", row)
- # 循环行中的点
- for j in range(len(row)):
- if j == len(row) - 1:
- break
- current_point = row[j]
- # print("current_point", current_point)
- next_point_in_row_list = row[j+1:]
- # 循环这一行的下一个点
- for next_point_in_row in next_point_in_row_list:
- # 是否在这一行点找到,找不到就这一行的下个点
- not_found = 1
- # 查询下个点所在列
- next_col = []
- for col in col_point_list:
- col.sort(key=lambda x: (x[1], x[0]))
- # 列不在该区域跳过
- if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]:
- continue
- if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3:
- next_col = col
- break
- # 循环匹配当前点和下一列点
- next_col.sort(key=lambda x: (x[1], x[0]))
- for point1 in next_col:
- # 同一行的就跳过
- if current_point[1]-3 <= point1[1] <= current_point[1]+3:
- continue
- if point1[1] <= current_point[1]-3:
- continue
- # 候选bbox
- candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]]
- # print("candidate_bbox", candidate_bbox)
- # 判断该bbox是否存在,判断bbox的上下两条边是否有包含在row中
- contain_flag1 = 0
- contain_flag2 = 0
- for row1 in row_lines:
- # 行不在该区域跳过
- if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]:
- continue
- # bbox上边框 y一样
- if not contain_flag1:
- if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3:
- # 格子里的断开线段
- row1_break = (max([row1[0], candidate_bbox[0]]),
- row1[1],
- min([row1[2], candidate_bbox[2]]),
- row1[3])
- if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- contain_flag1 = 1
- # # candidate的x1,x2需被包含在row线中
- # if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3:
- # contain_flag1 = 1
- #
- # # 判断线条有无端点在格子中
- # elif candidate_bbox[0] < row1[0] < candidate_bbox[2] \
- # or candidate_bbox[0] < row1[2] < candidate_bbox[2]:
- # # 线条会有缺一点情况,判断长度超过格子一半
- # if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- # contain_flag1 = 1
- # bbox下边框 y一样
- if not contain_flag2:
- if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3:
- # 格子里的断开线段
- row1_break = (max([row1[0], candidate_bbox[0]]),
- row1[1],
- min([row1[2], candidate_bbox[2]]),
- row1[3])
- if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- contain_flag2 = 1
- # # candidate的x1,x2需被包含在row线中
- # if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3:
- # contain_flag2 = 1
- #
- # # 判断线条有无端点在格子中
- # elif candidate_bbox[0] < row1[0] < candidate_bbox[2] \
- # or candidate_bbox[0] < row1[2] < candidate_bbox[2]:
- # # 线条会有缺一点情况,判断长度超过格子一半
- # if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- # contain_flag2 = 1
- # 判断该bbox是否存在,判断bbox的左右两条边是否有包含在col中
- contain_flag3 = 0
- contain_flag4 = 0
- for col1 in col_lines:
- # 列不在该区域跳过
- if col1[1] <= split_y[i-1] or col1[3] >= split_y[i]:
- continue
- # bbox左边线 x一样
- if not contain_flag3:
- if col1[0]-3 <= candidate_bbox[0] <= col1[0]+3:
- # 格子里的断开线段
- col1_break = (col1[0],
- max([col1[1], candidate_bbox[1]]),
- col1[2],
- min([col1[3], candidate_bbox[3]]))
- if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- contain_flag3 = 1
- # # candidate的y1,y2需被包含在col线中
- # if col1[1]-3 <= candidate_bbox[1] <= candidate_bbox[3] <= col1[3]+3:
- # contain_flag3 = 1
- #
- # # 判断线条有无端点在格子中
- # elif candidate_bbox[1] < col1[1] < candidate_bbox[3] \
- # or candidate_bbox[1] < col1[3] < candidate_bbox[3]:
- # # 线条会有缺一点情况,判断长度超过格子一半
- # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- # contain_flag3 = 1
- # bbox右边框 x一样
- if not contain_flag4:
- if col1[0]-3 <= candidate_bbox[2] <= col1[0]+3:
- # 格子里的断开线段
- # col1_break = (col1[0],
- # max([col1[1], candidate_bbox[1]]),
- # col1[2],
- # min([col1[3], candidate_bbox[3]]))
- # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- # contain_flag4 = 1
- # 如果候选bbox的边的上1/3或下1/3包含在col中
- candidate_bbox_line1 = [candidate_bbox[1],
- candidate_bbox[1] + (candidate_bbox[3]-candidate_bbox[1])/3]
- candidate_bbox_line2 = [candidate_bbox[3] - (candidate_bbox[3]-candidate_bbox[1])/3,
- candidate_bbox[3]]
- if col1[1] <= candidate_bbox_line1[0] <= candidate_bbox_line1[1] <= col1[3] \
- or col1[1] <= candidate_bbox_line2[0] <= candidate_bbox_line2[1] <= col1[3]:
- # print("candidate_bbox", candidate_bbox)
- # print("col1", col1)
- contain_flag4 = 1
- # # candidate的y1,y2需被包含在col线中
- # if col1[1]-3 <= candidate_bbox[1] <= candidate_bbox[3] <= col1[3]+3:
- # contain_flag4 = 1
- #
- # # 判断线条有无端点在格子中
- # elif candidate_bbox[1] < col1[1] < candidate_bbox[3] \
- # or candidate_bbox[1] < col1[3] < candidate_bbox[3]:
- # # 线条会有缺一点情况,判断长度超过格子一半
- # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- # contain_flag4 = 1
- # 找到了该bbox,并且是存在的
- if contain_flag1 and contain_flag2 and contain_flag3 and contain_flag4:
- bbox.append([(candidate_bbox[0], candidate_bbox[1]),
- (candidate_bbox[2], candidate_bbox[3])])
- not_found = 0
- # print("exist candidate_bbox", candidate_bbox)
- # print(contain_flag1, contain_flag2, contain_flag3, contain_flag4)
- break
- # else:
- # print("candidate_bbox", candidate_bbox)
- # print(contain_flag1, contain_flag2, contain_flag3, contain_flag4)
- if not not_found:
- break
- return bbox
- def get_bbox(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines):
- # 分割线纵坐标
- if len(split_y) < 2:
- return []
- # 获取bbox
- bbox_list = []
- for i in range(1, len(split_y)):
- last_y = split_y[i-1]
- y = split_y[i]
- # 先对点线进行分区
- split_row_point_list = []
- split_col_point_list = []
- split_row_lines = []
- split_col_lines = []
- for row in row_point_list:
- if last_y <= row[0][1] <= y:
- row.sort(key=lambda x: (x[1], x[0]))
- split_row_point_list.append(row)
- for col in col_point_list:
- if last_y <= col[0][1] <= y:
- split_col_point_list.append(col)
- for row in row_lines:
- if last_y <= row[1] <= y:
- split_row_lines.append(row)
- for col in col_lines:
- if last_y <= col[1] <= y:
- split_col_lines.append(col)
- # 每个点获取其对角线点,以便形成bbox,按行循环
- for i in range(len(split_row_point_list)-1):
- row = split_row_point_list[i]
- # 循环该行的点
- for k in range(len(row)-1):
- point1 = row[k]
- next_point1 = row[k+1]
- # print("*"*30)
- # print("point1", point1)
- # 有三种对角线点
- # 1. 该点下一行的下一列的点
- # 2. 该点下一列的下一行的点
- # 3. 上述两个点是同一个点
- # 下一行没找到就循环后面的行
- if_find = 0
- for j in range(i+1, len(split_row_point_list)):
- if if_find:
- break
- next_row = split_row_point_list[j]
- # print("next_row", next_row)
- # 循环下一行的点
- for point2 in next_row:
- if abs(point1[0] - point2[0]) <= 2:
- continue
- if point2[0] < point1[0]:
- continue
- bbox = [point1[0], point1[1], point2[0], point2[1]]
- if abs(bbox[0] - bbox[2]) <= 10:
- continue
- if abs(bbox[1] - bbox[3]) <= 10:
- continue
- # bbox的四条边都需要验证是否在line上
- if check_bbox(bbox, split_row_lines, split_col_lines):
- bbox_list.append([(bbox[0], bbox[1]), (bbox[2], bbox[3])])
- if_find = 1
- # print("check bbox", bbox)
- break
- return bbox_list
- def check_bbox(bbox, rows, cols, threshold=5):
- def check(check_line, lines, limit_axis, axis):
- # 需检查的线的1/2段,1/3段,2/3段,1/4段,3/4段
- line_1_2 = [check_line[0], (check_line[0]+check_line[1])/2]
- line_2_2 = [(check_line[0]+check_line[1])/2, check_line[1]]
- line_1_3 = [check_line[0], check_line[0]+(check_line[1]-check_line[0])/3]
- line_2_3 = [check_line[1]-(check_line[1]-check_line[0])/3, check_line[1]]
- line_1_4 = [check_line[0], check_line[0]+(check_line[1]-check_line[0])/4]
- line_3_4 = [check_line[1]-(check_line[1]-check_line[0])/4, check_line[1]]
- # 限制row相同y,col相同x
- if_line = 0
- for line1 in lines:
- if not if_line and abs(line1[1-axis] - limit_axis) <= threshold:
- # check_line完全包含在line中
- if line1[axis] <= check_line[0] <= check_line[1] <= line1[axis+2]:
- if_line = 1
- # check_line的1/2包含在line
- elif line1[axis] <= line_1_2[0] <= line_1_2[1] <= line1[axis+2] \
- or line1[axis] <= line_2_2[0] <= line_2_2[1] <= line1[axis+2]:
- if_line = 1
- # check_line两个1/3段被包含在不同line中
- elif line1[axis] <= line_1_3[0] <= line_1_3[1] <= line1[axis+2]:
- # check_line另一边的1/4被包含
- for line2 in lines:
- if abs(line1[1-axis] - limit_axis) <= threshold:
- if line2[axis] <= line_3_4[0] <= line_3_4[1] <= line2[axis+2]:
- if_line = 1
- break
- elif line1[axis] <= line_2_3[0] <= line_2_3[1] <= line1[axis+2]:
- # check_line另一边的1/4被包含
- for line2 in lines:
- if abs(line1[1-axis] - limit_axis) <= threshold:
- if line2[axis] <= line_1_4[0] <= line_1_4[1] <= line2[axis+2]:
- if_line = 1
- break
- return if_line
- up_down_line = [bbox[0], bbox[2]]
- up_y, down_y = bbox[1], bbox[3]
- left_right_line = [bbox[1], bbox[3]]
- left_x, right_x = bbox[0], bbox[2]
- # 检查bbox四条边是否存在
- if_up = check(up_down_line, rows, up_y, 0)
- if_down = check(up_down_line, rows, down_y, 0)
- if_left = check(left_right_line, cols, left_x, 1)
- if_right = check(left_right_line, cols, right_x, 1)
- # 检查bbox内部除了四条边,是否有其它line在bbox内部
- if_col = 0
- if_row = 0
- if if_up and if_down and if_left and if_right:
- for col in cols:
- if not if_col and left_x+threshold <= col[0] <= right_x-threshold:
- if col[1] <= left_right_line[0] <= left_right_line[1] <= col[3]:
- if_col = 1
- elif left_right_line[0] <= col[1] <= left_right_line[1]:
- if left_right_line[1] - col[1] >= (left_right_line[1] + left_right_line[0])/2:
- if_col = 1
- elif left_right_line[0] <= col[3] <= left_right_line[1]:
- if col[3] - left_right_line[0] >= (left_right_line[1] + left_right_line[0])/2:
- if_col = 1
- for row in rows:
- if not if_row and up_y+threshold <= row[1] <= down_y-threshold:
- if row[0] <= up_down_line[0] <= up_down_line[1] <= row[2]:
- if_row = 1
- elif up_down_line[0] <= row[0] <= up_down_line[1]:
- if up_down_line[1] - row[0] >= (up_down_line[1] + up_down_line[0])/2:
- if_row = 1
- elif up_down_line[0] <= row[2] <= up_down_line[1]:
- if row[2] - up_down_line[0] >= (up_down_line[1] + up_down_line[0])/2:
- if_row = 1
- if if_up and if_down and if_left and if_right and not if_col and not if_row:
- return True
- else:
- return False
- def add_continue_bbox(bboxes):
- add_bbox_list = []
- bboxes.sort(key=lambda x: (x[0][0], x[0][1]))
- last_bbox = bboxes[0]
- # 先对bbox分区
- for i in range(1, len(split_y)):
- y = split_y[i]
- last_y = split_y[i-1]
- split_bbox = []
- for bbox in bboxes:
- if last_y <= bbox[1][1] <= y:
- split_bbox.append(bbox)
- split_bbox.sort
- for i in range(1, len(bboxes)):
- bbox = bboxes[i]
- if last_y <= bbox[1][1] <= y and last_y <= last_bbox[1][1] <= y:
- if abs(last_bbox[1][1] - bbox[0][1]) <= 2:
- last_bbox = bbox
- else:
- if last_bbox[1][1] > bbox[0][1]:
- last_bbox = bbox
- else:
- add_bbox = [(last_bbox[0][0], last_bbox[1][1]),
- (last_bbox[1][0], bbox[0][1])]
- add_bbox_list.append(add_bbox)
- last_y = y
- print("add_bbox_list", add_bbox_list)
- if add_bbox_list:
- bboxes = [str(x) for x in bboxes + add_bbox_list]
- bboxes = list(set(bboxes))
- bboxes = [eval(x) for x in bboxes]
- bboxes.sort(key=lambda x: (x[0][1], x[0][0]))
- return bboxes
- def points_to_line(points_lines, axis):
- new_line_list = []
- for line in points_lines:
- average = 0
- _min = _min = line[0][axis]
- _max = line[-1][axis]
- for point in line:
- average += point[1-axis]
- if point[axis] < _min:
- _min = point[axis]
- if point[axis] > _max:
- _max = point[axis]
- average = int(average / len(line))
- if axis:
- new_line = [average, _min, average, _max]
- else:
- new_line = [_min, average, _max, average]
- new_line_list.append(new_line)
- return new_line_list
- def get_bbox_by_contours(image_np):
- img_gray = cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY)
- ret, img_bin = cv2.threshold(img_gray, 127, 255, cv2.THRESH_BINARY)
- # 3.连通域分析
- img_bin, contours, hierarchy = cv2.findContours(img_bin,
- cv2.RETR_LIST,
- cv2.CHAIN_APPROX_SIMPLE)
- # 4.获取最小外接圆 圆心 半径
- center, radius = cv2.minEnclosingTriangle(contours[0])
- center = np.int0(center)
- # 5.绘制最小外接圆
- img_result = image_np.copy()
- cv2.circle(img_result, tuple(center), int(radius), (255, 255, 255), 2)
- # # 读入图片
- # img = image_np
- # cv2.imshow("get_bbox_by_contours ", image_np)
- # # 中值滤波,去噪
- # img = cv2.medianBlur(img, 3)
- # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- # cv2.namedWindow('original', cv2.WINDOW_AUTOSIZE)
- # cv2.imshow('original', gray)
- #
- # # 阈值分割得到二值化图片
- # ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
- #
- # # 膨胀操作
- # kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
- # bin_clo = cv2.dilate(binary, kernel2, iterations=2)
- #
- # # 连通域分析
- # num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_clo, connectivity=8)
- #
- # # 查看各个返回值
- # # 连通域数量
- # print('num_labels = ',num_labels)
- # # 连通域的信息:对应各个轮廓的x、y、width、height和面积
- # print('stats = ',stats)
- # # 连通域的中心点
- # print('centroids = ',centroids)
- # # 每一个像素的标签1、2、3.。。,同一个连通域的标签是一致的
- # print('labels = ',labels)
- #
- # # 不同的连通域赋予不同的颜色
- # output = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
- # for i in range(1, num_labels):
- #
- # mask = labels == i
- # output[:, :, 0][mask] = np.random.randint(0, 255)
- # output[:, :, 1][mask] = np.random.randint(0, 255)
- # output[:, :, 2][mask] = np.random.randint(0, 255)
- # cv2.imshow('oginal', output)
- # cv2.waitKey()
- # cv2.destroyAllWindows()
- def get_points_col(points, split_y, threshold=5):
- # 坐标点按行分
- row_point_list = []
- row_point = []
- points.sort(key=lambda x: (x[0], x[1]))
- # print("get_points_col points sort", points)
- x = points[0][0]
- for i in range(1, len(split_y)):
- for p in points:
- if p[1] <= split_y[i-1] or p[1] >= split_y[i]:
- continue
- if x-threshold <= p[0] <= x+threshold:
- row_point.append(p)
- else:
- # print("row_point", row_point)
- row_point.sort(key=lambda x: (x[1], x[0]))
- if row_point:
- row_point_list.append(row_point)
- row_point = []
- x = p[0]
- row_point.append(p)
- if row_point:
- row_point_list.append(row_point)
- return row_point_list
- def get_points_row(points, split_y, threshold=5):
- # 坐标点按列分
- col_point_list = []
- col_point = []
- points.sort(key=lambda x: (x[1], x[0]))
- y = points[0][1]
- for i in range(len(split_y)):
- for p in points:
- if p[1] <= split_y[i-1] or p[1] >= split_y[i]:
- continue
- if y-threshold <= p[1] <= y+threshold:
- col_point.append(p)
- else:
- col_point.sort(key=lambda x: (x[0], x[1]))
- if col_point:
- col_point_list.append(col_point)
- col_point = []
- y = p[1]
- col_point.append(p)
- if col_point:
- col_point_list.append(col_point)
- return col_point_list
- def get_outline_point(points, split_y):
- # 分割线纵坐标
- # print("get_outline_point split_y", split_y)
- if len(split_y) < 2:
- return []
- outline_2point = []
- points.sort(key=lambda x: (x[1], x[0]))
- for i in range(1, len(split_y)):
- area_points = []
- for point in points:
- if point[1] <= split_y[i-1] or point[1] >= split_y[i]:
- continue
- area_points.append(point)
- if area_points:
- area_points.sort(key=lambda x: (x[1], x[0]))
- outline_2point.append([area_points[0], area_points[-1]])
- return outline_2point
- # def merge_row(row_lines):
- # for row in row_lines:
- # for row1 in row_lines:
- def get_best_predict_size(image_np):
- sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128]
- min_len = 10000
- best_height = sizes[0]
- for height in sizes:
- if abs(image_np.shape[0] - height) < min_len:
- min_len = abs(image_np.shape[0] - height)
- best_height = height
- min_len = 10000
- best_width = sizes[0]
- for width in sizes:
- if abs(image_np.shape[1] - width) < min_len:
- min_len = abs(image_np.shape[1] - width)
- best_width = width
- return best_height, best_width
- def choose_longer_row(lines):
- new_row = []
- jump_row = []
- for i in range(len(lines)):
- row1 = lines[i]
- jump_flag = 0
- if row1 in jump_row:
- continue
- for j in range(i+1, len(lines)):
- row2 = lines[j]
- if row2 in jump_row:
- continue
- if row2[1]-5 <= row1[1] <= row2[1]+5:
- if row1[0] <= row2[0] and row1[2] >= row2[2]:
- new_row.append(row1)
- jump_row.append(row1)
- jump_row.append(row2)
- jump_flag = 1
- break
- elif row2[0] <= row1[0] and row2[2] >= row1[2]:
- new_row.append(row2)
- jump_row.append(row1)
- jump_row.append(row2)
- jump_flag = 1
- break
- if not jump_flag:
- new_row.append(row1)
- jump_row.append(row1)
- return new_row
- def choose_longer_col(lines):
- new_col = []
- jump_col = []
- for i in range(len(lines)):
- col1 = lines[i]
- jump_flag = 0
- if col1 in jump_col:
- continue
- for j in range(i+1, len(lines)):
- col2 = lines[j]
- if col2 in jump_col:
- continue
- if col2[0]-5 <= col1[0] <= col2[0]+5:
- if col1[1] <= col2[1] and col1[3] >= col2[3]:
- new_col.append(col1)
- jump_col.append(col1)
- jump_col.append(col2)
- jump_flag = 1
- break
- elif col2[1] <= col1[1] and col2[3] >= col1[3]:
- new_col.append(col2)
- jump_col.append(col1)
- jump_col.append(col2)
- jump_flag = 1
- break
- if not jump_flag:
- new_col.append(col1)
- jump_col.append(col1)
- return new_col
- def delete_contain_bbox(bboxes):
- # bbox互相包含,取小的bbox
- delete_bbox = []
- for i in range(len(bboxes)):
- for j in range(i+1, len(bboxes)):
- bbox1 = bboxes[i]
- bbox2 = bboxes[j]
- # 横坐标相等情况
- if bbox1[0][0] == bbox2[0][0] and bbox1[1][0] == bbox2[1][0]:
- if bbox1[0][1] <= bbox2[0][1] <= bbox2[1][1] <= bbox1[1][1]:
- # print("1", bbox1, bbox2)
- delete_bbox.append(bbox1)
- elif bbox2[0][1] <= bbox1[0][1] <= bbox1[1][1] <= bbox2[1][1]:
- # print("2", bbox1, bbox2)
- delete_bbox.append(bbox2)
- # 纵坐标相等情况
- elif bbox1[0][1] == bbox2[0][1] and bbox1[1][1] == bbox2[1][1]:
- if bbox1[0][0] <= bbox2[0][0] <= bbox2[1][0] <= bbox1[1][0]:
- print("3", bbox1, bbox2)
- delete_bbox.append(bbox1)
- elif bbox2[0][0] <= bbox1[0][0] <= bbox1[1][0] <= bbox2[1][0]:
- print("4", bbox1, bbox2)
- delete_bbox.append(bbox2)
- print("delete_contain_bbox len(bboxes)", len(bboxes))
- print("delete_contain_bbox len(delete_bbox)", len(delete_bbox))
- for bbox in delete_bbox:
- if bbox in bboxes:
- bboxes.remove(bbox)
- print("delete_contain_bbox len(bboxes)", len(bboxes))
- return bboxes
- if __name__ == '__main__':
- # p = "开标记录表3_page_0.png"
- # p = "train_data/label_1.jpg"
- # p = "test_files/train_463.jpg"
- p = "test_files/8.png"
- # p = "test_files/无边框3.jpg"
- # p = "test_files/part1.png"
- # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00e959a0bc9011ebaf5a00163e0ae709" + \
- # "\\00e95f7cbc9011ebaf5a00163e0ae709_pdf_page0.png"
- # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \
- # "\\00fb43acbc7e11eb836000163e0ae709.png"
- # p = "test_files/table.jpg"
- # p = "data_process/create_data/0.jpg"
- # p = "../format_conversion/temp/f1fe9c4ac8e511eb81d700163e0857b6/f1fea1e0c8e511eb81d700163e0857b6.png"
- # p = "../format_conversion/1.png"
- img = cv2.imread(p)
- t = time.time()
- model.load_weights("")
- best_h, best_w = get_best_predict_size(img)
- print(img.shape)
- print((best_h, best_w))
- # row_boxes, col_boxes = table_line(img[..., ::-1], model, size=(512, 1024), hprob=0.5, vprob=0.5)
- # row_boxes, col_boxes, img = table_line(img[..., ::-1], model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
- row_boxes, col_boxes, img = table_line(img, model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
- print("len(row_boxes)", len(row_boxes))
- print("len(col_boxes)", col_boxes)
- # 创建空图
- test_img = np.zeros((img.shape), np.uint8)
- test_img.fill(255)
- for box in row_boxes+col_boxes:
- cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
- cv2.imshow("test_image", test_img)
- cv2.waitKey(0)
- cv2.imwrite("temp.jpg", test_img)
- # 计算交点、分割线
- crossover_points = get_points(row_boxes, col_boxes, (img.shape[0], img.shape[1]))
- print("len(col_boxes)", len(col_boxes))
- split_lines, split_y = get_split_line(crossover_points, col_boxes, img)
- print("split_y", split_y)
- # for point in crossover_points:
- # cv2.circle(test_img, point, 1, (0, 255, 0), 3)
- # cv2.imshow("point image1", test_img)
- # cv2.waitKey(0)
- # 计算行列,剔除相近交点
- row_point_list = get_points_row(crossover_points, split_y, 0)
- col_point_list = get_points_col(crossover_points, split_y, 0)
- crossover_points = delete_close_points(crossover_points, row_point_list, col_point_list)
- row_point_list = get_points_row(crossover_points, split_y)
- col_point_list = get_points_col(crossover_points, split_y)
- for point in crossover_points:
- cv2.circle(test_img, point, 1, (0, 0, 255), 3)
- cv2.imshow("point image1", test_img)
- cv2.waitKey(0)
- print("len(row_boxes)", len(row_boxes))
- print("len(col_boxes)", len(col_boxes))
- # 修复边框
- new_row_boxes, new_col_boxes, long_row_boxes, long_col_boxes = \
- fix_outline(img, row_boxes, col_boxes, crossover_points, split_y)
- if new_row_boxes or new_col_boxes:
- if long_row_boxes:
- print("long_row_boxes", long_row_boxes)
- row_boxes = long_row_boxes
- if long_col_boxes:
- print("long_col_boxes", long_col_boxes)
- col_boxes = long_col_boxes
- if new_row_boxes:
- row_boxes += new_row_boxes
- print("new_row_boxes", new_row_boxes)
- if new_col_boxes:
- print("new_col_boxes", new_col_boxes)
- col_boxes += new_col_boxes
- # print("len(row_boxes)", len(row_boxes))
- # print("len(col_boxes)", len(col_boxes))
- # row_boxes += new_row_boxes
- # col_boxes += new_col_boxes
- # row_boxes = choose_longer_row(row_boxes)
- # col_boxes = choose_longer_col(col_boxes)
- # 创建空图
- test_img = np.zeros((img.shape), np.uint8)
- test_img.fill(255)
- for box in row_boxes+col_boxes:
- cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
- cv2.imshow("test_image2", test_img)
- cv2.waitKey(0)
- # 展示补线
- for row in new_row_boxes:
- cv2.line(test_img, (int(row[0]), int(row[1])),
- (int(row[2]), int(row[3])), (0, 0, 255), 1)
- for col in new_col_boxes:
- cv2.line(test_img, (int(col[0]), int(col[1])),
- (int(col[2]), int(col[3])), (0, 0, 255), 1)
- cv2.imshow("fix_outline", test_img)
- cv2.waitKey(0)
- cv2.imwrite("temp.jpg", test_img)
- # 修复边框后重新计算交点、分割线
- print("crossover_points", len(crossover_points))
- crossover_points = get_points(row_boxes, col_boxes, (img.shape[0], img.shape[1]))
- print("crossover_points new", len(crossover_points))
- split_lines, split_y = get_split_line(crossover_points, col_boxes, img)
- # 计算行列,剔除相近交点
- row_point_list = get_points_row(crossover_points, split_y, 0)
- col_point_list = get_points_col(crossover_points, split_y, 0)
- print(len(crossover_points), len(row_point_list), len(col_point_list))
- crossover_points = delete_close_points(crossover_points, row_point_list, col_point_list)
- print(len(crossover_points), len(row_point_list), len(col_point_list))
- row_point_list = get_points_row(crossover_points, split_y)
- col_point_list = get_points_col(crossover_points, split_y)
- for point in crossover_points:
- cv2.circle(test_img, point, 1, (0, 255, 0), 3)
- cv2.imshow("point image2", test_img)
- cv2.waitKey(0)
- # 获取每个表格的左上右下两个点
- outline_point = get_outline_point(crossover_points, split_y)
- # print(outline_point)
- for outline in outline_point:
- cv2.circle(test_img, outline[0], 1, (255, 0, 0), 5)
- cv2.circle(test_img, outline[1], 1, (255, 0, 0), 5)
- cv2.imshow("outline point", test_img)
- cv2.waitKey(0)
- # 获取bbox
- # get_bbox(img, crossover_points, split_y)
- # for point in crossover_points:
- # cv2.circle(test_img, point, 1, (0, 255, 0), 3)
- # cv2.imshow("point image3", test_img)
- # cv2.waitKey(0)
- # split_y = []
- # for outline in outline_point:
- # split_y.extend([outline[0][1]-5, outline[1][1]+5])
- print("len(row_boxes)", len(row_boxes))
- print("len(col_boxes)", len(col_boxes))
- bboxes = get_bbox(img, row_point_list, col_point_list, split_y, row_boxes, col_boxes)
- # 展示
- for box in bboxes:
- # print(box[0], box[1])
- # if abs(box[0][1] - box[1][1]) > abs(box[0][0] - box[1][0]):
- # continue
- cv2.rectangle(test_img, box[0], box[1], (0, 0, 255), 2, 8)
- cv2.imshow('bboxes', test_img)
- cv2.waitKey(0)
- # img = draw_lines(img, row_boxes+col_boxes, color=(255, 0, 0), lineW=2)
- # img = draw_boxes(img, rowboxes+colboxes, color=(0, 0, 255))
- print(time.time()-t, len(row_boxes), len(col_boxes))
- cv2.imwrite('temp.jpg', test_img)
- # cv2.imshow('main', img)
- # cv2.waitKey(0)
|