1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Thu Sep 9 23:11:51 2020
- table line detect
- @author: chineseocr
- """
- import copy
- import logging
- import tensorflow as tf
- import tensorflow.keras.backend as K
- from tensorflow.keras.models import Model
- from tensorflow.keras.layers import Input, concatenate, Conv2D, MaxPooling2D, BatchNormalization, UpSampling2D
- from tensorflow.keras.layers import LeakyReLU
- from otr.utils import letterbox_image, get_table_line, adjust_lines, line_to_line, draw_boxes
- import numpy as np
- import cv2
- import time
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- def dice_coef(y_true, y_pred, smooth=1e-5):
- y_true_f = K.flatten(y_true)
- y_pred_f = K.flatten(y_pred)
- intersection = K.sum(y_true_f * y_pred_f)
- return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
- def dice_coef_loss():
- def dice_coef_loss_fixed(y_true, y_pred):
- return -dice_coef(y_true, y_pred)
- return dice_coef_loss_fixed
- def focal_loss(gamma=3., alpha=.5):
- # 3 0.85 2000e acc-0.6 p-0.99 r-0.99 val_acc-0.56 val_p-0.86 val_r-0.95
- # 2 0.85 double_gpu acc-
- # 3 0.25 gpu 50e acc-0.5 p-0.99 r-0.99 val_acc-0.45 val_p-0.96 val_r-0.88
- # 2 0.25 gpu acc-
- # 3 0.5 double_gpu acc-0.6 p-0.99 r-0.99 val_acc-0.60 val_p-0.93 val_r-0.93
- def focal_loss_fixed(y_true, y_pred):
- pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
- pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
- return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon()+pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
- return focal_loss_fixed
- def table_net(input_shape=(1152, 896, 3), num_classes=1):
- inputs = Input(shape=input_shape)
- # 512
- use_bias = False
- down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(inputs)
- down0a = BatchNormalization()(down0a)
- down0a = LeakyReLU(alpha=0.1)(down0a)
- down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0a)
- down0a = BatchNormalization()(down0a)
- down0a = LeakyReLU(alpha=0.1)(down0a)
- down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a)
- # 256
- down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0a_pool)
- down0 = BatchNormalization()(down0)
-
- down0 = LeakyReLU(alpha=0.1)(down0)
- down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0)
- down0 = BatchNormalization()(down0)
- down0 = LeakyReLU(alpha=0.1)(down0)
- down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0)
- # 128
- down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down0_pool)
- down1 = BatchNormalization()(down1)
- down1 = LeakyReLU(alpha=0.1)(down1)
- down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down1)
- down1 = BatchNormalization()(down1)
- down1 = LeakyReLU(alpha=0.1)(down1)
- down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
- # 64
- down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down1_pool)
- down2 = BatchNormalization()(down2)
- down2 = LeakyReLU(alpha=0.1)(down2)
- down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down2)
- down2 = BatchNormalization()(down2)
- down2 = LeakyReLU(alpha=0.1)(down2)
- down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
- # 32
- down3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(down2_pool)
- down3 = BatchNormalization()(down3)
- down3 = LeakyReLU(alpha=0.1)(down3)
- down3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(down3)
- down3 = BatchNormalization()(down3)
- down3 = LeakyReLU(alpha=0.1)(down3)
- down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)
- # 16
- down4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(down3_pool)
- down4 = BatchNormalization()(down4)
- down4 = LeakyReLU(alpha=0.1)(down4)
- down4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(down4)
- down4 = BatchNormalization()(down4)
- down4 = LeakyReLU(alpha=0.1)(down4)
- down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)
- # 8
- center = Conv2D(1024, (3, 3), padding='same', use_bias=use_bias)(down4_pool)
- center = BatchNormalization()(center)
- center = LeakyReLU(alpha=0.1)(center)
- center = Conv2D(1024, (3, 3), padding='same', use_bias=use_bias)(center)
- center = BatchNormalization()(center)
- center = LeakyReLU(alpha=0.1)(center)
- # center
- up4 = UpSampling2D((2, 2))(center)
- up4 = concatenate([down4, up4], axis=3)
- up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4)
- up4 = BatchNormalization()(up4)
- up4 = LeakyReLU(alpha=0.1)(up4)
- up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4)
- up4 = BatchNormalization()(up4)
- up4 = LeakyReLU(alpha=0.1)(up4)
- up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4)
- up4 = BatchNormalization()(up4)
- up4 = LeakyReLU(alpha=0.1)(up4)
- # 16
- up3 = UpSampling2D((2, 2))(up4)
- up3 = concatenate([down3, up3], axis=3)
- up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3)
- up3 = BatchNormalization()(up3)
- up3 = LeakyReLU(alpha=0.1)(up3)
- up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3)
- up3 = BatchNormalization()(up3)
- up3 = LeakyReLU(alpha=0.1)(up3)
- up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3)
- up3 = BatchNormalization()(up3)
- up3 = LeakyReLU(alpha=0.1)(up3)
- # 32
- up2 = UpSampling2D((2, 2))(up3)
- up2 = concatenate([down2, up2], axis=3)
- up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2)
- up2 = BatchNormalization()(up2)
- up2 = LeakyReLU(alpha=0.1)(up2)
- up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2)
- up2 = BatchNormalization()(up2)
- up2 = LeakyReLU(alpha=0.1)(up2)
- up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2)
- up2 = BatchNormalization()(up2)
- up2 = LeakyReLU(alpha=0.1)(up2)
- # 64
- up1 = UpSampling2D((2, 2))(up2)
- up1 = concatenate([down1, up1], axis=3)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
- up1 = BatchNormalization()(up1)
- up1 = LeakyReLU(alpha=0.1)(up1)
- # 128
- up0 = UpSampling2D((2, 2))(up1)
- up0 = concatenate([down0, up0], axis=3)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
- up0 = BatchNormalization()(up0)
- up0 = LeakyReLU(alpha=0.1)(up0)
- # 256
- up0a = UpSampling2D((2, 2))(up0)
- up0a = concatenate([down0a, up0a], axis=3)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
- up0a = BatchNormalization()(up0a)
- up0a = LeakyReLU(alpha=0.1)(up0a)
- # 512
- classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0a)
- model = Model(inputs=inputs, outputs=classify)
- return model
- model = table_net((None, None, 3), 2)
- def draw_pixel(pred, prob=0.2, is_test=1):
- if not is_test:
- return
- else:
- import matplotlib.pyplot as plt
- _array = []
- for _h in range(len(pred)):
- _line = []
- for _w in range(len(pred[_h])):
- _prob = pred[_h][_w]
- if _prob[0] > prob:
- _line.append((0, 0, 255))
- elif _prob[1] > prob:
- _line.append((255, 0, 0))
- else:
- _line.append((255, 255, 255))
- _array.append(_line)
- plt.axis('off')
- plt.imshow(np.array(_array))
- plt.show()
- return
- def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_len=10,
- cell_width=13):
- def inBbox(bbox,point,line_width):
- x,y = point
- if x>=bbox[0]-line_width and x<=bbox[2]+line_width and y>=bbox[1]-line_width and y<=bbox[3]+line_width:
- return True,[min(x,bbox[0]),min(y,bbox[1]),max(x,bbox[2]),max(y,bbox[3])]
- return False,None
- _time = time.time()
- height = len(pred)
- width = len(pred[0])
- clust_horizontal = []
- clust_vertical = []
- h_index = -1
- _step = line_width
- _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=0))
- _last = False
- _current = False
- while 1:
- h_index += 5
- if h_index>=height:
- break
- w_index = -1
- if sourceP_LB:
- h_i = height-1-h_index
- else:
- h_i = h_index
- while 1:
- w_index += 5
- if w_index>=width:
- break
- if _sum[w_index]<min_len:
- continue
- _h,_v = pred[h_index][w_index]
- if _v>prob:
- _find = False
- _point = (w_index,h_i)
- for l_h_i in range(len(clust_vertical)):
- l_h = clust_vertical[len(clust_vertical)-l_h_i-1]
- bbox = l_h.get("bbox")
- b_in,_bbox = inBbox(bbox,_point,line_width)
- if b_in:
- _find = True
- l_h.get("points").append(_point)
- l_h["bbox"] = _bbox
- break
- if not _find:
- clust_vertical.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
- w_index = -1
- _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1))
- while 1:
- w_index += 5
- if w_index>=width:
- break
- h_index = -1
- while 1:
- h_index += 5
- if h_index>=height:
- break
- if _sum[h_index]<min_len:
- continue
- if sourceP_LB:
- h_i = height-1-h_index
- else:
- h_i = h_index
- _h,_v = pred[h_index][w_index]
- if _h>prob:
- _find = False
- _point = (w_index,h_i)
- for l_h_i in range(len(clust_horizontal)):
- l_h = clust_horizontal[len(clust_horizontal)-l_h_i-1]
- bbox = l_h.get("bbox")
- b_in,_bbox = inBbox(bbox,_point,line_width)
- if b_in:
- _find = True
- l_h.get("points").append(_point)
- l_h["bbox"] = _bbox
- break
- if not _find:
- clust_horizontal.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
- tmp_vertical = []
- for _dict in clust_vertical:
- _bbox = _dict.get("bbox")
- if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len:
- tmp_vertical.append([(_bbox[0]+_bbox[2])/2,_bbox[1]-padding,(_bbox[0]+_bbox[2])/2,_bbox[3]+padding])
- tmp_horizontal = []
- for _dict in clust_horizontal:
- _bbox = _dict.get("bbox")
- if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len:
- tmp_horizontal.append([_bbox[0]-padding,(_bbox[1]+_bbox[3])/2,_bbox[2]+padding,(_bbox[1]+_bbox[3])/2])
- #merge lines
- tmp_vertical.sort(key=lambda x:x[3],reverse=True)
- tmp_horizontal.sort(key=lambda x:x[0])
- pop_index = []
- final_vertical = []
- for _line in tmp_vertical:
- _find = False
- x0,y0,x1,y1 = _line
- for _line2 in final_vertical:
- x2,y2,x3,y3 = _line2
- if abs(x0-x2)<line_width and abs(y0-y3)<cell_width or abs(y1-y2)<cell_width:
- _find = True
- final_vertical.append([x0,min(y0,y2),x1,max(y1,y3)])
- break
- if not _find:
- final_vertical.append(_line)
- final_horizontal = []
- for _line in tmp_horizontal:
- _find = False
- x0,y0,x1,y1 = _line
- for _line2 in final_horizontal:
- x2,y2,x3,y3 = _line2
- if abs(y0-y2)<line_width and abs(x0-x3)<cell_width or abs(x1-x2)<cell_width:
- _find = True
- final_horizontal.append([min(x0,x2),y0,max(x1,x3),y1])
- break
- if not _find:
- final_horizontal.append(_line)
- list_line = []
- for _line in final_vertical:
- list_line.append(_line)
- for _line in final_horizontal:
- list_line.append(_line)
- logging.info("points2lines cost %.2fs"%(time.time()-_time))
- # import matplotlib.pyplot as plt
- # plt.figure()
- # for _line in list_line:
- # x0,y0,x1,y1 = _line
- # plt.plot([x0,x1],[y0,y1])
- # for _line in list_line:
- # x0,y0,x1,y1 = _line.bbox
- # plt.plot([x0,x1],[y0,y1])
- # for point in list_crosspoints:
- # plt.scatter(point.get("point")[0],point.get("point")[1])
- # plt.show()
- return list_line
- def get_line_from_binary_image(image_np, point_value=1, axis=0):
- """
- 根据像素点的变化,将像素点为特定值的转化为line,即找出端点坐标。
- 需要二值化的图。
- 仅支持竖线横线。
- :param image_np: numpy格式 image
- :param point_value: 像素点的特定值
- :param axis: 是否是行,否则为列
- :return: line list
- """
- def get_axis_points(_list, axis=0):
- _list.sort(key=lambda x: (x[1-axis], x[axis]))
- standard_axis = points[axis][1-axis]
- axis_points = []
- sub_points = []
- for p in _list:
- if p[1-axis] == standard_axis:
- sub_points.append(p)
- else:
- standard_axis = p[1-axis]
- if sub_points:
- axis_points.append(sub_points)
- sub_points = []
- # 最后一行/列
- if sub_points:
- axis_points.append(sub_points)
- return axis_points
- def get_axis_lines(_list, axis=0):
- # 逐行/列判断,一行/列可能多条横线/竖线
- points_lines = []
- for axis_list in _list:
- sub_line = [axis_list[0]]
- for p in axis_list:
- # 设置基准点
- standard_p = sub_line[-1]
- # 判断连续
- if p[axis] - standard_p[axis] == 1:
- sub_line.append(p)
- else:
- points_lines.append(sub_line)
- sub_line = [p]
- # 最后一行/列
- if sub_line:
- points_lines.append(sub_line)
- # 许多点组成的line转为两点line
- lines = []
- for line in points_lines:
- line.sort(key=lambda x: (x[axis], x[1-axis]))
- lines.append([line[0][0], line[0][1], line[-1][0], line[-1][1]])
- return lines
- # 取值大于point_value的点的坐标
- ys, xs = np.where(image_np >= point_value)
- points = [[xs[i], ys[i]] for i in range(len(xs))]
- # 提出所有相同x或相同y的点
- # 提取行/列
- axis_points = get_axis_points(points, axis)
- # 提取每行/列的横线/竖线
- axis_lines = get_axis_lines(axis_points, axis)
- # print("axis_lines", axis_lines)
- return axis_lines
- def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
- logging.info("into table_line, prob is " + str(prob))
- sizew, sizeh = size
- img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
- pred = model.predict(np.array([img_new]))
- pred = pred[0]
- draw_pixel(pred, prob, is_test)
- # 横线预测结果
- # row_pred = pred[..., 0] > hprob
- # row_pred = row_pred.astype(np.uint8)
- # # 竖线预测结果
- # col_pred = pred[..., 1] > vprob
- # col_pred = col_pred.astype(np.uint8)
- # # 打印模型输出
- # cv2.imshow("predict", (col_pred+row_pred)*255)
- # cv2.waitKey(0)
- _time = time.time()
- list_line = points2lines(pred, False, prob=prob)
- mat_plot(list_line, "points2lines", is_test)
- # 清除短线
- # print(img_new.shape)
- list_line = delete_short_lines(list_line, img_new.shape)
- mat_plot(list_line, "delete_short_lines", is_test)
- # 清除无交点线
- list_line = delete_no_cross_lines(list_line)
- mat_plot(list_line, "delete_no_cross_lines", is_test)
- # 分成横竖线
- list_rows = []
- list_cols = []
- for line in list_line:
- if line[0] == line[2]:
- list_cols.append(line)
- elif line[1] == line[3]:
- list_rows.append(line)
- # 合并错开线
- list_rows = merge_line(list_rows, axis=0)
- list_cols = merge_line(list_cols, axis=1)
- mat_plot(list_rows+list_cols, "merge_line", is_test)
- # 计算交点、分割线
- cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
- if not cross_points:
- return []
- # 清掉外围的没用的线
- # list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points)
- # mat_plot(list_rows+list_cols, "delete_outline", is_test)
- # 多个表格分割线
- list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
- split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
- # 修复边框
- new_rows, new_cols, long_rows, long_cols = fix_outline(img_new, list_rows, list_cols, cross_points,
- split_y)
- # 如有补线
- if new_rows or new_cols:
- # 连接至补线的延长线
- if long_rows:
- list_rows = long_rows
- if long_cols:
- list_cols = long_cols
- # 新的补线
- if new_rows:
- list_rows += new_rows
- if new_cols:
- list_cols += new_cols
- list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
- # 修复边框后重新计算交点、分割线
- cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
- cv_plot(cross_points, img_new.shape, 0, is_test)
- split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
- print("fix new split_y", split_y)
- print("fix new split_lines", split_lines)
- # 修复内部缺线
- # cross_points = fix_inner(list_rows, list_cols, cross_points, split_y)
- # if not cross_points:
- # return []
- mat_plot(list_rows+list_cols, "fix_outline", is_test)
- split_lines_show = []
- for _l in split_lines:
- split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]])
- mat_plot(split_lines_show+list_cols,
- "split_lines", is_test)
- # 修复表格4个角
- list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0)
- mat_plot(list_rows+list_cols, "fix_corner", is_test)
- # 修复内部缺线
- list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y)
- mat_plot(list_rows+list_cols, "fix_inner", is_test)
- # 合并错开线
- list_rows = merge_line(list_rows, axis=0)
- list_cols = merge_line(list_cols, axis=1)
- mat_plot(list_rows+list_cols, "merge_line", is_test)
- list_line = list_rows + list_cols
- # 打印处理后线
- mat_plot(list_line, "all", is_test)
- return list_line
- def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=30, alph=15):
- sizew, sizeh = size
- # [..., ::-1] 最后一维内部反向输出
- # inputBlob, fx, fy = letterbox_image(img[..., ::-1], (sizew, sizeh))
- # pred = model.predict(np.array([np.array(inputBlob)]))
- # pred = model.predict(np.array([np.array(inputBlob)/255.0]))
- img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
- # logging.info("into table_line 1")
- pred = model.predict(np.array([img_new]))
- # logging.info("into table_line 2")
- pred = pred[0]
- draw_pixel(pred)
- _time = time.time()
- points2lines(pred)
- logging.info("points2lines takes %ds"%(time.time()-_time))
- vpred = pred[..., 1] > vprob # 横线
- hpred = pred[..., 0] > hprob # 竖线
- vpred = vpred.astype(int)
- hpred = hpred.astype(int)
- # print("vpred shape", vpred)
- # print("hpred shape", hpred)
- colboxes = get_table_line(vpred, axis=1, lineW=col)
- rowboxes = get_table_line(hpred, axis=0, lineW=row)
- # logging.info("into table_line 3")
- # if len(rowboxes) > 0:
- # rowboxes = np.array(rowboxes)
- # rowboxes[:, [0, 2]] = rowboxes[:, [0, 2]]/fx
- # rowboxes[:, [1, 3]] = rowboxes[:, [1, 3]]/fy
- # rowboxes = rowboxes.tolist()
- # if len(colboxes) > 0:
- # colboxes = np.array(colboxes)
- # colboxes[:, [0, 2]] = colboxes[:, [0, 2]]/fx
- # colboxes[:, [1, 3]] = colboxes[:, [1, 3]]/fy
- # colboxes = colboxes.tolist()
- nrow = len(rowboxes)
- ncol = len(colboxes)
- for i in range(nrow):
- for j in range(ncol):
- rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], 10)
- colboxes[j] = line_to_line(colboxes[j], rowboxes[i], 10)
- # logging.info("into table_line 4")
- # 删掉贴着边框的line
- temp_list = []
- threshold = 5
- for line in rowboxes:
- if line[1]-0 <= threshold or size[1]-line[1] <= threshold:
- continue
- # 内部排序
- if line[0] > line[2]:
- line = [line[2], line[3], line[0], line[1]]
- temp_list.append(line)
- rowboxes = temp_list
- temp_list = []
- for line in colboxes:
- if line[0]-0 <= threshold or size[0]-line[0] <= threshold:
- continue
- # 内部排序
- if line[1] > line[3]:
- line = [line[2], line[3], line[0], line[1]]
- temp_list.append(line)
- colboxes = temp_list
- return rowboxes, colboxes, img_new
- def fix_in_split_lines(_rows, _cols, _img):
- # 补线贴着边缘无法得到split_y,导致无法分区
- for _row in _rows:
- if _row[1] >= _img.shape[0] - 5:
- _row[1] = _img.shape[0] - 6
- _row[3] = _img.shape[0] - 6
- print("_row", _row)
- if _row[1] <= 0 + 5:
- _row[1] = 6
- _row[3] = 6
- for _col in _cols:
- if _col[3] >= _img.shape[0] - 5:
- _col[3] = _img.shape[0] - 6
- if _col[1] <= 0 + 5:
- _col[1] = 6
- return _rows, _cols
- def mat_plot(list_line, name="", is_test=1):
- if not is_test:
- return
- from matplotlib import pyplot as plt
- plt.figure()
- plt.title(name)
- for _line in list_line:
- x0, y0, x1, y1 = _line
- plt.plot([x0, x1], [y0, y1])
- plt.show()
- def cv_plot(_list, img_shape, line_or_point=1, is_test=1):
- if is_test == 0:
- return
- img_print = np.zeros(img_shape, np.uint8)
- img_print.fill(255)
- if line_or_point:
- for line in _list:
- cv2.line(img_print, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
- (255, 0, 0))
- cv2.imshow("cv_plot", img_print)
- cv2.waitKey(0)
- else:
- for point in _list:
- cv2.circle(img_print, (int(point[0]), int(point[1])), 1, (255, 0, 0), 2)
- cv2.imshow("cv_plot", img_print)
- cv2.waitKey(0)
- def delete_no_cross_lines(list_lines):
- def get_cross_point(l1, l2):
- # https://www.zhihu.com/question/381406535/answer/1095948349
- flag = 0
- # 相交一定是一条横线一条竖线
- if (l1[0] == l1[2] and l2[1] == l2[3]) or (l1[1] == l1[3] and l2[0] == l2[2]):
- if l1[0] <= l2[0] <= l1[2] and l2[1] <= l1[1] <= l2[3]:
- flag = 1
- elif l2[0] <= l1[0] <= l2[2] and l1[1] <= l2[1] <= l1[3]:
- flag = 1
- return flag
- new_list_lines = []
- for i in range(len(list_lines)):
- line1 = list_lines[i]
- find_flag = 0
- for j in range(i+1, len(list_lines)):
- line2 = list_lines[j]
- if get_cross_point(line1, line2):
- # print("delete_no_cross_lines", line1, line2)
- find_flag = 1
- if line2 not in new_list_lines:
- new_list_lines.append(line2)
- if find_flag and line1 not in new_list_lines:
- new_list_lines.append(line1)
- return new_list_lines
- def delete_short_lines(list_lines, image_shape, scale=40):
- x_min_len = max(5, int(image_shape[0] / scale))
- y_min_len = max(5, int(image_shape[1] / scale))
- new_list_lines = []
- for line in list_lines:
- if line[0] == line[2]:
- if abs(line[3] - line[1]) >= y_min_len:
- # print("y_min_len", abs(line[3] - line[1]), y_min_len)
- new_list_lines.append(line)
- else:
- if abs(line[2] - line[0]) >= x_min_len:
- # print("x_min_len", abs(line[2] - line[0]), x_min_len)
- new_list_lines.append(line)
- return new_list_lines
- def get_outline(points, image_np):
- # 取出x, y的最大值最小值
- x_min = points[0][0]
- x_max = points[-1][0]
- points.sort(key=lambda x: (x[1], x[0]))
- y_min = points[0][1]
- y_max = points[-1][1]
- # 创建空图
- # outline_img = np.zeros(image_size, np.uint8)
- outline_img = np.copy(image_np)
- cv2.rectangle(outline_img, (x_min-5, y_min-5), (x_max+5, y_max+5), (0, 0, 0), 2)
- # cv2.imshow("outline_img", outline_img)
- # cv2.waitKey(0)
- return outline_img
- def get_split_line(points, col_lines, image_np, threshold=5):
- # print("get_split_line", image_np.shape)
- points.sort(key=lambda x: (x[1], x[0]))
- # 遍历y坐标,并判断y坐标与上一个y坐标是否存在连接线
- i = 0
- split_line_y = []
- for point in points:
- # 从已分开的线下面开始判断
- if split_line_y:
- if point[1] <= split_line_y[-1] + threshold:
- last_y = point[1]
- continue
- if last_y <= split_line_y[-1] + threshold:
- last_y = point[1]
- continue
- if i == 0:
- last_y = point[1]
- i += 1
- continue
- current_line = (last_y, point[1])
- split_flag = 1
- for col in col_lines:
- # 只要找到一条col包含就不是分割线
- if current_line[0] >= col[1]-3 and current_line[1] <= col[3]+3:
- split_flag = 0
- # print("img", img.shape)
- # print("col", col)
- # print("current_line", current_line)
- break
- if split_flag:
- split_line_y.append(current_line[0]+5)
- split_line_y.append(current_line[1]-5)
- last_y = point[1]
- # 加上收尾分割线
- points.sort(key=lambda x: (x[1], x[0]))
- y_min = points[0][1]
- y_max = points[-1][1]
- # print("加上收尾分割线", y_min, y_max)
- if y_min-threshold < 0:
- split_line_y.append(0)
- else:
- split_line_y.append(y_min-threshold)
- if y_max+threshold > image_np.shape[0]:
- split_line_y.append(image_np.shape[0])
- else:
- split_line_y.append(y_max+threshold)
- split_line_y = list(set(split_line_y))
- # 剔除两条相隔太近分割线
- temp_split_line_y = []
- split_line_y.sort(key=lambda x: x)
- last_y = -20
- for y in split_line_y:
- # print(y)
- if y - last_y >= 20:
- # print(y, last_y)
- temp_split_line_y.append(y)
- last_y = y
- split_line_y = temp_split_line_y
- # print("split_line_y", split_line_y)
- # 生成分割线
- split_line = []
- last_y = 0
- for y in split_line_y:
- # if y - last_y <= 15:
- # continue
- split_line.append([(0, y), (image_np.shape[1], y)])
- last_y = y
- split_line.append([(0, 0), (image_np.shape[1], 0)])
- split_line.append([(0, image_np.shape[0]), (image_np.shape[1], image_np.shape[0])])
- split_line.sort(key=lambda x: x[0][1])
- # print("split_line", split_line)
- # 画图画线
- # split_line_img = np.copy(image_np)
- # for y in split_line_y:
- # cv2.line(split_line_img, (0, y), (image_np.shape[1], y), (0, 0, 0), 1)
- # cv2.imshow("split_line_img", split_line_img)
- # cv2.waitKey(0)
- return split_line, split_line_y
- def get_points(row_lines, col_lines, image_size):
- # 创建空图
- row_img = np.zeros(image_size, np.uint8)
- col_img = np.zeros(image_size, np.uint8)
- # 画线
- thresh = 3
- for row in row_lines:
- cv2.line(row_img, (int(row[0]-thresh), int(row[1])), (int(row[2]+thresh), int(row[3])), (255, 255, 255), 1)
- for col in col_lines:
- cv2.line(col_img, (int(col[0]), int(col[1]-thresh)), (int(col[2]), int(col[3]+thresh)), (255, 255, 255), 1)
- # 求出交点
- point_img = np.bitwise_and(row_img, col_img)
- # cv2.imwrite("get_points.jpg", row_img+col_img)
- # cv2.imshow("get_points", row_img+col_img)
- # cv2.waitKey(0)
- # 识别黑白图中的白色交叉点,将横纵坐标取出
- ys, xs = np.where(point_img > 0)
- points = []
- for i in range(len(xs)):
- points.append((xs[i], ys[i]))
- points.sort(key=lambda x: (x[0], x[1]))
- return points
- def get_minAreaRect(image):
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- gray = cv2.bitwise_not(gray)
- thresh = cv2.threshold(gray, 0, 255,
- cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
- coords = np.column_stack(np.where(thresh > 0))
- return cv2.minAreaRect(coords)
- def get_contours(image):
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
- contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- cv2.drawContours(image, contours, -1, (0, 0, 255), 3)
- cv2.imshow("get contours", image)
- cv2.waitKey(0)
- def merge_line(lines, axis, threshold=5):
- """
- 解决模型预测一条直线错开成多条直线,合并成一条直线
- :param lines: 线条列表
- :param axis: 0:横线 1:竖线
- :param threshold: 两条线间像素差阈值
- :return: 合并后的线条列表
- """
- # 任意一条line获取该合并的line,横线往下找,竖线往右找
- lines.sort(key=lambda x: (x[axis], x[1-axis]))
- merged_lines = []
- used_lines = []
- for line1 in lines:
- if line1 in used_lines:
- continue
- merged_line = [line1]
- used_lines.append(line1)
- for line2 in lines:
- if line2 in used_lines:
- continue
- if line1[1-axis]-threshold <= line2[1-axis] <= line1[1-axis]+threshold:
- # 计算基准长度
- min_axis = 10000
- max_axis = 0
- for line3 in merged_line:
- if line3[axis] < min_axis:
- min_axis = line3[axis]
- if line3[axis+2] > max_axis:
- max_axis = line3[axis+2]
- # 判断两条线有无交集
- if min_axis <= line2[axis] <= max_axis \
- or min_axis <= line2[axis+2] <= max_axis:
- merged_line.append(line2)
- used_lines.append(line2)
- if merged_line:
- merged_lines.append(merged_line)
- # 合并line
- result_lines = []
- for merged_line in merged_lines:
- # 获取line宽的平均值
- axis_average = 0
- for line in merged_line:
- axis_average += line[1-axis]
- axis_average = int(axis_average/len(merged_line))
- # 获取最长line两端
- merged_line.sort(key=lambda x: (x[axis]))
- axis_start = merged_line[0][axis]
- merged_line.sort(key=lambda x: (x[axis+2]))
- axis_end = merged_line[-1][axis+2]
- if axis:
- result_lines.append([axis_average, axis_start, axis_average, axis_end])
- else:
- result_lines.append([axis_start, axis_average, axis_end, axis_average])
- return result_lines
- def fix_inner2(row_points, col_points, row_lines, col_lines, threshold=3):
- for i in range(len(row_points)):
- row = row_points[i]
- row.sort(key=lambda x: (x[1], x[0]))
- for j in range(len(row)):
- # 当前点
- point = row[j]
- # 获取当前点在所在行的下个点
- if j >= len(row) - 1:
- next_row_point = []
- else:
- next_row_point = row[j+1]
- if next_row_point:
- for k in range(len(row_lines)):
- line = row_lines[k]
- if line[1] - threshold <= point[1] <= line[1] + threshold:
- if not line[0] <= point[0] <= next_row_point[0] <= line[2]:
- if point[0] <= line[2] < next_row_point[0]:
- if line[2] - point[0] >= 1/3 * (next_row_point[0] - point[0]):
- row_lines[k][2] = next_row_point[0]
- if point[0] < line[0] <= next_row_point[0]:
- if next_row_point[0] - line[0] >= 1/3 * (next_row_point[0] - point[0]):
- row_lines[k][0] = point[0]
- # 获取当前点所在列的下个点
- next_col_point = []
- for col in col_points:
- if point in col:
- col.sort(key=lambda x: (x[0], x[1]))
- if col.index(point) < len(col) - 1:
- next_col_point = col[col.index(point)+1]
- break
- # 获取当前点的对角线点,通过该列下个点所在行的下个点获得
- next_row_next_col_point = []
- if next_col_point:
- for row2 in row_points:
- if next_col_point in row2:
- row2.sort(key=lambda x: (x[1], x[0]))
- if row2.index(next_col_point) < len(row2) - 1:
- next_row_next_col_point = row2[row2.index(next_col_point)+1]
- break
- # 有该列下一点但没有该列下一点所在行的下个点
- if not next_row_next_col_point:
- # 如果有该行下个点
- if next_row_point:
- next_row_next_col_point = [next_row_point[0], next_col_point[1]]
- if next_col_point:
- for k in range(len(col_lines)):
- line = col_lines[k]
- if line[0] - threshold <= point[0] <= line[0] + threshold:
- if not line[1] <= point[1] <= next_col_point[1] <= line[3]:
- if point[1] <= line[3] < next_col_point[1]:
- if line[3] - point[1] >= 1/3 * (next_col_point[1] - point[1]):
- col_lines[k][3] = next_col_point[1]
- if point[1] < line[1] <= next_col_point[1]:
- if next_col_point[1] - line[1] >= 1/3 * (next_col_point[1] - point[1]):
- col_lines[k][1] = point[1]
- if next_row_next_col_point:
- for k in range(len(col_lines)):
- line = col_lines[k]
- if line[0] - threshold <= next_row_next_col_point[0] <= line[0] + threshold:
- if not line[1] <= point[1] <= next_row_next_col_point[1] <= line[3]:
- if point[1] < line[1] <= next_row_next_col_point[1]:
- if next_row_next_col_point[1] - line[1] >= 1/3 * (next_row_next_col_point[1] - point[1]):
- col_lines[k][1] = point[1]
- return row_lines, col_lines
- def fix_inner1(row_lines, col_lines, points, split_y):
- def fix(fix_lines, assist_lines, split_points, axis):
- new_points = []
- for line1 in fix_lines:
- min_assist_line = [[], []]
- min_distance = [1000, 1000]
- if_find = [0, 0]
- # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点
- fix_line_points = []
- for point in split_points:
- if abs(point[1-axis] - line1[1-axis]) <= 2:
- if line1[axis] <= point[axis] <= line1[axis+2]:
- fix_line_points.append(point)
- # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交
- line1_point = [line1[:2], line1[2:]]
- for i in range(2):
- point = line1_point[i]
- for line2 in assist_lines:
- if not if_find[i] and abs(point[axis] - line2[axis]) <= 2:
- if line1[1-axis] <= point[1-axis] <= line2[1-axis+2]:
- # print("line1, match line2", line1, line2)
- if_find[i] = 1
- break
- else:
- if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1-axis] <= point[1-axis] <= line2[1-axis+2]:
- if line1[axis] <= line2[axis] <= line1[axis+2]:
- continue
- min_distance[i] = abs(line1[axis] - line2[axis])
- min_assist_line[i] = line2
- # 找出离assist_line最近的交点
- # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3
- min_distance = [1000, 1000]
- min_col_point = [[], []]
- for i in range(2):
- # print("顶点", i, line1_point[i])
- if min_assist_line[i]:
- for point in fix_line_points:
- if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]:
- min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
- min_col_point[i] = point
- if min_col_point[i]:
- bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis])
- line_distance = abs(min_col_point[i][axis] - line1_point[i][axis])
- if bbox_len/3 <= line_distance <= bbox_len:
- add_point = (line1_point[i][1-axis], min_assist_line[i][axis])
- print("============================table line==")
- print("fix_inner add point", add_point)
- print(min_col_point[i][axis], line1_point[i][axis], min_col_point[i][axis], min_assist_line[i][axis])
- print(abs(min_col_point[i][axis] - line1_point[i][axis]), abs(min_col_point[i][axis] - min_assist_line[i][axis])/3)
- print("line1, line2", line1, min_assist_line[i])
- new_points.append(add_point)
- return new_points
- new_points = []
- for i in range(1, len(split_y)):
- last_y = split_y[i-1]
- y = split_y[i]
- # 先对点线进行分区
- split_row_lines = []
- split_col_lines = []
- split_points = []
- for row in row_lines:
- if last_y <= row[1] <= y:
- split_row_lines.append(row)
- for col in col_lines:
- if last_y <= col[1] <= y:
- split_col_lines.append(col)
- for point in points:
- if last_y <= point[1] <= y:
- split_points.append(point)
- new_points += fix(split_col_lines, split_row_lines, split_points, axis=1)
- new_points += fix(split_row_lines, split_col_lines, split_points, axis=0)
- # 找出所有col的顶点不在row上的、row的顶点不在col上的
- # for col in split_col_lines:
- # print("*"*30)
- #
- # # 获取该line中的所有point
- # col_points = []
- # for point in split_points:
- # if abs(point[0] - col[0]) <= 2:
- # if col[1] <= point[1] <= col[3]:
- # col_points.append(point)
- #
- # # 比较顶点
- # min_row_1 = []
- # min_row_2 = []
- # min_distance_1 = 1000
- # min_distance_2 = 1000
- # if_find_1 = 0
- # if_find_2 = 0
- # for row in split_row_lines:
- # # 第一个顶点
- # if not if_find_1 and abs(col[1] - row[1]) <= 2:
- # if row[0] <= col[0] <= row[2]:
- # print("col, match row", col, row)
- # if_find_1 = 1
- # break
- # else:
- # if abs(col[1] - row[1]) < min_distance_1 and row[0] <= col[0] <= row[2]:
- # if col[1] <= row[1] <= col[3]:
- # continue
- # min_distance_1 = abs(col[1] - row[1])
- # min_row_1 = row
- #
- # # 第二个顶点
- # if not if_find_2 and abs(col[3] - row[1]) <= 2:
- # if row[0] <= col[2] <= row[2]:
- # if_find_2 = 1
- # break
- # else:
- # if abs(col[3] - row[1]) < min_distance_2 and row[0] <= col[2] <= row[2]:
- # min_distance_2 = abs(col[3] - row[1])
- # min_row_2 = row
- #
- # if not if_find_1:
- # print("col", col)
- # print("min_row_1", min_row_1)
- # if min_row_1:
- # min_distance_1 = 1000
- # min_col_point = []
- # for point in col_points:
- # if abs(point[1] - min_row_1[1]) < min_distance_1:
- # min_distance_1 = abs(point[1] - min_row_1[1])
- # min_col_point = point
- #
- # if abs(min_col_point[1] - col[1]) >= abs(min_col_point[1] - min_row_1[1])/3:
- #
- # add_point = (col[0], min_row_1[1])
- # print("fix_inner add point", add_point)
- # new_points.append(add_point)
- # else:
- # print("distance too long", min_col_point, min_row_1)
- # print(abs(min_col_point[1] - col[1]), abs(min_col_point[1] - min_row_1[1])/3)
- return points+new_points
- def fix_inner(row_lines, col_lines, points, split_y):
- def fix(fix_lines, assist_lines, split_points, axis):
- new_points = []
- for line1 in fix_lines:
- min_assist_line = [[], []]
- min_distance = [1000, 1000]
- if_find = [0, 0]
- # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点
- fix_line_points = []
- for point in split_points:
- if abs(point[1-axis] - line1[1-axis]) <= 2:
- if line1[axis] <= point[axis] <= line1[axis+2]:
- fix_line_points.append(point)
- # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交
- line1_point = [line1[:2], line1[2:]]
- for i in range(2):
- point = line1_point[i]
- for line2 in assist_lines:
- if not if_find[i] and abs(point[axis] - line2[axis]) <= 2:
- if line1[1-axis] <= point[1-axis] <= line2[1-axis+2]:
- # print("line1, match line2", line1, line2)
- if_find[i] = 1
- break
- else:
- if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1-axis] <= point[1-axis] <= line2[1-axis+2]:
- if line1[axis] <= line2[axis] <= line1[axis+2]:
- continue
- min_distance[i] = abs(line1[axis] - line2[axis])
- min_assist_line[i] = line2
- # 找出离assist_line最近的交点
- # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3
- min_distance = [1000, 1000]
- min_col_point = [[], []]
- for i in range(2):
- # print("顶点", i, line1_point[i])
- if min_assist_line[i]:
- for point in fix_line_points:
- if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]:
- min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
- min_col_point[i] = point
- # print("min_col_point", min_col_point)
- # print("min_assist_line", min_assist_line)
- # print("line1_point", line1_point)
- if min_assist_line[0] and min_assist_line[0] == min_assist_line[1]:
- if min_assist_line[0][axis] < line1_point[0][axis]:
- bbox_len = abs(min_col_point[0][axis] - min_assist_line[0][axis])
- line_distance = abs(min_col_point[0][axis] - line1_point[0][axis])
- if bbox_len/3 <= line_distance <= bbox_len:
- if axis == 1:
- add_point = (line1_point[0][1-axis], min_assist_line[0][axis])
- else:
- add_point = (min_assist_line[0][axis], line1_point[0][1-axis])
- new_points.append([line1, add_point])
- elif min_assist_line[1][axis] > line1_point[1][axis]:
- bbox_len = abs(min_col_point[1][axis] - min_assist_line[1][axis])
- line_distance = abs(min_col_point[1][axis] - line1_point[1][axis])
- if bbox_len/3 <= line_distance <= bbox_len:
- if axis == 1:
- add_point = (line1_point[1][1-axis], min_assist_line[1][axis])
- else:
- add_point = (min_assist_line[1][axis], line1_point[1][1-axis])
- new_points.append([line1, add_point])
- else:
- for i in range(2):
- if min_col_point[i]:
- bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis])
- line_distance = abs(min_col_point[i][axis] - line1_point[i][axis])
- # print("bbox_len, line_distance", bbox_len, line_distance)
- if bbox_len/3 <= line_distance <= bbox_len:
- if axis == 1:
- add_point = (line1_point[i][1-axis], min_assist_line[i][axis])
- else:
- add_point = (min_assist_line[i][axis], line1_point[i][1-axis])
- # print("============================table line==")
- # print("fix_inner add point", add_point)
- # print(min_col_point[i][axis], line1_point[i][axis], min_col_point[i][axis], min_assist_line[i][axis])
- # print(abs(min_col_point[i][axis] - line1_point[i][axis]), abs(min_col_point[i][axis] - min_assist_line[i][axis])/3)
- # print("line1, line2", line1, min_assist_line[i])
- # print("line1, add_point", [line1, add_point])
- new_points.append([line1, add_point])
- return new_points
- new_points = []
- for i in range(1, len(split_y)):
- last_y = split_y[i-1]
- y = split_y[i]
- # 先对点线进行分区
- split_row_lines = []
- split_col_lines = []
- split_points = []
- for row in row_lines:
- if last_y <= row[1] <= y:
- split_row_lines.append(row)
- for col in col_lines:
- if last_y <= col[1] <= y:
- split_col_lines.append(col)
- for point in points:
- if last_y <= point[1] <= y:
- split_points.append(point)
- new_point_list = fix(split_col_lines, split_row_lines, split_points, axis=1)
- for line, new_point in new_point_list:
- if line in col_lines:
- index = col_lines.index(line)
- point1 = line[:2]
- point2 = line[2:]
- if new_point[1] >= point2[1]:
- col_lines[index] = [point1[0], point1[1], new_point[0], new_point[1]]
- elif new_point[1] <= point1[1]:
- col_lines[index] = [new_point[0], new_point[1], point2[0], point2[1]]
- new_point_list = fix(split_row_lines, split_col_lines, split_points, axis=0)
- for line, new_point in new_point_list:
- if line in row_lines:
- index = row_lines.index(line)
- point1 = line[:2]
- point2 = line[2:]
- if new_point[0] >= point2[0]:
- row_lines[index] = [point1[0], point1[1], new_point[0], new_point[1]]
- elif new_point[0] <= point1[0]:
- row_lines[index] = [new_point[0], new_point[1], point2[0], point2[1]]
- return row_lines, col_lines
- def fix_corner(row_lines, col_lines, split_y, threshold=0):
- new_row_lines = []
- new_col_lines = []
- last_y = split_y[0]
- for y in split_y:
- if y == last_y:
- continue
- split_row_lines = []
- split_col_lines = []
- for row in row_lines:
- if last_y-threshold <= row[1] <= y+threshold or last_y-threshold <= row[3] <= y+threshold:
- split_row_lines.append(row)
- for col in col_lines:
- # fix corner 容易因split line 漏掉线
- if last_y-threshold <= col[1] <= y+threshold or last_y-threshold <= col[3] <= y+threshold:
- split_col_lines.append(col)
- if not split_row_lines or not split_col_lines:
- last_y = y
- continue
- split_row_lines.sort(key=lambda x: (x[1], x[0]))
- split_col_lines.sort(key=lambda x: (x[0], x[1]))
- up_line = split_row_lines[0]
- bottom_line = split_row_lines[-1]
- left_line = split_col_lines[0]
- right_line = split_col_lines[-1]
- # 左上角
- if up_line[0:2] != left_line[0:2]:
- # print("up_line, left_line", up_line, left_line)
- add_corner = [left_line[0], up_line[1]]
- split_row_lines[0][0] = add_corner[0]
- split_col_lines[0][1] = add_corner[1]
- # 右上角
- if up_line[2:] != right_line[:2]:
- # print("up_line, right_line", up_line, right_line)
- add_corner = [right_line[0], up_line[1]]
- split_row_lines[0][2] = add_corner[0]
- split_col_lines[-1][1] = add_corner[1]
- new_row_lines = new_row_lines + split_row_lines
- new_col_lines = new_col_lines + split_col_lines
- last_y = y
- return new_row_lines, new_col_lines
- def delete_outline(row_lines, col_lines, points):
- row_lines.sort(key=lambda x: (x[1], x[0]))
- col_lines.sort(key=lambda x: (x[0], x[1]))
- line = [row_lines[0], row_lines[-1], col_lines[0], col_lines[-1]]
- threshold = 2
- point_cnt = [0, 0, 0, 0]
- for point in points:
- for i in range(4):
- if i < 2:
- if line[i][1]-threshold <= point[1] <= line[i][1]+threshold:
- if line[i][0] <= point[0] <= line[i][2]:
- point_cnt[i] += 1
- else:
- if line[i][0]-threshold <= point[0] <= line[i][0]+threshold:
- if line[i][1] <= point[1] <= line[i][3]:
- point_cnt[i] += 1
- # if line[0][1]-threshold <= point[1] <= line[0][1]+threshold:
- # if line[0][0] <= point[0] <= line[0][2]:
- # point_cnt[0] += 1
- # elif line[1][1]-threshold <= point[1] <= line[1][1]+threshold:
- # if line[1][0] <= point[0] <= line[1][2]:
- # point_cnt[1] += 1
- # elif line[2][0]-threshold <= point[0] <= line[2][0]+threshold:
- # if line[2][1] <= point[1] <= line[2][3]:
- # point_cnt[2] += 1
- # elif line[3][0]-threshold <= point[0] <= line[3][0]+threshold:
- # if line[3][1] <= point[1] <= line[3][3]:
- # point_cnt[3] += 1
- # 轮廓line至少包含3个交点
- for i in range(4):
- if point_cnt[i] < 3:
- if i < 2:
- if line[i] in row_lines:
- row_lines.remove(line[i])
- else:
- if line[i] in col_lines:
- col_lines.remove(line[i])
- return row_lines, col_lines
- def fix_outline2(image, row_lines, col_lines, points, split_y):
- print("split_y", split_y)
- # 分割线纵坐标
- if len(split_y) < 2:
- return [], [], [], []
- # elif len(split_y) == 2:
- # split_y = [2000., 2000., 2000., 2000.]
- split_y.sort(key=lambda x: x)
- new_split_y = []
- for i in range(1, len(split_y), 2):
- new_split_y.append(int((split_y[i]+split_y[i-1])/2))
- # # 查看是否正确输出区域分割线
- # for line in split_y:
- # cv2.line(image, (0, int(line)), (int(image.shape[1]), int(line)), (0, 0, 255), 2)
- # cv2.imshow("split_y", image)
- # cv2.waitKey(0)
- # 预测线根据分割线纵坐标分为多个分割区域
- # row_lines.sort(key=lambda x: (x[3], x[2], x[1], x[0]))
- # col_lines.sort(key=lambda x: (x[3], x[2], x[1], x[0]))
- # points.sort(key=lambda x: (x[1], x[0]))
- # row_count = 0
- # col_count = 0
- # point_count = 0
- split_row_list = []
- split_col_list = []
- split_point_list = []
- # for i in range(1, len(split_y)):
- # y = split_y[i]
- # last_y = split_y[i-1]
- # row_lines = row_lines[row_count:]
- # col_lines = col_lines[col_count:]
- # points = points[point_count:]
- # row_count = 0
- # col_count = 0
- # point_count = 0
- #
- # if not row_lines:
- # split_row_list.append([])
- # for row in row_lines:
- # if last_y <= row[3] <= y:
- # row_count += 1
- # else:
- # split_row_list.append(row_lines[:row_count])
- # break
- # if row_count == len(row_lines):
- # split_row_list.append(row_lines[:row_count])
- # break
- #
- # if not col_lines:
- # split_col_list.append([])
- #
- # for col in col_lines:
- # # if last_y <= col[3] <= y:
- # if col[1] <= last_y <= y <= col[3] or last_y <= col[3] <= y:
- # # if last_y <= col[1] <= y or last_y <= col[3] <= y:
- # col_count += 1
- # else:
- # split_col_list.append(col_lines[:col_count])
- # break
- # if col_count == len(col_lines):
- # split_col_list.append(col_lines[:col_count])
- # break
- #
- # if not points:
- # split_point_list.append([])
- # for point in points:
- # if last_y <= point[1] <= y:
- # point_count += 1
- # else:
- # split_point_list.append(points[:point_count])
- # break
- # if point_count == len(points):
- # split_point_list.append(points[:point_count])
- # break
- #
- # # print("len(split_row_list)", len(split_row_list))
- # # print("len(split_col_list)", len(split_col_list))
- # if row_count < len(row_lines) - 1 and col_count < len(col_lines) - 1:
- # row_lines = row_lines[row_count:]
- # split_row_list.append(row_lines)
- # col_lines = col_lines[col_count:]
- # split_col_list.append(col_lines)
- #
- # if point_count < len(points) - 1:
- # points = points[point_count:len(points)]
- # split_point_list.append(points)
- for i in range(1, len(split_y)):
- y = split_y[i]
- last_y = split_y[i-1]
- split_row = []
- for row in row_lines:
- if last_y <= row[3] <= y:
- split_row.append(row)
- split_row_list.append(split_row)
- split_col = []
- for col in col_lines:
- if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]:
- split_col.append(col)
- split_col_list.append(split_col)
- split_point = []
- for point in points:
- if last_y <= point[1] <= y:
- split_point.append(point)
- split_point_list.append(split_point)
- # 预测线取上下左右4个边(会有超出表格部分) [(), ()]
- area_row_line = []
- area_col_line = []
- for area in split_row_list:
- if not area:
- area_row_line.append([])
- continue
- area.sort(key=lambda x: (x[1], x[0]))
- up_line = area[0]
- bottom_line = area[-1]
- area_row_line.append([up_line, bottom_line])
- for area in split_col_list:
- if not area:
- area_col_line.append([])
- continue
- area.sort(key=lambda x: x[0])
- left_line = area[0]
- right_line = area[-1]
- area_col_line.append([left_line, right_line])
- # 线交点根据分割线纵坐标分为多个分割区域
- # points.sort(key=lambda x: (x[1], x[0]))
- # point_count = 0
- # split_point_list = []
- # for y in new_split_y:
- # points = points[point_count:len(points)]
- # point_count = 0
- # for point in points:
- # if point[1] <= y:
- # point_count += 1
- # else:
- # split_point_list.append(points[:point_count])
- # break
- # if point_count == len(points):
- # split_point_list.append(points[:point_count])
- # break
- # if point_count < len(points) - 1:
- # points = points[point_count:len(points)]
- # split_point_list.append(points)
- # print("len(split_point_list)", len(split_point_list))
- # 取每个分割区域的4条线(无超出表格部分)
- area_row_line2 = []
- area_col_line2 = []
- for area in split_point_list:
- if not area:
- area_row_line2.append([])
- area_col_line2.append([])
- continue
- area.sort(key=lambda x: (x[0], x[1]))
- left_up = area[0]
- right_bottom = area[-1]
- up_line = [left_up[0], left_up[1], right_bottom[0], left_up[1]]
- bottom_line = [left_up[0], right_bottom[1], right_bottom[0], right_bottom[1]]
- left_line = [left_up[0], left_up[1], left_up[0], right_bottom[1]]
- right_line = [right_bottom[0], left_up[1], right_bottom[0], right_bottom[1]]
- area_row_line2.append([up_line, bottom_line])
- area_col_line2.append([left_line, right_line])
- # 判断超出部分的长度,超出一定长度就补线
- new_row_lines = []
- new_col_lines = []
- longer_row_lines = []
- longer_col_lines = []
- all_longer_row_lines = []
- all_longer_col_lines = []
- # print("split_y", split_y)
- # print("split_row_list", split_row_list, len(split_row_list))
- # print("split_row_list", split_col_list, len(split_col_list))
- # print("area_row_line", area_row_line, len(area_row_line))
- # print("area_col_line", area_col_line, len(area_col_line))
- for i in range(len(area_row_line)):
- if not area_row_line[i] or not area_col_line[i]:
- continue
- up_line = area_row_line[i][0]
- up_line2 = area_row_line2[i][0]
- bottom_line = area_row_line[i][1]
- bottom_line2 = area_row_line2[i][1]
- left_line = area_col_line[i][0]
- left_line2 = area_col_line2[i][0]
- right_line = area_col_line[i][1]
- right_line2 = area_col_line2[i][1]
- # 计算单格高度宽度
- if len(split_row_list[i]) > 1:
- height_dict = {}
- for j in range(len(split_row_list[i])):
- if j + 1 > len(split_row_list[i]) - 1:
- break
- height = abs(int(split_row_list[i][j][3] - split_row_list[i][j+1][3]))
- if height in height_dict.keys():
- height_dict[height] = height_dict[height] + 1
- else:
- height_dict[height] = 1
- height_list = [[x, height_dict[x]] for x in height_dict.keys()]
- height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True)
- # print("height_list", height_list)
- box_height = height_list[0][0]
- else:
- box_height = 10
- if len(split_col_list[i]) > 1:
- box_width = abs(split_col_list[i][1][2] - split_col_list[i][0][2])
- else:
- box_width = 10
- print("box_height", box_height, "box_width", box_width)
- # cv2.line(image, (int(up_line[0]), int(up_line[1])),
- # (int(up_line[2]), int(up_line[3])),
- # (255, 255, 0), 2)
- # cv2.line(image, (int(right_line[0]), int(right_line[1])),
- # (int(right_line[2]), int(right_line[3])),
- # (0, 255, 255), 2)
- # cv2.imshow("right_line", image)
- # cv2.waitKey(0)
- # 补左右两条竖线超出来的线的row
- if (up_line[1] - left_line[1] >= 10 and up_line[1] - right_line[1] >= 2) or \
- (up_line[1] - left_line[1] >= 2 and up_line[1] - right_line[1] >= 10):
- if up_line[1] - left_line[1] >= up_line[1] - right_line[1]:
- new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]])
- new_col_y = left_line[1]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差大于一格
- # print("abs(new_col_y - col[1])", abs(new_col_y - col[1]))
- if abs(new_col_y - col[1]) <= box_height:
- split_col_list[i][j][1] = min([new_col_y, col[1]])
- longer_col_lines.append([col[0], min([new_col_y, col[1]]), col[2], col[3]])
- else:
- new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]])
- new_col_y = right_line[1]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- # 需判断该线在这个区域中
- # if up_line2[1]-3 <= col[1] <= col[3] <= bottom_line2[1]+3:
- col = split_col_list[i][j]
- # 且距离不能相差太大
- # print("abs(new_col_y - col[1])", abs(new_col_y - col[1]))
- if abs(new_col_y - col[1]) <= box_height:
- split_col_list[i][j][1] = min([new_col_y, col[1]])
- if (left_line[3] - bottom_line[3] >= 10 and right_line[3] - bottom_line[3] >= 2) or \
- (left_line[3] - bottom_line[3] >= 2 and right_line[3] - bottom_line[3] >= 10):
- if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]:
- new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]])
- new_col_y = left_line[3]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差太大
- if abs(new_col_y - col[3]) <= box_height:
- split_col_list[i][j][3] = max([new_col_y, col[3]])
- else:
- new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]])
- new_col_y = right_line[3]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差太大
- if abs(new_col_y - col[3]) <= box_height:
- split_col_list[i][j][3] = max([new_col_y, col[3]])
- # 补上下两条横线超出来的线的col
- if (left_line[0] - up_line[0] >= 10 and left_line[0] - bottom_line[0] >= 2) or \
- (left_line[0] - up_line[0] >= 2 and left_line[0] - bottom_line[0] >= 10):
- if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]:
- new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]])
- new_row_x = up_line[0]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[0]) <= box_width:
- split_row_list[i][j][0] = min([new_row_x, row[0]])
- else:
- new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]])
- new_row_x = bottom_line[0]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[0]) <= box_width:
- split_row_list[i][j][0] = min([new_row_x, row[0]])
- if (up_line[2] - right_line[2] >= 10 and bottom_line[2] - right_line[2] >= 2) or \
- (up_line[2] - right_line[2] >= 2 and bottom_line[2] - right_line[2] >= 10):
- if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]:
- new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]])
- new_row_x = up_line[2]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[2]) <= box_width:
- split_row_list[i][j][2] = max([new_row_x, row[2]])
- else:
- new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]])
- new_row_x = bottom_line[2]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- # 需判断该线在这个区域中
- # if up_line2[1]-3 <= row[1] <= bottom_line2[1]+3:
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[2]) <= box_width:
- split_row_list[i][j][2] = max([new_row_x, row[2]])
- all_longer_row_lines += split_row_list[i]
- all_longer_col_lines += split_col_list[i]
- # print("all_longer_row_lines", len(all_longer_row_lines), i)
- # print("all_longer_col_lines", len(all_longer_col_lines), i)
- # print("new_row_lines", len(new_row_lines), i)
- # print("new_col_lines", len(new_col_lines), i)
- # 删除表格内部的补线
- # temp_list = []
- # for row in new_row_lines:
- # if up_line[1]-5 <= row[1] <= bottom_line[1]+5:
- # continue
- # temp_list.append(row)
- # print("fix_outline", new_row_lines)
- # new_row_lines = temp_list
- # print("fix_outline", new_row_lines)
- # temp_list = []
- # for col in new_col_lines:
- # if left_line[0]-5 <= col[0] <= right_line[0]+5:
- # continue
- # temp_list.append(col)
- #
- # new_col_lines = temp_list
- # print("fix_outline", new_col_lines)
- # print("fix_outline", new_row_lines)
- # 删除重复包含的补线
- # temp_list = []
- # for row in new_row_lines:
- # if up_line[1]-5 <= row[1] <= bottom_line[1]+5:
- # continue
- # temp_list.append(row)
- # new_row_lines = temp_list
- # 展示上下左右边框线
- # for i in range(len(area_row_line)):
- # print("row1", area_row_line[i])
- # print("row2", area_row_line2[i])
- # print("col1", area_col_line[i])
- # print("col2", area_col_line2[i])
- # cv2.line(image, (int(area_row_line[i][0][0]), int(area_row_line[i][0][1])),
- # (int(area_row_line[i][0][2]), int(area_row_line[i][0][3])), (0, 255, 0), 2)
- # cv2.line(image, (int(area_row_line2[i][1][0]), int(area_row_line2[i][1][1])),
- # (int(area_row_line2[i][1][2]), int(area_row_line2[i][1][3])), (0, 0, 255), 2)
- # cv2.imshow("fix_outline", image)
- # cv2.waitKey(0)
- # 展示所有线
- # for line in all_longer_col_lines:
- # cv2.line(image, (int(line[0]), int(line[1])),
- # (int(line[2]), int(line[3])),
- # (0, 255, 0), 2)
- # cv2.imshow("fix_outline", image)
- # cv2.waitKey(0)
- # for line in all_longer_row_lines:
- # cv2.line(image, (int(line[0]), int(line[1])),
- # (int(line[2]), int(line[3])),
- # (0, 0, 255), 2)
- # cv2.imshow("fix_outline", image)
- # cv2.waitKey(0)
- return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines
- def fix_outline(image, row_lines, col_lines, points, split_y, scale=25):
- logging.info("into fix_outline")
- x_min_len = max(10, int(image.shape[0] / scale))
- y_min_len = max(10, int(image.shape[1] / scale))
- # print("x_min_len", x_min_len, "y_min_len", y_min_len)
- # print("split_y", split_y)
- # 分割线纵坐标
- if len(split_y) < 2:
- return [], [], [], []
- split_y.sort(key=lambda x: x)
- new_split_y = []
- for i in range(1, len(split_y), 2):
- new_split_y.append(int((split_y[i]+split_y[i-1])/2))
- split_row_list = []
- split_col_list = []
- split_point_list = []
- for i in range(1, len(split_y)):
- y = split_y[i]
- last_y = split_y[i-1]
- split_row = []
- for row in row_lines:
- if last_y <= row[3] <= y:
- split_row.append(row)
- split_row_list.append(split_row)
- split_col = []
- for col in col_lines:
- if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]:
- split_col.append(col)
- split_col_list.append(split_col)
- split_point = []
- for point in points:
- if last_y <= point[1] <= y:
- split_point.append(point)
- split_point_list.append(split_point)
- # 预测线取上下左右4个边(会有超出表格部分) [(), ()]
- area_row_line = []
- area_col_line = []
- for area in split_row_list:
- if not area:
- area_row_line.append([])
- continue
- area.sort(key=lambda x: (x[1], x[0]))
- up_line = area[0]
- bottom_line = area[-1]
- area_row_line.append([up_line, bottom_line])
- for area in split_col_list:
- if not area:
- area_col_line.append([])
- continue
- area.sort(key=lambda x: x[0])
- left_line = area[0]
- right_line = area[-1]
- area_col_line.append([left_line, right_line])
- # 取每个分割区域的4条线(无超出表格部分)
- area_row_line2 = []
- area_col_line2 = []
- for area in split_point_list:
- if not area:
- area_row_line2.append([])
- area_col_line2.append([])
- continue
- area.sort(key=lambda x: (x[0], x[1]))
- left_up = area[0]
- right_bottom = area[-1]
- up_line = [left_up[0], left_up[1], right_bottom[0], left_up[1]]
- bottom_line = [left_up[0], right_bottom[1], right_bottom[0], right_bottom[1]]
- left_line = [left_up[0], left_up[1], left_up[0], right_bottom[1]]
- right_line = [right_bottom[0], left_up[1], right_bottom[0], right_bottom[1]]
- area_row_line2.append([up_line, bottom_line])
- area_col_line2.append([left_line, right_line])
- # 判断超出部分的长度,超出一定长度就补线
- new_row_lines = []
- new_col_lines = []
- longer_row_lines = []
- longer_col_lines = []
- all_longer_row_lines = []
- all_longer_col_lines = []
- for i in range(len(area_row_line)):
- if not area_row_line[i] or not area_col_line[i]:
- continue
- up_line = area_row_line[i][0]
- up_line2 = area_row_line2[i][0]
- bottom_line = area_row_line[i][1]
- bottom_line2 = area_row_line2[i][1]
- left_line = area_col_line[i][0]
- left_line2 = area_col_line2[i][0]
- right_line = area_col_line[i][1]
- right_line2 = area_col_line2[i][1]
- # 计算单格高度宽度
- if len(split_row_list[i]) > 1:
- height_dict = {}
- for j in range(len(split_row_list[i])):
- if j + 1 > len(split_row_list[i]) - 1:
- break
- # print("height_dict", split_row_list[i][j], split_row_list[i][j+1])
- height = abs(int(split_row_list[i][j][3] - split_row_list[i][j+1][3]))
- if height >= 10:
- if height in height_dict.keys():
- height_dict[height] = height_dict[height] + 1
- else:
- height_dict[height] = 1
- height_list = [[x, height_dict[x]] for x in height_dict.keys()]
- height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True)
- # print("box_height", height_list)
- box_height = height_list[0][0]
- else:
- box_height = y_min_len
- if len(split_col_list[i]) > 1:
- box_width = abs(split_col_list[i][1][2] - split_col_list[i][0][2])
- else:
- box_width = x_min_len
- # print("box_height", box_height, "box_width", box_width)
- # 设置轮廓线需超出阈值
- if box_height >= 2*y_min_len:
- fix_h_len = y_min_len
- else:
- fix_h_len = box_height * 2/3
- if box_width >= 2*x_min_len:
- fix_w_len = x_min_len
- else:
- fix_w_len = box_width * 2/3
- # 补左右两条竖线超出来的线的row
- if up_line[1] - left_line[1] >= fix_h_len and up_line[1] - right_line[1] >= fix_h_len:
- if up_line[1] - left_line[1] >= up_line[1] - right_line[1]:
- new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]])
- new_col_y = left_line[1]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差大于一格
- # print("abs(new_col_y - col[1])", abs(new_col_y - col[1]))
- if abs(new_col_y - col[1]) <= box_height:
- split_col_list[i][j][1] = min([new_col_y, col[1]])
- longer_col_lines.append([col[0], min([new_col_y, col[1]]), col[2], col[3]])
- else:
- new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]])
- new_col_y = right_line[1]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- # 需判断该线在这个区域中
- # if up_line2[1]-3 <= col[1] <= col[3] <= bottom_line2[1]+3:
- col = split_col_list[i][j]
- # 且距离不能相差太大
- # print("abs(new_col_y - col[1])", abs(new_col_y - col[1]))
- if abs(new_col_y - col[1]) <= box_height:
- split_col_list[i][j][1] = min([new_col_y, col[1]])
- if left_line[3] - bottom_line[3] >= fix_h_len and right_line[3] - bottom_line[3] >= fix_h_len:
- if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]:
- new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]])
- new_col_y = left_line[3]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差太大
- if abs(new_col_y - col[3]) <= box_height:
- split_col_list[i][j][3] = max([new_col_y, col[3]])
- else:
- new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]])
- new_col_y = right_line[3]
- # 补了row,要将其他短的col连到row上
- for j in range(len(split_col_list[i])):
- col = split_col_list[i][j]
- # 且距离不能相差太大
- if abs(new_col_y - col[3]) <= box_height:
- split_col_list[i][j][3] = max([new_col_y, col[3]])
- # 补上下两条横线超出来的线的col
- if left_line[0] - up_line[0] >= fix_w_len and left_line[0] - bottom_line[0] >= fix_w_len:
- if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]:
- new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]])
- new_row_x = up_line[0]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[0]) <= box_width:
- split_row_list[i][j][0] = min([new_row_x, row[0]])
- else:
- new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]])
- new_row_x = bottom_line[0]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[0]) <= box_width:
- split_row_list[i][j][0] = min([new_row_x, row[0]])
- if up_line[2] - right_line[2] >= fix_w_len and bottom_line[2] - right_line[2] >= fix_w_len:
- if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]:
- new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]])
- new_row_x = up_line[2]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[2]) <= box_width:
- split_row_list[i][j][2] = max([new_row_x, row[2]])
- else:
- new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]])
- new_row_x = bottom_line[2]
- # 补了col,要将其他短的row连到col上
- for j in range(len(split_row_list[i])):
- # 需判断该线在这个区域中
- # if up_line2[1]-3 <= row[1] <= bottom_line2[1]+3:
- row = split_row_list[i][j]
- # 且距离不能相差太大
- if abs(new_row_x - row[2]) <= box_width:
- split_row_list[i][j][2] = max([new_row_x, row[2]])
- all_longer_row_lines += split_row_list[i]
- all_longer_col_lines += split_col_list[i]
- return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines
- def fix_table(row_point_list, col_point_list, split_y, row_lines, col_lines):
- # 分割线纵坐标
- if len(split_y) < 2:
- return []
- # 获取bbox
- bbox = []
- # 每个点获取与其x最相近和y最相近的点
- for i in range(1, len(split_y)):
- # 循环每行
- for row in row_point_list:
- row.sort(key=lambda x: (x[0], x[1]))
- # 行不在该区域跳过
- if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]:
- continue
- # print("len(row)", len(row))
- # print("row", row)
- # 循环行中的点
- for j in range(len(row)):
- if j == len(row) - 1:
- break
- current_point = row[j]
- next_point_in_row_list = row[j+1:]
- # 循环这一行的下一个点
- for next_point_in_row in next_point_in_row_list:
- # 是否在这一行点找到,找不到就这一行的下个点
- not_found = 1
- # 查询下个点所在列
- next_col = []
- for col in col_point_list:
- col.sort(key=lambda x: (x[1], x[0]))
- # 列不在该区域跳过
- if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]:
- continue
- if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3:
- next_col = col
- break
- # 循环匹配当前点和下一列点
- next_col.sort(key=lambda x: (x[1], x[0]))
- for point1 in next_col:
- # 同一行的就跳过
- if current_point[1]-3 <= point1[1] <= current_point[1]+3:
- continue
- if point1[1] <= current_point[1]-3:
- continue
- # 候选bbox
- candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]]
- # print("candidate_bbox", candidate_bbox)
- # 判断该bbox是否存在,判断bbox的上下两条边是否有包含在row中
- contain_flag1 = 0
- contain_flag2 = 0
- for row1 in row_lines:
- # 行不在该区域跳过
- if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]:
- continue
- # bbox上边框 y一样
- if not contain_flag1:
- if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3:
- # 格子里的断开线段
- row1_break = (max([row1[0], candidate_bbox[0]]),
- row1[1],
- min([row1[2], candidate_bbox[2]]),
- row1[3])
- if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- contain_flag1 = 1
- # bbox下边框 y一样
- if not contain_flag2:
- if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3:
- # 格子里的断开线段
- row1_break = (max([row1[0], candidate_bbox[0]]),
- row1[1],
- min([row1[2], candidate_bbox[2]]),
- row1[3])
- if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- contain_flag2 = 1
- # 判断该bbox是否存在,判断bbox的左右两条边是否有包含在col中
- contain_flag3 = 0
- contain_flag4 = 0
- for col1 in col_lines:
- # 列不在该区域跳过
- if col1[1] <= split_y[i-1] or col1[3] >= split_y[i]:
- continue
- # bbox左边线 x一样
- if not contain_flag3:
- if col1[0]-3 <= candidate_bbox[0] <= col1[0]+3:
- # 格子里的断开线段
- col1_break = (col1[0],
- max([col1[1], candidate_bbox[1]]),
- col1[2],
- min([col1[3], candidate_bbox[3]]))
- if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- contain_flag3 = 1
- # bbox右边框 x一样
- if not contain_flag4:
- if col1[0]-3 <= candidate_bbox[2] <= col1[0]+3:
- # 格子里的断开线段
- col1_break = (col1[0],
- max([col1[1], candidate_bbox[1]]),
- col1[2],
- min([col1[3], candidate_bbox[3]]))
- if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- contain_flag4 = 1
- # 找到了该bbox,并且是存在的
- if contain_flag1 and contain_flag2 and contain_flag3 and contain_flag4:
- bbox.append([(candidate_bbox[0], candidate_bbox[1]),
- (candidate_bbox[2], candidate_bbox[3])])
- not_found = 0
- break
- if not not_found:
- break
- return bbox
- def delete_close_points(point_list, row_point_list, col_point_list, threshold=5):
- new_point_list = []
- delete_point_list = []
- point_list.sort(key=lambda x: (x[1], x[0]))
- for i in range(len(point_list)):
- point1 = point_list[i]
- if point1 in delete_point_list:
- continue
- if i == len(point_list) - 1:
- new_point_list.append(point1)
- break
- point2 = point_list[i+1]
- # 判断坐标
- if abs(point1[0] - point2[0]) > threshold or abs(point1[1] - point2[1]) > threshold:
- new_point_list.append(point1)
- else:
- # 看两个点上的相同坐标点哪个多,就保留哪个
- count1 = 0
- count2 = 0
- for col in col_point_list:
- if point1[0] == col[0][0]:
- count1 += len(col)
- elif point2[0] == col[0][0]:
- count2 += len(col)
- if count1 >= count2:
- new_point_list.append(point1)
- delete_point_list.append(point2)
- else:
- new_point_list.append(point2)
- delete_point_list.append(point1)
- point_list = new_point_list
- new_point_list = []
- delete_point_list = []
- point_list.sort(key=lambda x: (x[0], x[1]))
- for i in range(len(point_list)):
- point1 = point_list[i]
- if point1 in delete_point_list:
- continue
- if i == len(point_list) - 1:
- new_point_list.append(point1)
- break
- point2 = point_list[i+1]
- # 判断坐标
- if abs(point1[0] - point2[0]) > threshold or abs(point1[1] - point2[1]) > threshold:
- new_point_list.append(point1)
- else:
- count1 = 0
- count2 = 0
- for row in row_point_list:
- if point1[0] == row[0][0]:
- count1 += len(row)
- elif point2[0] == row[0][0]:
- count2 += len(row)
- if count1 >= count2:
- new_point_list.append(point1)
- delete_point_list.append(point2)
- else:
- new_point_list.append(point2)
- delete_point_list.append(point1)
- return new_point_list
- def get_bbox2(image_np, points):
- # # 坐标点按行分
- # row_point_list = []
- # row_point = []
- # points.sort(key=lambda x: (x[0], x[1]))
- # for p in points:
- # if len(row_point) == 0:
- # x = p[0]
- # if x-5 <= p[0] <= x+5:
- # row_point.append(p)
- # else:
- # row_point_list.append(row_point)
- # row_point = []
- # # 坐标点按列分
- # col_point_list = []
- # col_point = []
- # points.sort(key=lambda x: (x[1], x[0]))
- # for p in points:
- # if len(col_point) == 0:
- # y = p[1]
- # if y-5 <= p[1] <= y+5:
- # col_point.append(p)
- # else:
- # col_point_list.append(col_point)
- # col_point = []
- row_point_list = get_points_row(points)
- col_point_list = get_points_col(points)
- print("len(points)", len(points))
- for point in points:
- cv2.circle(image_np, point, 1, (0, 255, 0), 1)
- cv2.imshow("points_deleted", image_np)
- points = delete_close_points(points, row_point_list, col_point_list)
- print("len(points)", len(points))
- for point in points:
- cv2.circle(image_np, point, 1, (255, 0, 0), 3)
- cv2.imshow("points_deleted", image_np)
- cv2.waitKey(0)
- row_point_list = get_points_row(points, 5)
- col_point_list = get_points_col(points, 5)
- print("len(row_point_list)", len(row_point_list))
- for row in row_point_list:
- print("row", len(row))
- print("col_point_list", len(col_point_list))
- for col in col_point_list:
- print("col", len(col))
- bbox = []
- for i in range(len(row_point_list)):
- if i == len(row_point_list) - 1:
- break
- # 遍历每个row的point,找到其所在列的下一个点和所在行的下一个点
- current_row = row_point_list[i]
- for j in range(len(current_row)):
- current_point = current_row[j]
- if j == len(current_row) - 1:
- break
- next_row_point = current_row[j+1]
- # 找出当前点所在的col,得到该列下一个point
- current_col = col_point_list[j]
- for k in range(len(current_col)):
- if current_col[k][1] > current_point[1] + 10:
- next_col_point = current_col[k]
- break
- next_row = row_point_list[k]
- for k in range(len(next_row)):
- if next_row[k][0] >= next_row_point[0] + 5:
- next_point = next_row[k]
- break
- # 得到bbox
- bbox.append([(current_point[0], current_point[1]), (next_point[0], next_point[1])])
- # bbox = []
- # for p in points:
- # # print("p", p)
- # p_row = []
- # p_col = []
- # for row in row_point_list:
- # if p[0] == row[0][0]:
- # for p1 in row:
- # if abs(p[1]-p1[1]) <= 5:
- # continue
- # p_row.append([p1, abs(p[1]-p1[1])])
- # p_row.sort(key=lambda x: x[1])
- # for col in col_point_list:
- # if p[1] == col[0][1]:
- # for p2 in col:
- # if abs(p[0]-p2[0]) <= 5:
- # continue
- # p_col.append([p2, abs(p[0]-p2[0])])
- # p_col.sort(key=lambda x: x[1])
- # if len(p_row) == 0 or len(p_col) == 0:
- # continue
- # break_flag = 0
- # for i in range(len(p_row)):
- # for j in range(len(p_col)):
- # # print(p_row[i][0])
- # # print(p_col[j][0])
- # another_point = (p_col[j][0][0], p_row[i][0][1])
- # # print("another_point", another_point)
- # if abs(p[0]-another_point[0]) <= 5 or abs(p[1]-another_point[1]) <= 5:
- # continue
- # if p[0] >= another_point[0] or p[1] >= another_point[1]:
- # continue
- # if another_point in points:
- # box = [p, another_point]
- # box.sort(key=lambda x: x[0])
- # if box not in bbox:
- # bbox.append(box)
- # break_flag = 1
- # break
- # if break_flag:
- # break
- #
- # # delete duplicate
- # delete_bbox = []
- # for i in range(len(bbox)):
- # for j in range(i+1, len(bbox)):
- # if bbox[i][0] == bbox[j][0]:
- # if bbox[i][1][0] - bbox[j][1][0] <= 3 \
- # and bbox[i][1][1] - bbox[j][1][1] <= 3:
- # delete_bbox.append(bbox[j])
- # if bbox[i][1] == bbox[j][1]:
- # if bbox[i][0][0] - bbox[j][0][0] <= 3 \
- # and bbox[i][0][1] - bbox[j][0][1] <= 3:
- # delete_bbox.append(bbox[j])
- # # delete too small area
- # # for box in bbox:
- # # if box[1][0] - box[0][0] <=
- # for d_box in delete_bbox:
- # if d_box in bbox:
- # bbox.remove(d_box)
- # print bbox
- bbox.sort(key=lambda x: (x[0][0], x[0][1], x[1][0], x[1][1]))
- # origin bbox
- # origin_bbox = []
- # for box in bbox:
- # origin_bbox.append([(box[0][0], box[0][1] - 40), (box[1][0], box[1][1] - 40)])
- # for box in origin_bbox:
- # cv2.rectangle(origin_image, box[0], box[1], (0, 0, 255), 2, 8)
- # cv2.imshow('AlanWang', origin_image)
- # cv2.waitKey(0)
- for box in bbox:
- cv2.rectangle(image_np, box[0], box[1], (0, 0, 255), 2, 8)
- cv2.imshow('bboxes', image_np)
- cv2.waitKey(0)
- # for point in points:
- # print(point)
- # cv2.circle(image_np, point, 1, (0, 0, 255), 3)
- # cv2.imshow('points', image_np)
- # cv2.waitKey(0)
- return bbox
- def get_bbox1(image_np, points, split_y):
- # 分割线纵坐标
- # print("split_y", split_y)
- if len(split_y) < 2:
- return []
- # 计算行列,剔除相近交点
- row_point_list = get_points_row(points)
- col_point_list = get_points_col(points)
- print("len(row_point_list)", row_point_list)
- print("len(col_point_list)", len(col_point_list))
- # for point in points:
- # cv2.circle(image_np, point, 1, (0, 255, 0), 1)
- # cv2.imshow("points", image_np)
- points = delete_close_points(points, row_point_list, col_point_list)
- # print("len(points)", len(points))
- # for point in points:
- # cv2.circle(image_np, point, 1, (255, 0, 0), 3)
- # cv2.imshow("points_deleted", image_np)
- # cv2.waitKey(0)
- # 获取bbox
- bbox = []
- # 每个点获取与其x最相近和y最相近的点
- for i in range(1, len(split_y)):
- for point1 in points:
- if point1[1] <= split_y[i-1] or point1[1] >= split_y[i]:
- continue
- distance_x = 10000
- distance_y = 10000
- x = 0
- y = 0
- threshold = 10
- for point2 in points:
- if point2[1] <= split_y[i-1] or point2[1] >= split_y[i]:
- continue
- # 最近 x y
- if 2 < point2[0] - point1[0] < distance_x and point2[1] - point1[1] <= threshold:
- distance_x = point2[0] - point1[0]
- x = point2[0]
- if 2 < point2[1] - point1[1] < distance_y and point2[0] - point1[0] <= threshold:
- distance_y = point2[1] - point1[1]
- y = point2[1]
- if not x or not y:
- continue
- bbox.append([(point1[0], point1[1]), (x, y)])
- # 删除包含关系bbox
- temp_list = []
- for i in range(len(bbox)):
- box1 = bbox[i]
- for j in range(len(bbox)):
- if i == j:
- continue
- box2 = bbox[j]
- contain_flag = 0
- if box2[0][0] <= box1[0][0] <= box1[1][0] <= box2[1][0] and \
- box2[0][1] <= box1[0][1] <= box1[1][1] <= box2[1][1]:
- contain_flag = 1
- break
- temp_list.append(box1)
- bbox = temp_list
- # 展示
- for box in bbox:
- # print(box[0], box[1])
- # if abs(box[0][1] - box[1][1]) > abs(box[0][0] - box[1][0]):
- # continue
- cv2.rectangle(image_np, box[0], box[1], (0, 0, 255), 2, 8)
- cv2.imshow('bboxes', image_np)
- cv2.waitKey(0)
- return bbox
- def get_bbox0(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines):
- # 分割线纵坐标
- if len(split_y) < 2:
- return []
- # 计算行列,剔除相近交点
- # row_point_list = get_points_row(points)
- # col_point_list = get_points_col(points)
- # points = delete_close_points(points, row_point_list, col_point_list)
- # row_point_list = get_points_row(points)
- # col_point_list = get_points_col(points)
- # 获取bbox
- bbox = []
- # print("get_bbox split_y", split_y)
- # 每个点获取与其x最相近和y最相近的点
- for i in range(1, len(split_y)):
- # 循环每行
- for row in row_point_list:
- row.sort(key=lambda x: (x[0], x[1]))
- # 行不在该区域跳过
- if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]:
- continue
- # 循环行中的点
- for j in range(len(row)):
- if j == len(row) - 1:
- break
- current_point = row[j]
- next_point_in_row = row[j+1]
- # 查询下个点所在列
- next_col = []
- for col in col_point_list:
- col.sort(key=lambda x: (x[1], x[0]))
- # 列不在该区域跳过
- if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]:
- continue
- if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3:
- next_col = col
- break
- # 循环匹配当前点和下一列点
- for point1 in next_col:
- # 同一行的就跳过
- if current_point[1]-3 <= point1[1] <= current_point[1]+3:
- continue
- if point1[1] <= current_point[1]-3:
- continue
- # 候选bbox
- candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]]
- # 判断该bbox是否存在,线条包含关系
- contain_flag1 = 0
- contain_flag2 = 0
- for row1 in row_lines:
- # 行不在该区域跳过
- if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]:
- continue
- # bbox上边框 y一样
- if not contain_flag1:
- if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3:
- # candidate的x1,x2需被包含在row线中
- if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3:
- contain_flag1 = 1
- # bbox下边框 y一样
- if not contain_flag2:
- if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3:
- # candidate的x1,x2需被包含在row线中
- if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3:
- contain_flag2 = 1
- # 找到了该bbox,并且是存在的
- if contain_flag1 and contain_flag2:
- bbox.append([(candidate_bbox[0], candidate_bbox[1]),
- (candidate_bbox[2], candidate_bbox[3])])
- break
- return bbox
- def get_bbox3(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines):
- # 分割线纵坐标
- if len(split_y) < 2:
- return []
- # 获取bbox
- bbox = []
- # 每个点获取与其x最相近和y最相近的点
- for i in range(1, len(split_y)):
- # 循环每行
- for row in row_point_list:
- row.sort(key=lambda x: (x[0], x[1]))
- # 行不在该区域跳过
- if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]:
- continue
- # print("len(row)", len(row))
- # print("row", row)
- # 循环行中的点
- for j in range(len(row)):
- if j == len(row) - 1:
- break
- current_point = row[j]
- # print("current_point", current_point)
- next_point_in_row_list = row[j+1:]
- # 循环这一行的下一个点
- for next_point_in_row in next_point_in_row_list:
- # 是否在这一行点找到,找不到就这一行的下个点
- not_found = 1
- # 查询下个点所在列
- next_col = []
- for col in col_point_list:
- col.sort(key=lambda x: (x[1], x[0]))
- # 列不在该区域跳过
- if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]:
- continue
- if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3:
- next_col = col
- break
- # 循环匹配当前点和下一列点
- next_col.sort(key=lambda x: (x[1], x[0]))
- for point1 in next_col:
- # 同一行的就跳过
- if current_point[1]-3 <= point1[1] <= current_point[1]+3:
- continue
- if point1[1] <= current_point[1]-3:
- continue
- # 候选bbox
- candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]]
- # print("candidate_bbox", candidate_bbox)
- # 判断该bbox是否存在,判断bbox的上下两条边是否有包含在row中
- contain_flag1 = 0
- contain_flag2 = 0
- for row1 in row_lines:
- # 行不在该区域跳过
- if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]:
- continue
- # bbox上边框 y一样
- if not contain_flag1:
- if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3:
- # 格子里的断开线段
- row1_break = (max([row1[0], candidate_bbox[0]]),
- row1[1],
- min([row1[2], candidate_bbox[2]]),
- row1[3])
- if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- contain_flag1 = 1
- # # candidate的x1,x2需被包含在row线中
- # if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3:
- # contain_flag1 = 1
- #
- # # 判断线条有无端点在格子中
- # elif candidate_bbox[0] < row1[0] < candidate_bbox[2] \
- # or candidate_bbox[0] < row1[2] < candidate_bbox[2]:
- # # 线条会有缺一点情况,判断长度超过格子一半
- # if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- # contain_flag1 = 1
- # bbox下边框 y一样
- if not contain_flag2:
- if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3:
- # 格子里的断开线段
- row1_break = (max([row1[0], candidate_bbox[0]]),
- row1[1],
- min([row1[2], candidate_bbox[2]]),
- row1[3])
- if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- contain_flag2 = 1
- # # candidate的x1,x2需被包含在row线中
- # if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3:
- # contain_flag2 = 1
- #
- # # 判断线条有无端点在格子中
- # elif candidate_bbox[0] < row1[0] < candidate_bbox[2] \
- # or candidate_bbox[0] < row1[2] < candidate_bbox[2]:
- # # 线条会有缺一点情况,判断长度超过格子一半
- # if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3:
- # contain_flag2 = 1
- # 判断该bbox是否存在,判断bbox的左右两条边是否有包含在col中
- contain_flag3 = 0
- contain_flag4 = 0
- for col1 in col_lines:
- # 列不在该区域跳过
- if col1[1] <= split_y[i-1] or col1[3] >= split_y[i]:
- continue
- # bbox左边线 x一样
- if not contain_flag3:
- if col1[0]-3 <= candidate_bbox[0] <= col1[0]+3:
- # 格子里的断开线段
- col1_break = (col1[0],
- max([col1[1], candidate_bbox[1]]),
- col1[2],
- min([col1[3], candidate_bbox[3]]))
- if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- contain_flag3 = 1
- # # candidate的y1,y2需被包含在col线中
- # if col1[1]-3 <= candidate_bbox[1] <= candidate_bbox[3] <= col1[3]+3:
- # contain_flag3 = 1
- #
- # # 判断线条有无端点在格子中
- # elif candidate_bbox[1] < col1[1] < candidate_bbox[3] \
- # or candidate_bbox[1] < col1[3] < candidate_bbox[3]:
- # # 线条会有缺一点情况,判断长度超过格子一半
- # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- # contain_flag3 = 1
- # bbox右边框 x一样
- if not contain_flag4:
- if col1[0]-3 <= candidate_bbox[2] <= col1[0]+3:
- # 格子里的断开线段
- # col1_break = (col1[0],
- # max([col1[1], candidate_bbox[1]]),
- # col1[2],
- # min([col1[3], candidate_bbox[3]]))
- # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- # contain_flag4 = 1
- # 如果候选bbox的边的上1/3或下1/3包含在col中
- candidate_bbox_line1 = [candidate_bbox[1],
- candidate_bbox[1] + (candidate_bbox[3]-candidate_bbox[1])/3]
- candidate_bbox_line2 = [candidate_bbox[3] - (candidate_bbox[3]-candidate_bbox[1])/3,
- candidate_bbox[3]]
- if col1[1] <= candidate_bbox_line1[0] <= candidate_bbox_line1[1] <= col1[3] \
- or col1[1] <= candidate_bbox_line2[0] <= candidate_bbox_line2[1] <= col1[3]:
- # print("candidate_bbox", candidate_bbox)
- # print("col1", col1)
- contain_flag4 = 1
- # # candidate的y1,y2需被包含在col线中
- # if col1[1]-3 <= candidate_bbox[1] <= candidate_bbox[3] <= col1[3]+3:
- # contain_flag4 = 1
- #
- # # 判断线条有无端点在格子中
- # elif candidate_bbox[1] < col1[1] < candidate_bbox[3] \
- # or candidate_bbox[1] < col1[3] < candidate_bbox[3]:
- # # 线条会有缺一点情况,判断长度超过格子一半
- # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3:
- # contain_flag4 = 1
- # 找到了该bbox,并且是存在的
- if contain_flag1 and contain_flag2 and contain_flag3 and contain_flag4:
- bbox.append([(candidate_bbox[0], candidate_bbox[1]),
- (candidate_bbox[2], candidate_bbox[3])])
- not_found = 0
- # print("exist candidate_bbox", candidate_bbox)
- # print(contain_flag1, contain_flag2, contain_flag3, contain_flag4)
- break
- # else:
- # print("candidate_bbox", candidate_bbox)
- # print(contain_flag1, contain_flag2, contain_flag3, contain_flag4)
- if not not_found:
- break
- return bbox
- def get_bbox(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines):
- # 分割线纵坐标
- if len(split_y) < 2:
- return []
- # 获取bbox
- bbox_list = []
- for i in range(1, len(split_y)):
- last_y = split_y[i-1]
- y = split_y[i]
- # 先对点线进行分区
- split_row_point_list = []
- split_col_point_list = []
- split_row_lines = []
- split_col_lines = []
- for row in row_point_list:
- if last_y <= row[0][1] <= y:
- row.sort(key=lambda x: (x[1], x[0]))
- split_row_point_list.append(row)
- for col in col_point_list:
- if last_y <= col[0][1] <= y:
- split_col_point_list.append(col)
- for row in row_lines:
- if last_y <= row[1] <= y:
- split_row_lines.append(row)
- for col in col_lines:
- if last_y <= col[1] <= y:
- split_col_lines.append(col)
- # 每个点获取其对角线点,以便形成bbox,按行循环
- for i in range(len(split_row_point_list)-1):
- row = split_row_point_list[i]
- # 循环该行的点
- for k in range(len(row)-1):
- point1 = row[k]
- next_point1 = row[k+1]
- # print("*"*30)
- # print("point1", point1)
- # 有三种对角线点
- # 1. 该点下一行的下一列的点
- # 2. 该点下一列的下一行的点
- # 3. 上述两个点是同一个点
- # 下一行没找到就循环后面的行
- if_find = 0
- for j in range(i+1, len(split_row_point_list)):
- if if_find:
- break
- next_row = split_row_point_list[j]
- # print("next_row", next_row)
- # 循环下一行的点
- for point2 in next_row:
- if abs(point1[0] - point2[0]) <= 2:
- continue
- if point2[0] < point1[0]:
- continue
- bbox = [point1[0], point1[1], point2[0], point2[1]]
- if abs(bbox[0] - bbox[2]) <= 10:
- continue
- if abs(bbox[1] - bbox[3]) <= 10:
- continue
- # bbox的四条边都需要验证是否在line上
- if check_bbox(bbox, split_row_lines, split_col_lines):
- bbox_list.append([(bbox[0], bbox[1]), (bbox[2], bbox[3])])
- if_find = 1
- # print("check bbox", bbox)
- break
- return bbox_list
- def check_bbox(bbox, rows, cols, threshold=5):
- def check(check_line, lines, limit_axis, axis):
- # 需检查的线的1/2段,1/3段,2/3段,1/4段,3/4段
- line_1_2 = [check_line[0], (check_line[0]+check_line[1])/2]
- line_2_2 = [(check_line[0]+check_line[1])/2, check_line[1]]
- line_1_3 = [check_line[0], check_line[0]+(check_line[1]-check_line[0])/3]
- line_2_3 = [check_line[1]-(check_line[1]-check_line[0])/3, check_line[1]]
- line_1_4 = [check_line[0], check_line[0]+(check_line[1]-check_line[0])/4]
- line_3_4 = [check_line[1]-(check_line[1]-check_line[0])/4, check_line[1]]
- # 限制row相同y,col相同x
- if_line = 0
- for line1 in lines:
- if not if_line and abs(line1[1-axis] - limit_axis) <= threshold:
- # check_line完全包含在line中
- if line1[axis] <= check_line[0] <= check_line[1] <= line1[axis+2]:
- if_line = 1
- # check_line的1/2包含在line
- elif line1[axis] <= line_1_2[0] <= line_1_2[1] <= line1[axis+2] \
- or line1[axis] <= line_2_2[0] <= line_2_2[1] <= line1[axis+2]:
- if_line = 1
- # check_line两个1/3段被包含在不同line中
- elif line1[axis] <= line_1_3[0] <= line_1_3[1] <= line1[axis+2]:
- # check_line另一边的1/4被包含
- for line2 in lines:
- if abs(line1[1-axis] - limit_axis) <= threshold:
- if line2[axis] <= line_3_4[0] <= line_3_4[1] <= line2[axis+2]:
- if_line = 1
- break
- elif line1[axis] <= line_2_3[0] <= line_2_3[1] <= line1[axis+2]:
- # check_line另一边的1/4被包含
- for line2 in lines:
- if abs(line1[1-axis] - limit_axis) <= threshold:
- if line2[axis] <= line_1_4[0] <= line_1_4[1] <= line2[axis+2]:
- if_line = 1
- break
- return if_line
- up_down_line = [bbox[0], bbox[2]]
- up_y, down_y = bbox[1], bbox[3]
- left_right_line = [bbox[1], bbox[3]]
- left_x, right_x = bbox[0], bbox[2]
- # 检查bbox四条边是否存在
- if_up = check(up_down_line, rows, up_y, 0)
- if_down = check(up_down_line, rows, down_y, 0)
- if_left = check(left_right_line, cols, left_x, 1)
- if_right = check(left_right_line, cols, right_x, 1)
- # 检查bbox内部除了四条边,是否有其它line在bbox内部
- if_col = 0
- if_row = 0
- if if_up and if_down and if_left and if_right:
- for col in cols:
- if not if_col and left_x+threshold <= col[0] <= right_x-threshold:
- if col[1] <= left_right_line[0] <= left_right_line[1] <= col[3]:
- if_col = 1
- elif left_right_line[0] <= col[1] <= left_right_line[1]:
- if left_right_line[1] - col[1] >= (left_right_line[1] + left_right_line[0])/2:
- if_col = 1
- elif left_right_line[0] <= col[3] <= left_right_line[1]:
- if col[3] - left_right_line[0] >= (left_right_line[1] + left_right_line[0])/2:
- if_col = 1
- for row in rows:
- if not if_row and up_y+threshold <= row[1] <= down_y-threshold:
- if row[0] <= up_down_line[0] <= up_down_line[1] <= row[2]:
- if_row = 1
- elif up_down_line[0] <= row[0] <= up_down_line[1]:
- if up_down_line[1] - row[0] >= (up_down_line[1] + up_down_line[0])/2:
- if_row = 1
- elif up_down_line[0] <= row[2] <= up_down_line[1]:
- if row[2] - up_down_line[0] >= (up_down_line[1] + up_down_line[0])/2:
- if_row = 1
- if if_up and if_down and if_left and if_right and not if_col and not if_row:
- return True
- else:
- return False
- def add_continue_bbox(bboxes):
- add_bbox_list = []
- bboxes.sort(key=lambda x: (x[0][0], x[0][1]))
- last_bbox = bboxes[0]
- # 先对bbox分区
- for i in range(1, len(split_y)):
- y = split_y[i]
- last_y = split_y[i-1]
- split_bbox = []
- for bbox in bboxes:
- if last_y <= bbox[1][1] <= y:
- split_bbox.append(bbox)
- split_bbox.sort
- for i in range(1, len(bboxes)):
- bbox = bboxes[i]
- if last_y <= bbox[1][1] <= y and last_y <= last_bbox[1][1] <= y:
- if abs(last_bbox[1][1] - bbox[0][1]) <= 2:
- last_bbox = bbox
- else:
- if last_bbox[1][1] > bbox[0][1]:
- last_bbox = bbox
- else:
- add_bbox = [(last_bbox[0][0], last_bbox[1][1]),
- (last_bbox[1][0], bbox[0][1])]
- add_bbox_list.append(add_bbox)
- last_y = y
- print("add_bbox_list", add_bbox_list)
- if add_bbox_list:
- bboxes = [str(x) for x in bboxes + add_bbox_list]
- bboxes = list(set(bboxes))
- bboxes = [eval(x) for x in bboxes]
- bboxes.sort(key=lambda x: (x[0][1], x[0][0]))
- return bboxes
- def points_to_line(points_lines, axis):
- new_line_list = []
- for line in points_lines:
- average = 0
- _min = _min = line[0][axis]
- _max = line[-1][axis]
- for point in line:
- average += point[1-axis]
- if point[axis] < _min:
- _min = point[axis]
- if point[axis] > _max:
- _max = point[axis]
- average = int(average / len(line))
- if axis:
- new_line = [average, _min, average, _max]
- else:
- new_line = [_min, average, _max, average]
- new_line_list.append(new_line)
- return new_line_list
- def get_bbox_by_contours(image_np):
- img_gray = cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY)
- ret, img_bin = cv2.threshold(img_gray, 127, 255, cv2.THRESH_BINARY)
- # 3.连通域分析
- img_bin, contours, hierarchy = cv2.findContours(img_bin,
- cv2.RETR_LIST,
- cv2.CHAIN_APPROX_SIMPLE)
- # 4.获取最小外接圆 圆心 半径
- center, radius = cv2.minEnclosingTriangle(contours[0])
- center = np.int0(center)
- # 5.绘制最小外接圆
- img_result = image_np.copy()
- cv2.circle(img_result, tuple(center), int(radius), (255, 255, 255), 2)
- # # 读入图片
- # img = image_np
- # cv2.imshow("get_bbox_by_contours ", image_np)
- # # 中值滤波,去噪
- # img = cv2.medianBlur(img, 3)
- # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- # cv2.namedWindow('original', cv2.WINDOW_AUTOSIZE)
- # cv2.imshow('original', gray)
- #
- # # 阈值分割得到二值化图片
- # ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
- #
- # # 膨胀操作
- # kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
- # bin_clo = cv2.dilate(binary, kernel2, iterations=2)
- #
- # # 连通域分析
- # num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_clo, connectivity=8)
- #
- # # 查看各个返回值
- # # 连通域数量
- # print('num_labels = ',num_labels)
- # # 连通域的信息:对应各个轮廓的x、y、width、height和面积
- # print('stats = ',stats)
- # # 连通域的中心点
- # print('centroids = ',centroids)
- # # 每一个像素的标签1、2、3.。。,同一个连通域的标签是一致的
- # print('labels = ',labels)
- #
- # # 不同的连通域赋予不同的颜色
- # output = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
- # for i in range(1, num_labels):
- #
- # mask = labels == i
- # output[:, :, 0][mask] = np.random.randint(0, 255)
- # output[:, :, 1][mask] = np.random.randint(0, 255)
- # output[:, :, 2][mask] = np.random.randint(0, 255)
- # cv2.imshow('oginal', output)
- # cv2.waitKey()
- # cv2.destroyAllWindows()
- def get_points_col(points, split_y, threshold=5):
- # 坐标点按行分
- row_point_list = []
- row_point = []
- points.sort(key=lambda x: (x[0], x[1]))
- # print("get_points_col points sort", points)
- x = points[0][0]
- for i in range(1, len(split_y)):
- for p in points:
- if p[1] <= split_y[i-1] or p[1] >= split_y[i]:
- continue
- if x-threshold <= p[0] <= x+threshold:
- row_point.append(p)
- else:
- # print("row_point", row_point)
- row_point.sort(key=lambda x: (x[1], x[0]))
- if row_point:
- row_point_list.append(row_point)
- row_point = []
- x = p[0]
- row_point.append(p)
- if row_point:
- row_point_list.append(row_point)
- return row_point_list
- def get_points_row(points, split_y, threshold=5):
- # 坐标点按列分
- col_point_list = []
- col_point = []
- points.sort(key=lambda x: (x[1], x[0]))
- y = points[0][1]
- for i in range(len(split_y)):
- for p in points:
- if p[1] <= split_y[i-1] or p[1] >= split_y[i]:
- continue
- if y-threshold <= p[1] <= y+threshold:
- col_point.append(p)
- else:
- col_point.sort(key=lambda x: (x[0], x[1]))
- if col_point:
- col_point_list.append(col_point)
- col_point = []
- y = p[1]
- col_point.append(p)
- if col_point:
- col_point_list.append(col_point)
- return col_point_list
- def get_outline_point(points, split_y):
- # 分割线纵坐标
- # print("get_outline_point split_y", split_y)
- if len(split_y) < 2:
- return []
- outline_2point = []
- points.sort(key=lambda x: (x[1], x[0]))
- for i in range(1, len(split_y)):
- area_points = []
- for point in points:
- if point[1] <= split_y[i-1] or point[1] >= split_y[i]:
- continue
- area_points.append(point)
- if area_points:
- area_points.sort(key=lambda x: (x[1], x[0]))
- outline_2point.append([area_points[0], area_points[-1]])
- return outline_2point
- # def merge_row(row_lines):
- # for row in row_lines:
- # for row1 in row_lines:
- def get_best_predict_size(image_np):
- sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128]
- min_len = 10000
- best_height = sizes[0]
- for height in sizes:
- if abs(image_np.shape[0] - height) < min_len:
- min_len = abs(image_np.shape[0] - height)
- best_height = height
- min_len = 10000
- best_width = sizes[0]
- for width in sizes:
- if abs(image_np.shape[1] - width) < min_len:
- min_len = abs(image_np.shape[1] - width)
- best_width = width
- return best_height, best_width
- def choose_longer_row(lines):
- new_row = []
- jump_row = []
- for i in range(len(lines)):
- row1 = lines[i]
- jump_flag = 0
- if row1 in jump_row:
- continue
- for j in range(i+1, len(lines)):
- row2 = lines[j]
- if row2 in jump_row:
- continue
- if row2[1]-5 <= row1[1] <= row2[1]+5:
- if row1[0] <= row2[0] and row1[2] >= row2[2]:
- new_row.append(row1)
- jump_row.append(row1)
- jump_row.append(row2)
- jump_flag = 1
- break
- elif row2[0] <= row1[0] and row2[2] >= row1[2]:
- new_row.append(row2)
- jump_row.append(row1)
- jump_row.append(row2)
- jump_flag = 1
- break
- if not jump_flag:
- new_row.append(row1)
- jump_row.append(row1)
- return new_row
- def choose_longer_col(lines):
- new_col = []
- jump_col = []
- for i in range(len(lines)):
- col1 = lines[i]
- jump_flag = 0
- if col1 in jump_col:
- continue
- for j in range(i+1, len(lines)):
- col2 = lines[j]
- if col2 in jump_col:
- continue
- if col2[0]-5 <= col1[0] <= col2[0]+5:
- if col1[1] <= col2[1] and col1[3] >= col2[3]:
- new_col.append(col1)
- jump_col.append(col1)
- jump_col.append(col2)
- jump_flag = 1
- break
- elif col2[1] <= col1[1] and col2[3] >= col1[3]:
- new_col.append(col2)
- jump_col.append(col1)
- jump_col.append(col2)
- jump_flag = 1
- break
- if not jump_flag:
- new_col.append(col1)
- jump_col.append(col1)
- return new_col
- def delete_contain_bbox(bboxes):
- # bbox互相包含,取小的bbox
- delete_bbox = []
- for i in range(len(bboxes)):
- for j in range(i+1, len(bboxes)):
- bbox1 = bboxes[i]
- bbox2 = bboxes[j]
- # 横坐标相等情况
- if bbox1[0][0] == bbox2[0][0] and bbox1[1][0] == bbox2[1][0]:
- if bbox1[0][1] <= bbox2[0][1] <= bbox2[1][1] <= bbox1[1][1]:
- # print("1", bbox1, bbox2)
- delete_bbox.append(bbox1)
- elif bbox2[0][1] <= bbox1[0][1] <= bbox1[1][1] <= bbox2[1][1]:
- # print("2", bbox1, bbox2)
- delete_bbox.append(bbox2)
- # 纵坐标相等情况
- elif bbox1[0][1] == bbox2[0][1] and bbox1[1][1] == bbox2[1][1]:
- if bbox1[0][0] <= bbox2[0][0] <= bbox2[1][0] <= bbox1[1][0]:
- print("3", bbox1, bbox2)
- delete_bbox.append(bbox1)
- elif bbox2[0][0] <= bbox1[0][0] <= bbox1[1][0] <= bbox2[1][0]:
- print("4", bbox1, bbox2)
- delete_bbox.append(bbox2)
- print("delete_contain_bbox len(bboxes)", len(bboxes))
- print("delete_contain_bbox len(delete_bbox)", len(delete_bbox))
- for bbox in delete_bbox:
- if bbox in bboxes:
- bboxes.remove(bbox)
- print("delete_contain_bbox len(bboxes)", len(bboxes))
- return bboxes
- if __name__ == '__main__':
- # p = "开标记录表3_page_0.png"
- # p = "train_data/label_1.jpg"
- # p = "test_files/train_463.jpg"
- p = "test_files/8.png"
- # p = "test_files/无边框3.jpg"
- # p = "test_files/part1.png"
- # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00e959a0bc9011ebaf5a00163e0ae709" + \
- # "\\00e95f7cbc9011ebaf5a00163e0ae709_pdf_page0.png"
- # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \
- # "\\00fb43acbc7e11eb836000163e0ae709.png"
- # p = "test_files/table.jpg"
- # p = "data_process/create_data/0.jpg"
- # p = "../format_conversion/temp/f1fe9c4ac8e511eb81d700163e0857b6/f1fea1e0c8e511eb81d700163e0857b6.png"
- # p = "../format_conversion/1.png"
- img = cv2.imread(p)
- t = time.time()
- model.load_weights("")
- best_h, best_w = get_best_predict_size(img)
- print(img.shape)
- print((best_h, best_w))
- # row_boxes, col_boxes = table_line(img[..., ::-1], model, size=(512, 1024), hprob=0.5, vprob=0.5)
- # row_boxes, col_boxes, img = table_line(img[..., ::-1], model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
- row_boxes, col_boxes, img = table_line(img, model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
- print("len(row_boxes)", len(row_boxes))
- print("len(col_boxes)", col_boxes)
- # 创建空图
- test_img = np.zeros((img.shape), np.uint8)
- test_img.fill(255)
- for box in row_boxes+col_boxes:
- cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
- cv2.imshow("test_image", test_img)
- cv2.waitKey(0)
- cv2.imwrite("temp.jpg", test_img)
- # 计算交点、分割线
- crossover_points = get_points(row_boxes, col_boxes, (img.shape[0], img.shape[1]))
- print("len(col_boxes)", len(col_boxes))
- split_lines, split_y = get_split_line(crossover_points, col_boxes, img)
- print("split_y", split_y)
- # for point in crossover_points:
- # cv2.circle(test_img, point, 1, (0, 255, 0), 3)
- # cv2.imshow("point image1", test_img)
- # cv2.waitKey(0)
- # 计算行列,剔除相近交点
- row_point_list = get_points_row(crossover_points, split_y, 0)
- col_point_list = get_points_col(crossover_points, split_y, 0)
- crossover_points = delete_close_points(crossover_points, row_point_list, col_point_list)
- row_point_list = get_points_row(crossover_points, split_y)
- col_point_list = get_points_col(crossover_points, split_y)
- for point in crossover_points:
- cv2.circle(test_img, point, 1, (0, 0, 255), 3)
- cv2.imshow("point image1", test_img)
- cv2.waitKey(0)
- print("len(row_boxes)", len(row_boxes))
- print("len(col_boxes)", len(col_boxes))
- # 修复边框
- new_row_boxes, new_col_boxes, long_row_boxes, long_col_boxes = \
- fix_outline(img, row_boxes, col_boxes, crossover_points, split_y)
- if new_row_boxes or new_col_boxes:
- if long_row_boxes:
- print("long_row_boxes", long_row_boxes)
- row_boxes = long_row_boxes
- if long_col_boxes:
- print("long_col_boxes", long_col_boxes)
- col_boxes = long_col_boxes
- if new_row_boxes:
- row_boxes += new_row_boxes
- print("new_row_boxes", new_row_boxes)
- if new_col_boxes:
- print("new_col_boxes", new_col_boxes)
- col_boxes += new_col_boxes
- # print("len(row_boxes)", len(row_boxes))
- # print("len(col_boxes)", len(col_boxes))
- # row_boxes += new_row_boxes
- # col_boxes += new_col_boxes
- # row_boxes = choose_longer_row(row_boxes)
- # col_boxes = choose_longer_col(col_boxes)
- # 创建空图
- test_img = np.zeros((img.shape), np.uint8)
- test_img.fill(255)
- for box in row_boxes+col_boxes:
- cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
- cv2.imshow("test_image2", test_img)
- cv2.waitKey(0)
- # 展示补线
- for row in new_row_boxes:
- cv2.line(test_img, (int(row[0]), int(row[1])),
- (int(row[2]), int(row[3])), (0, 0, 255), 1)
- for col in new_col_boxes:
- cv2.line(test_img, (int(col[0]), int(col[1])),
- (int(col[2]), int(col[3])), (0, 0, 255), 1)
- cv2.imshow("fix_outline", test_img)
- cv2.waitKey(0)
- cv2.imwrite("temp.jpg", test_img)
- # 修复边框后重新计算交点、分割线
- print("crossover_points", len(crossover_points))
- crossover_points = get_points(row_boxes, col_boxes, (img.shape[0], img.shape[1]))
- print("crossover_points new", len(crossover_points))
- split_lines, split_y = get_split_line(crossover_points, col_boxes, img)
- # 计算行列,剔除相近交点
- row_point_list = get_points_row(crossover_points, split_y, 0)
- col_point_list = get_points_col(crossover_points, split_y, 0)
- print(len(crossover_points), len(row_point_list), len(col_point_list))
- crossover_points = delete_close_points(crossover_points, row_point_list, col_point_list)
- print(len(crossover_points), len(row_point_list), len(col_point_list))
- row_point_list = get_points_row(crossover_points, split_y)
- col_point_list = get_points_col(crossover_points, split_y)
- for point in crossover_points:
- cv2.circle(test_img, point, 1, (0, 255, 0), 3)
- cv2.imshow("point image2", test_img)
- cv2.waitKey(0)
- # 获取每个表格的左上右下两个点
- outline_point = get_outline_point(crossover_points, split_y)
- # print(outline_point)
- for outline in outline_point:
- cv2.circle(test_img, outline[0], 1, (255, 0, 0), 5)
- cv2.circle(test_img, outline[1], 1, (255, 0, 0), 5)
- cv2.imshow("outline point", test_img)
- cv2.waitKey(0)
- # 获取bbox
- # get_bbox(img, crossover_points, split_y)
- # for point in crossover_points:
- # cv2.circle(test_img, point, 1, (0, 255, 0), 3)
- # cv2.imshow("point image3", test_img)
- # cv2.waitKey(0)
- # split_y = []
- # for outline in outline_point:
- # split_y.extend([outline[0][1]-5, outline[1][1]+5])
- print("len(row_boxes)", len(row_boxes))
- print("len(col_boxes)", len(col_boxes))
- bboxes = get_bbox(img, row_point_list, col_point_list, split_y, row_boxes, col_boxes)
- # 展示
- for box in bboxes:
- # print(box[0], box[1])
- # if abs(box[0][1] - box[1][1]) > abs(box[0][0] - box[1][0]):
- # continue
- cv2.rectangle(test_img, box[0], box[1], (0, 0, 255), 2, 8)
- cv2.imshow('bboxes', test_img)
- cv2.waitKey(0)
- # img = draw_lines(img, row_boxes+col_boxes, color=(255, 0, 0), lineW=2)
- # img = draw_boxes(img, rowboxes+colboxes, color=(0, 0, 255))
- print(time.time()-t, len(row_boxes), len(col_boxes))
- cv2.imwrite('temp.jpg', test_img)
- # cv2.imshow('main', img)
- # cv2.waitKey(0)
|