123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550 |
- import torch
- import sys
- import os
- sys.path.insert(0,"../")
- from torch.utils.data.dataset import Dataset
- from PIL import Image
- from random import randint,random
- from torch import nn,optim
- from torch.nn import functional as F
- import numpy as np
- import traceback
- from RotateMatch.rotate import *
- import math
- from multiprocessing import Process,Queue,RLock
- device = "cuda" if torch.cuda.is_available() else "cpu"
- from random import random
- import time
- import pickle
- def save(object_to_save, path):
- '''
- 保存对象
- @Arugs:
- object_to_save: 需要保存的对象
- @Return:
- 保存的路径
- '''
- with open(path, 'wb') as f:
- pickle.dump(object_to_save, f)
- def load(path):
- '''
- 读取对象
- @Arugs:
- path: 读取的路径
- @Return:
- 读取的对象
- '''
- with open(path, 'rb') as f:
- object1 = pickle.load(f)
- return object1
- class RotateDataSet(Dataset):
- def __init__(self,train,loadData=True):
- self.train = train
- self.input_size = (600,360)
- if loadData and os.path.exists(self.data_path):
- self.data_path = "F:\\Workspace2016\\ContentExtract\\test\\images"
- # self.data_path = "/home/python/luojiehua/RotateMatch/images"
- self.data_length = 1000
- self.data_begin = 8000
-
-
- self.data_queue = Queue()
-
- self.task_queue = Queue()
- self.train_len = int(self.data_length*0.9)
- self.test_len = self.data_length-self.train_len
- self.list_path = os.listdir(self.data_path)
- self.list_path.sort(key=lambda x:x)
- if self.train:
- self.filter_path = self.list_path[self.data_begin:self.data_begin+self.train_len]
- else:
- self.filter_path = self.list_path[self.data_begin+self.train_len:self.data_begin+self.data_length]
- self.next_flag = False
- self.data = []
- self.imgs = []
- self.count_0 = 0
- self.current_index = -1
- self.process_count = 0
- if not self.train:
- self.process_count = 0
- def readImgs(self,list_data,list_path):
- for filename in list_path:
- try:
- print(filename)
- _filepath = os.path.join(self.data_path,filename)
- # im:Image.Image = Image.open(_filepath,"r")
- list_data.append(_filepath)
- except Exception as e:
- traceback.print_exc()
- def imageTailor(self,im,totailor=True):
- # print("size",im.size,self.input_size)
- if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
- r_image = im.copy()
- else:
- r_image = im.resize(self.input_size)
- r_image = r_image.convert("RGB")
- if totailor:
- r_image = tailor(r_image,randint(90,250),randint(5,self.input_size[0]-5))
- # r_image.show()
- # _array = np.array(list(r_image.getdata()))
- _array = np.array(r_image)
- # print("array shape",_array.shape)
- _array = _array.transpose([2,0,1])
- return _array/255
- def rotateImage_fast(self,im,center_im,degree):
- if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
- r_image = im.copy()
- else:
- r_image = im.resize(self.input_size)
- r_image = imageRotate(r_image,center_im,degree)
- r_image = r_image.convert("RGB")
- # r_image.show()
- _array = np.array(r_image)
- print("array shape",_array.shape)
- _array = _array.transpose([2,0,1])
- return _array/255
- def image_rotate(self,im,degree,appendColor):
- if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
- r_image = im.copy()
- else:
- r_image = im.resize(self.input_size)
- center_im = circle_center(r_image,appendColor)#有点问题
- r_image = imageRotate(r_image,center_im,degree)
- return r_image
- def rotateImage(self,im,degree,appendColor):
- if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
- r_image = im.copy()
- else:
- r_image = im.resize(self.input_size)
- center_im = circle_center(r_image,appendColor)#有点问题
- r_image = imageRotate(r_image,center_im,degree)
- r_image = r_image.convert("RGB")
- # r_image.show()
- _array = np.array(r_image)
- print("array shape",_array.shape)
- _array = _array.transpose([2,0,1])
- return _array/255
- def imageProcess(self,im,degree,circle=True):
- if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
- r_image = im.copy()
- else:
- r_image = im.resize(self.input_size)
- if degree==0:
- center_im = circle_center(r_image)#有点问题
- r_image = imageRotate(r_image,center_im,degree)
- pass
- else:
- if circle:
- center_im = circle_center(r_image)
- r_image = imageRotate(r_image,center_im,degree)
- else:
- r_image = tailor(r_image,randint(90,140),randint(20,self.input_size[0]-20))
- # background_pin = background.load()
- r_image.show()
- _array = np.array(list(r_image.getdata()))
- _array.resize((*self.input_size,3))
- return _array/255
- def getPosLabel(self,im,appendColor=True):
- if random()>0.5:
- degree = randint(1,5)
- else:
- degree = randint(355,359)
- self.next_flag = False
- label = 1
- _r = random()
- if _r>0.7:
- im = im.rotate(90)
- elif _r>0.5:
- im = im.rotate(180)
- print(label)
- # im = self.imageProcess(im,degree,circle=True)
- # im = self.imageTailor(im,False)
- im = self.rotateImage(im,degree,appendColor)
- return im,np.array(label,dtype='long')
- def getNegLabel(self,im,appendColor=True):
- _r = random()
- degree = randint(5,355)
- label = 0
- self.next_flag = True
- if _r>0.5:
- circle = False
- print(label)
- # im = self.imageProcess(im,degree,circle=False)
- # im = self.imageTailor(im,True)
- im = self.rotateImage(im,degree,appendColor)
- return im,np.array(label,dtype='long')
- def getPosLabel_rotate(self,im,appendColor=True):
- degree = randint(1,360)
- self.next_flag = False
- label = 1
- _r = random()
- if _r>0.7:
- im = im.rotate(90)
- elif _r>0.5:
- im = im.rotate(180)
- print(label)
- # im = self.imageProcess(im,degree,circle=True)
- # im = self.imageTailor(im,False)
- im = self.image_rotate(im,degree,appendColor)
- im = self.rotateImage(im,randint(358,363)-degree,appendColor)
- return im,np.array(label,dtype='long')
- def getNegLabel_rotate(self,im,appendColor=True):
- _r = random()
- degree = randint(5,355)
- label = 0
- self.next_flag = True
- if _r>0.5:
- circle = False
- print(label)
- # im = self.imageProcess(im,degree,circle=False)
- # im = self.imageTailor(im,True)
- im = self.rotateImage(im,degree,appendColor)
- return im,np.array(label,dtype='long')
- def generateData(self,list_data,list_imgs):
- print("generate data")
- for _filepath in list_imgs:
- try:
- print(_filepath)
- im:Image.Image = Image.open(_filepath,"r")
- # _pos = self.getPosLabel(im,False)
- # _neg = self.getNegLabel(im,False)
- # list_data.append((_pos))
- # list_data.append((_neg))
- _pos = self.getPosLabel_rotate(im,True)
- _neg = self.getNegLabel_rotate(im,True)
- list_data.append((_pos))
- list_data.append((_neg))
- except Exception as e:
- traceback.print_exc()
- # save(list_data,"rotate_train_%d_%d.pk"%(self.data_begin,self.data_length))
- def start_prepare(self):
- list_process = []
- for _ in range(self.process_count):
- p = Process(target=self.process_prepare,args=([self.task_queue,self.data_queue]))
- list_process.append(p)
- for p in list_process:
- p.start()
- while 1:
- for p in list_process:
- print(p.is_alive())
- time.sleep(5)
- print("process done")
- while 1:
- try:
- _data = self.data_queue.get()
- self.data.append(_data)
- except Exception as e:
- pass
- def process_prepare(self,task_queue,_queue):
- while 1:
- try:
- print(task_queue.qsize())
- if task_queue.qsize()==0:
- return
- _filepath = task_queue.get(block=True,timeout=1)
- im_src:Image.Image = Image.open(_filepath,"r")
- if random.random()>0.5:
- degree = 0
- self.next_flag = False
- label = 1
- circle = True
- else:
- degree = randint(10,350)
- label = 0
- self.next_flag = True
- if random.random()>0.5:
- circle = True
- else:
- circle = False
- im = self.imageProcess(im_src,degree,circle=True)
- im = im/255
- _queue.put((im,np.array(label,dtype='long')))
- del im_src
- except Exception as e:
- break
- pass
- def __getitem__(self, index):
- # if index==0:
- # self.count_0+=1
- # self.count_0%=5
- # if self.count_0==0:
- # self.data = []
- # self.generateData(self.data,self.filter_path)
- self.prepareData()
- if self.process_count>0:
- return self.data_queue.get()
- return self.data[index]
- def prepareData(self):
- if len(self.imgs)==0:
- self.readImgs(self.imgs,self.filter_path)
- if self.process_count==0:
- self.generateData(self.data,self.imgs)
- else:
- for _im in self.imgs:
- self.task_queue.put(_im)
- self.start_prepare()
- pass
- def __len__(self):
- self.prepareData()
- return len(self.data)
- class CNNNet(nn.Module):
- def __init__(self,*args,**kwargs):
- super(CNNNet,self).__init__(*args,**kwargs)
- self.con1 = nn.Conv2d(3,5,5,bias=False)
- self.maxpool_3 = nn.MaxPool2d(3)
- self.con2 = nn.Conv2d(5,10,5,bias=False)
- self.con3 = nn.Conv2d(10,15,5,bias=False)
- self.fc1 = nn.Linear(3300,2)
- self.softmax = torch.nn.Softmax(dim=1)
- def forward(self,x):
- batch_size = x.size(0)
- x = x.float()
- x = x.to(device)
- x = self.maxpool_3(F.relu(self.con1(x)))
- x = self.maxpool_3(F.relu(self.con2(x)))
- x = self.maxpool_3(F.relu(self.con3(x)))
- x = x.view(batch_size,-1)
- out = self.fc1(x)
- out = self.softmax(out)
- return out
- class TorchTrainer():
- def __init__(self,epochs,net,optimizer,loss_fn,train_loader,test_loader,model_ckpt):
- self.epochs = epochs
- self.net = net
- self.optimizer = optimizer
- self.loss_fn = loss_fn
- self.train_loader = train_loader
- self.test_loader = test_loader
- self.model_ckpt = model_ckpt
- self.best_acc = None
- self.best_loss = None
- def test(self):
- ## 准确度测试
- total_correct = 0
- total_loss = 0
- total_num = len(self.test_loader.dataset)
- for x,y in self.test_loader:
- # x = x.view(x.size(0),28*28)
- out = self.net(x)
- if device=="cuda":
- y = y.to(device)
- else:
- y = y.to(torch.long)
- # out = torch.nn.Softmax()(out)
- pred = out.to(device).argmax(dim = 1)
- # print("=1",y)
- # print("=2",out)
- correct = pred.eq(y).sum().float().item() # .float之后还是tensor类型,要拿到数据需要使用item()
- total_correct += correct
- y = y.to(torch.long)
- loss = self.loss_fn(out, y)
- _loss = loss.item()
- total_loss += _loss
- acc = total_correct/total_num
- avg_loss = total_loss/total_num
- return acc,avg_loss
- def train(self):
- import sys
- train_loss = []
- print("start training")
- for epoch in range(self.epochs):
- epoch_loss = 0
- train_len = len(self.train_loader.dataset)
- print("train len:%d"%(train_len))
- batch_len = math.floor(len(self.train_loader.dataset)/self.train_loader.batch_size)
- total_correct = 0
- for batch_idx, (x, y) in enumerate(self.train_loader):
- # x:[b,1,28,28],y:[512]
- # [b,1,28,28] => [b,784]
- # x = x.view(x.size(0), 28 * 28)
- # =>[b,10]
- # out = net(x)
- # 清零梯度
- self.optimizer.zero_grad()
- out = self.net(x)
- # loss = mse(out,y_onehot)
- if device=="cuda":
- y = y.to(device)
- else:
- y = y.to(torch.long)
- loss = self.loss_fn(out, y)
- pred = out.argmax(dim=1)
- correct = pred.eq(y).sum().float().item()
- total_correct += correct
- # 计算梯度
- loss.backward()
- # w' = w -lr*grad
- # 更新梯度,得到新的[w1,b1,w2,b2,w3,b3]
- self.optimizer.step()
- _loss = loss.item()
- epoch_loss += _loss
- train_loss.append(_loss)
- if batch_idx % 2 == 0:
- print("epoch:%d batch_idx:%d loss:%.3f"%(epoch, batch_idx, loss.item()))
- print("epcho %d train_acc%.3f val_loss%.3f"%(epoch,total_correct/train_len,epoch_loss/train_len))
- val_acc,val_loss = self.test()
- print("epcho %d val_acc%.3f val_loss%.3f"%(epoch,val_acc,val_loss))
- self.saveModel(epoch,val_acc,val_loss)
- def saveModel(self,epoch,val_acc,val_loss):
- save_flag = False
- if self.best_acc is None:
- self.best_acc = val_acc
- self.best_loss = val_loss
- save_flag = True
- else:
- if val_acc>=self.best_acc and val_loss<self.best_loss:
- self.best_acc = val_acc
- self.best_loss = val_loss
- save_flag = True
- if save_flag:
- torch.save(self.net,"%s_epoch%d_acc%.3f_loss%.3f.pt"%(self.model_ckpt,epoch,self.best_acc,self.best_loss))
- def trainTorch():
- cnnnet = torch.load("model/rotate_epoch3_acc0.970_loss0.017.pt",map_location=torch.device(device))
- _d = cnnnet.state_dict()
- cnnnet = CNNNet()
- cnnnet.load_state_dict(_d,True)
- cnnnet = cnnnet.to(device)
- optimizer = optim.SGD(cnnnet.parameters(),lr=0.01)
- train_loader = torch.utils.data.DataLoader(RotateDataSet(train=True),batch_size=20)
- test_loader = torch.utils.data.DataLoader(RotateDataSet(train=False),batch_size=20)
- loss_fn = nn.CrossEntropyLoss()
- trainer = TorchTrainer(100,cnnnet,optimizer,loss_fn,train_loader,test_loader,"model/rotate")
- trainer.train()
- import keras
- from keras.layers import *
- def getKerasModel():
- input = keras.models.Input(shape=(600,360,3))
- conv1 = Conv2D(5,(3,3),activation="relu")(input)
- maxpool1 = MaxPool2D((2,2))(conv1)
- conv2 = Conv2D(10,(3,3),activation="relu")(maxpool1)
- maxpool2 = MaxPool2D((5,5))(conv2)
- fla = Flatten()(maxpool2)
- out = Dense(2,activation="softmax")(fla)
- model = keras.models.Model(inputs=[input],outputs=[out])
- model.summary()
- model.compile("adam",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
- return model
- def trainKeras():
- train_loader = torch.utils.data.DataLoader(RotateDataSet(train=True),batch_size=60)
- test_loader = torch.utils.data.DataLoader(RotateDataSet(train=False),batch_size=512)
- list_x = []
- list_y = []
- for batch_idx, (x, y) in enumerate(train_loader):
- print(type(x))
- list_x.extend(x.numpy())
- list_y.extend(y.numpy())
- val_x = []
- val_y = []
- for batch_index,(x,y) in enumerate(test_loader):
- val_x.extend(x.numpy())
- val_y.extend(y.numpy())
- model = getKerasModel()
- model.fit(np.array(list_x),np.array(list_y),epochs=100,validation_data=(np.array(val_x),np.array(val_y)),callbacks=[keras.callbacks.ModelCheckpoint(filepath="keras")])
- if __name__ == '__main__':
- trainTorch()
- # trainKeras()
|