import torch import sys import os sys.path.insert(0,"../") from torch.utils.data.dataset import Dataset from PIL import Image from random import randint,random from torch import nn,optim from torch.nn import functional as F import numpy as np import traceback from RotateMatch.rotate import * import math from multiprocessing import Process,Queue,RLock device = "cuda" if torch.cuda.is_available() else "cpu" from random import random import time import pickle def save(object_to_save, path): ''' 保存对象 @Arugs: object_to_save: 需要保存的对象 @Return: 保存的路径 ''' with open(path, 'wb') as f: pickle.dump(object_to_save, f) def load(path): ''' 读取对象 @Arugs: path: 读取的路径 @Return: 读取的对象 ''' with open(path, 'rb') as f: object1 = pickle.load(f) return object1 class RotateDataSet(Dataset): def __init__(self,train,loadData=True): self.train = train self.input_size = (600,360) if loadData and os.path.exists(self.data_path): self.data_path = "F:\\Workspace2016\\ContentExtract\\test\\images" # self.data_path = "/home/python/luojiehua/RotateMatch/images" self.data_length = 1000 self.data_begin = 8000 self.data_queue = Queue() self.task_queue = Queue() self.train_len = int(self.data_length*0.9) self.test_len = self.data_length-self.train_len self.list_path = os.listdir(self.data_path) self.list_path.sort(key=lambda x:x) if self.train: self.filter_path = self.list_path[self.data_begin:self.data_begin+self.train_len] else: self.filter_path = self.list_path[self.data_begin+self.train_len:self.data_begin+self.data_length] self.next_flag = False self.data = [] self.imgs = [] self.count_0 = 0 self.current_index = -1 self.process_count = 0 if not self.train: self.process_count = 0 def readImgs(self,list_data,list_path): for filename in list_path: try: print(filename) _filepath = os.path.join(self.data_path,filename) # im:Image.Image = Image.open(_filepath,"r") list_data.append(_filepath) except Exception as e: traceback.print_exc() def imageTailor(self,im,totailor=True): # print("size",im.size,self.input_size) if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]: r_image = im.copy() else: r_image = im.resize(self.input_size) r_image = r_image.convert("RGB") if totailor: r_image = tailor(r_image,randint(90,250),randint(5,self.input_size[0]-5)) # r_image.show() # _array = np.array(list(r_image.getdata())) _array = np.array(r_image) # print("array shape",_array.shape) _array = _array.transpose([2,0,1]) return _array/255 def rotateImage_fast(self,im,center_im,degree): if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]: r_image = im.copy() else: r_image = im.resize(self.input_size) r_image = imageRotate(r_image,center_im,degree) r_image = r_image.convert("RGB") # r_image.show() _array = np.array(r_image) print("array shape",_array.shape) _array = _array.transpose([2,0,1]) return _array/255 def image_rotate(self,im,degree,appendColor): if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]: r_image = im.copy() else: r_image = im.resize(self.input_size) center_im = circle_center(r_image,appendColor)#有点问题 r_image = imageRotate(r_image,center_im,degree) return r_image def rotateImage(self,im,degree,appendColor): if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]: r_image = im.copy() else: r_image = im.resize(self.input_size) center_im = circle_center(r_image,appendColor)#有点问题 r_image = imageRotate(r_image,center_im,degree) r_image = r_image.convert("RGB") # r_image.show() _array = np.array(r_image) print("array shape",_array.shape) _array = _array.transpose([2,0,1]) return _array/255 def imageProcess(self,im,degree,circle=True): if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]: r_image = im.copy() else: r_image = im.resize(self.input_size) if degree==0: center_im = circle_center(r_image)#有点问题 r_image = imageRotate(r_image,center_im,degree) pass else: if circle: center_im = circle_center(r_image) r_image = imageRotate(r_image,center_im,degree) else: r_image = tailor(r_image,randint(90,140),randint(20,self.input_size[0]-20)) # background_pin = background.load() r_image.show() _array = np.array(list(r_image.getdata())) _array.resize((*self.input_size,3)) return _array/255 def getPosLabel(self,im,appendColor=True): if random()>0.5: degree = randint(1,5) else: degree = randint(355,359) self.next_flag = False label = 1 _r = random() if _r>0.7: im = im.rotate(90) elif _r>0.5: im = im.rotate(180) print(label) # im = self.imageProcess(im,degree,circle=True) # im = self.imageTailor(im,False) im = self.rotateImage(im,degree,appendColor) return im,np.array(label,dtype='long') def getNegLabel(self,im,appendColor=True): _r = random() degree = randint(5,355) label = 0 self.next_flag = True if _r>0.5: circle = False print(label) # im = self.imageProcess(im,degree,circle=False) # im = self.imageTailor(im,True) im = self.rotateImage(im,degree,appendColor) return im,np.array(label,dtype='long') def getPosLabel_rotate(self,im,appendColor=True): degree = randint(1,360) self.next_flag = False label = 1 _r = random() if _r>0.7: im = im.rotate(90) elif _r>0.5: im = im.rotate(180) print(label) # im = self.imageProcess(im,degree,circle=True) # im = self.imageTailor(im,False) im = self.image_rotate(im,degree,appendColor) im = self.rotateImage(im,randint(358,363)-degree,appendColor) return im,np.array(label,dtype='long') def getNegLabel_rotate(self,im,appendColor=True): _r = random() degree = randint(5,355) label = 0 self.next_flag = True if _r>0.5: circle = False print(label) # im = self.imageProcess(im,degree,circle=False) # im = self.imageTailor(im,True) im = self.rotateImage(im,degree,appendColor) return im,np.array(label,dtype='long') def generateData(self,list_data,list_imgs): print("generate data") for _filepath in list_imgs: try: print(_filepath) im:Image.Image = Image.open(_filepath,"r") # _pos = self.getPosLabel(im,False) # _neg = self.getNegLabel(im,False) # list_data.append((_pos)) # list_data.append((_neg)) _pos = self.getPosLabel_rotate(im,True) _neg = self.getNegLabel_rotate(im,True) list_data.append((_pos)) list_data.append((_neg)) except Exception as e: traceback.print_exc() # save(list_data,"rotate_train_%d_%d.pk"%(self.data_begin,self.data_length)) def start_prepare(self): list_process = [] for _ in range(self.process_count): p = Process(target=self.process_prepare,args=([self.task_queue,self.data_queue])) list_process.append(p) for p in list_process: p.start() while 1: for p in list_process: print(p.is_alive()) time.sleep(5) print("process done") while 1: try: _data = self.data_queue.get() self.data.append(_data) except Exception as e: pass def process_prepare(self,task_queue,_queue): while 1: try: print(task_queue.qsize()) if task_queue.qsize()==0: return _filepath = task_queue.get(block=True,timeout=1) im_src:Image.Image = Image.open(_filepath,"r") if random.random()>0.5: degree = 0 self.next_flag = False label = 1 circle = True else: degree = randint(10,350) label = 0 self.next_flag = True if random.random()>0.5: circle = True else: circle = False im = self.imageProcess(im_src,degree,circle=True) im = im/255 _queue.put((im,np.array(label,dtype='long'))) del im_src except Exception as e: break pass def __getitem__(self, index): # if index==0: # self.count_0+=1 # self.count_0%=5 # if self.count_0==0: # self.data = [] # self.generateData(self.data,self.filter_path) self.prepareData() if self.process_count>0: return self.data_queue.get() return self.data[index] def prepareData(self): if len(self.imgs)==0: self.readImgs(self.imgs,self.filter_path) if self.process_count==0: self.generateData(self.data,self.imgs) else: for _im in self.imgs: self.task_queue.put(_im) self.start_prepare() pass def __len__(self): self.prepareData() return len(self.data) class CNNNet(nn.Module): def __init__(self,*args,**kwargs): super(CNNNet,self).__init__(*args,**kwargs) self.con1 = nn.Conv2d(3,5,5,bias=False) self.maxpool_3 = nn.MaxPool2d(3) self.con2 = nn.Conv2d(5,10,5,bias=False) self.con3 = nn.Conv2d(10,15,5,bias=False) self.fc1 = nn.Linear(3300,2) self.softmax = torch.nn.Softmax(dim=1) def forward(self,x): batch_size = x.size(0) x = x.float() x = x.to(device) x = self.maxpool_3(F.relu(self.con1(x))) x = self.maxpool_3(F.relu(self.con2(x))) x = self.maxpool_3(F.relu(self.con3(x))) x = x.view(batch_size,-1) out = self.fc1(x) out = self.softmax(out) return out class TorchTrainer(): def __init__(self,epochs,net,optimizer,loss_fn,train_loader,test_loader,model_ckpt): self.epochs = epochs self.net = net self.optimizer = optimizer self.loss_fn = loss_fn self.train_loader = train_loader self.test_loader = test_loader self.model_ckpt = model_ckpt self.best_acc = None self.best_loss = None def test(self): ## 准确度测试 total_correct = 0 total_loss = 0 total_num = len(self.test_loader.dataset) for x,y in self.test_loader: # x = x.view(x.size(0),28*28) out = self.net(x) if device=="cuda": y = y.to(device) else: y = y.to(torch.long) # out = torch.nn.Softmax()(out) pred = out.to(device).argmax(dim = 1) # print("=1",y) # print("=2",out) correct = pred.eq(y).sum().float().item() # .float之后还是tensor类型,要拿到数据需要使用item() total_correct += correct y = y.to(torch.long) loss = self.loss_fn(out, y) _loss = loss.item() total_loss += _loss acc = total_correct/total_num avg_loss = total_loss/total_num return acc,avg_loss def train(self): import sys train_loss = [] print("start training") for epoch in range(self.epochs): epoch_loss = 0 train_len = len(self.train_loader.dataset) print("train len:%d"%(train_len)) batch_len = math.floor(len(self.train_loader.dataset)/self.train_loader.batch_size) total_correct = 0 for batch_idx, (x, y) in enumerate(self.train_loader): # x:[b,1,28,28],y:[512] # [b,1,28,28] => [b,784] # x = x.view(x.size(0), 28 * 28) # =>[b,10] # out = net(x) # 清零梯度 self.optimizer.zero_grad() out = self.net(x) # loss = mse(out,y_onehot) if device=="cuda": y = y.to(device) else: y = y.to(torch.long) loss = self.loss_fn(out, y) pred = out.argmax(dim=1) correct = pred.eq(y).sum().float().item() total_correct += correct # 计算梯度 loss.backward() # w' = w -lr*grad # 更新梯度,得到新的[w1,b1,w2,b2,w3,b3] self.optimizer.step() _loss = loss.item() epoch_loss += _loss train_loss.append(_loss) if batch_idx % 2 == 0: print("epoch:%d batch_idx:%d loss:%.3f"%(epoch, batch_idx, loss.item())) print("epcho %d train_acc%.3f val_loss%.3f"%(epoch,total_correct/train_len,epoch_loss/train_len)) val_acc,val_loss = self.test() print("epcho %d val_acc%.3f val_loss%.3f"%(epoch,val_acc,val_loss)) self.saveModel(epoch,val_acc,val_loss) def saveModel(self,epoch,val_acc,val_loss): save_flag = False if self.best_acc is None: self.best_acc = val_acc self.best_loss = val_loss save_flag = True else: if val_acc>=self.best_acc and val_loss