torchRotate.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. import torch
  2. import sys
  3. import os
  4. sys.path.insert(0,"../")
  5. from torch.utils.data.dataset import Dataset
  6. from PIL import Image
  7. from random import randint,random
  8. from torch import nn,optim
  9. from torch.nn import functional as F
  10. import numpy as np
  11. import traceback
  12. from RotateMatch.rotate import *
  13. import math
  14. from multiprocessing import Process,Queue,RLock
  15. device = "cuda" if torch.cuda.is_available() else "cpu"
  16. from random import random
  17. import time
  18. import pickle
  19. def save(object_to_save, path):
  20. '''
  21. 保存对象
  22. @Arugs:
  23. object_to_save: 需要保存的对象
  24. @Return:
  25. 保存的路径
  26. '''
  27. with open(path, 'wb') as f:
  28. pickle.dump(object_to_save, f)
  29. def load(path):
  30. '''
  31. 读取对象
  32. @Arugs:
  33. path: 读取的路径
  34. @Return:
  35. 读取的对象
  36. '''
  37. with open(path, 'rb') as f:
  38. object1 = pickle.load(f)
  39. return object1
  40. class RotateDataSet(Dataset):
  41. def __init__(self,train,loadData=True):
  42. self.train = train
  43. self.input_size = (600,360)
  44. if loadData and os.path.exists(self.data_path):
  45. self.data_path = "F:\\Workspace2016\\ContentExtract\\test\\images"
  46. # self.data_path = "/home/python/luojiehua/RotateMatch/images"
  47. self.data_length = 1000
  48. self.data_begin = 8000
  49. self.data_queue = Queue()
  50. self.task_queue = Queue()
  51. self.train_len = int(self.data_length*0.9)
  52. self.test_len = self.data_length-self.train_len
  53. self.list_path = os.listdir(self.data_path)
  54. self.list_path.sort(key=lambda x:x)
  55. if self.train:
  56. self.filter_path = self.list_path[self.data_begin:self.data_begin+self.train_len]
  57. else:
  58. self.filter_path = self.list_path[self.data_begin+self.train_len:self.data_begin+self.data_length]
  59. self.next_flag = False
  60. self.data = []
  61. self.imgs = []
  62. self.count_0 = 0
  63. self.current_index = -1
  64. self.process_count = 0
  65. if not self.train:
  66. self.process_count = 0
  67. def readImgs(self,list_data,list_path):
  68. for filename in list_path:
  69. try:
  70. print(filename)
  71. _filepath = os.path.join(self.data_path,filename)
  72. # im:Image.Image = Image.open(_filepath,"r")
  73. list_data.append(_filepath)
  74. except Exception as e:
  75. traceback.print_exc()
  76. def imageTailor(self,im,totailor=True):
  77. # print("size",im.size,self.input_size)
  78. if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
  79. r_image = im.copy()
  80. else:
  81. r_image = im.resize(self.input_size)
  82. r_image = r_image.convert("RGB")
  83. if totailor:
  84. r_image = tailor(r_image,randint(90,250),randint(5,self.input_size[0]-5))
  85. # r_image.show()
  86. # _array = np.array(list(r_image.getdata()))
  87. _array = np.array(r_image)
  88. # print("array shape",_array.shape)
  89. _array = _array.transpose([2,0,1])
  90. return _array/255
  91. def rotateImage_fast(self,im,center_im,degree):
  92. if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
  93. r_image = im.copy()
  94. else:
  95. r_image = im.resize(self.input_size)
  96. r_image = imageRotate(r_image,center_im,degree)
  97. r_image = r_image.convert("RGB")
  98. # r_image.show()
  99. _array = np.array(r_image)
  100. print("array shape",_array.shape)
  101. _array = _array.transpose([2,0,1])
  102. return _array/255
  103. def image_rotate(self,im,degree,appendColor):
  104. if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
  105. r_image = im.copy()
  106. else:
  107. r_image = im.resize(self.input_size)
  108. center_im = circle_center(r_image,appendColor)#有点问题
  109. r_image = imageRotate(r_image,center_im,degree)
  110. return r_image
  111. def rotateImage(self,im,degree,appendColor):
  112. if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
  113. r_image = im.copy()
  114. else:
  115. r_image = im.resize(self.input_size)
  116. center_im = circle_center(r_image,appendColor)#有点问题
  117. r_image = imageRotate(r_image,center_im,degree)
  118. r_image = r_image.convert("RGB")
  119. # r_image.show()
  120. _array = np.array(r_image)
  121. print("array shape",_array.shape)
  122. _array = _array.transpose([2,0,1])
  123. return _array/255
  124. def imageProcess(self,im,degree,circle=True):
  125. if im.size[0]==self.input_size[0] and im.size[1]==self.input_size[1]:
  126. r_image = im.copy()
  127. else:
  128. r_image = im.resize(self.input_size)
  129. if degree==0:
  130. center_im = circle_center(r_image)#有点问题
  131. r_image = imageRotate(r_image,center_im,degree)
  132. pass
  133. else:
  134. if circle:
  135. center_im = circle_center(r_image)
  136. r_image = imageRotate(r_image,center_im,degree)
  137. else:
  138. r_image = tailor(r_image,randint(90,140),randint(20,self.input_size[0]-20))
  139. # background_pin = background.load()
  140. r_image.show()
  141. _array = np.array(list(r_image.getdata()))
  142. _array.resize((*self.input_size,3))
  143. return _array/255
  144. def getPosLabel(self,im,appendColor=True):
  145. if random()>0.5:
  146. degree = randint(1,5)
  147. else:
  148. degree = randint(355,359)
  149. self.next_flag = False
  150. label = 1
  151. _r = random()
  152. if _r>0.7:
  153. im = im.rotate(90)
  154. elif _r>0.5:
  155. im = im.rotate(180)
  156. print(label)
  157. # im = self.imageProcess(im,degree,circle=True)
  158. # im = self.imageTailor(im,False)
  159. im = self.rotateImage(im,degree,appendColor)
  160. return im,np.array(label,dtype='long')
  161. def getNegLabel(self,im,appendColor=True):
  162. _r = random()
  163. degree = randint(5,355)
  164. label = 0
  165. self.next_flag = True
  166. if _r>0.5:
  167. circle = False
  168. print(label)
  169. # im = self.imageProcess(im,degree,circle=False)
  170. # im = self.imageTailor(im,True)
  171. im = self.rotateImage(im,degree,appendColor)
  172. return im,np.array(label,dtype='long')
  173. def getPosLabel_rotate(self,im,appendColor=True):
  174. degree = randint(1,360)
  175. self.next_flag = False
  176. label = 1
  177. _r = random()
  178. if _r>0.7:
  179. im = im.rotate(90)
  180. elif _r>0.5:
  181. im = im.rotate(180)
  182. print(label)
  183. # im = self.imageProcess(im,degree,circle=True)
  184. # im = self.imageTailor(im,False)
  185. im = self.image_rotate(im,degree,appendColor)
  186. im = self.rotateImage(im,randint(358,363)-degree,appendColor)
  187. return im,np.array(label,dtype='long')
  188. def getNegLabel_rotate(self,im,appendColor=True):
  189. _r = random()
  190. degree = randint(5,355)
  191. label = 0
  192. self.next_flag = True
  193. if _r>0.5:
  194. circle = False
  195. print(label)
  196. # im = self.imageProcess(im,degree,circle=False)
  197. # im = self.imageTailor(im,True)
  198. im = self.rotateImage(im,degree,appendColor)
  199. return im,np.array(label,dtype='long')
  200. def generateData(self,list_data,list_imgs):
  201. print("generate data")
  202. for _filepath in list_imgs:
  203. try:
  204. print(_filepath)
  205. im:Image.Image = Image.open(_filepath,"r")
  206. # _pos = self.getPosLabel(im,False)
  207. # _neg = self.getNegLabel(im,False)
  208. # list_data.append((_pos))
  209. # list_data.append((_neg))
  210. _pos = self.getPosLabel_rotate(im,True)
  211. _neg = self.getNegLabel_rotate(im,True)
  212. list_data.append((_pos))
  213. list_data.append((_neg))
  214. except Exception as e:
  215. traceback.print_exc()
  216. # save(list_data,"rotate_train_%d_%d.pk"%(self.data_begin,self.data_length))
  217. def start_prepare(self):
  218. list_process = []
  219. for _ in range(self.process_count):
  220. p = Process(target=self.process_prepare,args=([self.task_queue,self.data_queue]))
  221. list_process.append(p)
  222. for p in list_process:
  223. p.start()
  224. while 1:
  225. for p in list_process:
  226. print(p.is_alive())
  227. time.sleep(5)
  228. print("process done")
  229. while 1:
  230. try:
  231. _data = self.data_queue.get()
  232. self.data.append(_data)
  233. except Exception as e:
  234. pass
  235. def process_prepare(self,task_queue,_queue):
  236. while 1:
  237. try:
  238. print(task_queue.qsize())
  239. if task_queue.qsize()==0:
  240. return
  241. _filepath = task_queue.get(block=True,timeout=1)
  242. im_src:Image.Image = Image.open(_filepath,"r")
  243. if random.random()>0.5:
  244. degree = 0
  245. self.next_flag = False
  246. label = 1
  247. circle = True
  248. else:
  249. degree = randint(10,350)
  250. label = 0
  251. self.next_flag = True
  252. if random.random()>0.5:
  253. circle = True
  254. else:
  255. circle = False
  256. im = self.imageProcess(im_src,degree,circle=True)
  257. im = im/255
  258. _queue.put((im,np.array(label,dtype='long')))
  259. del im_src
  260. except Exception as e:
  261. break
  262. pass
  263. def __getitem__(self, index):
  264. # if index==0:
  265. # self.count_0+=1
  266. # self.count_0%=5
  267. # if self.count_0==0:
  268. # self.data = []
  269. # self.generateData(self.data,self.filter_path)
  270. self.prepareData()
  271. if self.process_count>0:
  272. return self.data_queue.get()
  273. return self.data[index]
  274. def prepareData(self):
  275. if len(self.imgs)==0:
  276. self.readImgs(self.imgs,self.filter_path)
  277. if self.process_count==0:
  278. self.generateData(self.data,self.imgs)
  279. else:
  280. for _im in self.imgs:
  281. self.task_queue.put(_im)
  282. self.start_prepare()
  283. pass
  284. def __len__(self):
  285. self.prepareData()
  286. return len(self.data)
  287. class CNNNet(nn.Module):
  288. def __init__(self,*args,**kwargs):
  289. super(CNNNet,self).__init__(*args,**kwargs)
  290. self.con1 = nn.Conv2d(3,5,5,bias=False)
  291. self.maxpool_3 = nn.MaxPool2d(3)
  292. self.con2 = nn.Conv2d(5,10,5,bias=False)
  293. self.con3 = nn.Conv2d(10,15,5,bias=False)
  294. self.fc1 = nn.Linear(3300,2)
  295. self.softmax = torch.nn.Softmax(dim=1)
  296. def forward(self,x):
  297. batch_size = x.size(0)
  298. x = x.float()
  299. x = x.to(device)
  300. x = self.maxpool_3(F.relu(self.con1(x)))
  301. x = self.maxpool_3(F.relu(self.con2(x)))
  302. x = self.maxpool_3(F.relu(self.con3(x)))
  303. x = x.view(batch_size,-1)
  304. out = self.fc1(x)
  305. out = self.softmax(out)
  306. return out
  307. class TorchTrainer():
  308. def __init__(self,epochs,net,optimizer,loss_fn,train_loader,test_loader,model_ckpt):
  309. self.epochs = epochs
  310. self.net = net
  311. self.optimizer = optimizer
  312. self.loss_fn = loss_fn
  313. self.train_loader = train_loader
  314. self.test_loader = test_loader
  315. self.model_ckpt = model_ckpt
  316. self.best_acc = None
  317. self.best_loss = None
  318. def test(self):
  319. ## 准确度测试
  320. total_correct = 0
  321. total_loss = 0
  322. total_num = len(self.test_loader.dataset)
  323. for x,y in self.test_loader:
  324. # x = x.view(x.size(0),28*28)
  325. out = self.net(x)
  326. if device=="cuda":
  327. y = y.to(device)
  328. else:
  329. y = y.to(torch.long)
  330. # out = torch.nn.Softmax()(out)
  331. pred = out.to(device).argmax(dim = 1)
  332. # print("=1",y)
  333. # print("=2",out)
  334. correct = pred.eq(y).sum().float().item() # .float之后还是tensor类型,要拿到数据需要使用item()
  335. total_correct += correct
  336. y = y.to(torch.long)
  337. loss = self.loss_fn(out, y)
  338. _loss = loss.item()
  339. total_loss += _loss
  340. acc = total_correct/total_num
  341. avg_loss = total_loss/total_num
  342. return acc,avg_loss
  343. def train(self):
  344. import sys
  345. train_loss = []
  346. print("start training")
  347. for epoch in range(self.epochs):
  348. epoch_loss = 0
  349. train_len = len(self.train_loader.dataset)
  350. print("train len:%d"%(train_len))
  351. batch_len = math.floor(len(self.train_loader.dataset)/self.train_loader.batch_size)
  352. total_correct = 0
  353. for batch_idx, (x, y) in enumerate(self.train_loader):
  354. # x:[b,1,28,28],y:[512]
  355. # [b,1,28,28] => [b,784]
  356. # x = x.view(x.size(0), 28 * 28)
  357. # =>[b,10]
  358. # out = net(x)
  359. # 清零梯度
  360. self.optimizer.zero_grad()
  361. out = self.net(x)
  362. # loss = mse(out,y_onehot)
  363. if device=="cuda":
  364. y = y.to(device)
  365. else:
  366. y = y.to(torch.long)
  367. loss = self.loss_fn(out, y)
  368. pred = out.argmax(dim=1)
  369. correct = pred.eq(y).sum().float().item()
  370. total_correct += correct
  371. # 计算梯度
  372. loss.backward()
  373. # w' = w -lr*grad
  374. # 更新梯度,得到新的[w1,b1,w2,b2,w3,b3]
  375. self.optimizer.step()
  376. _loss = loss.item()
  377. epoch_loss += _loss
  378. train_loss.append(_loss)
  379. if batch_idx % 2 == 0:
  380. print("epoch:%d batch_idx:%d loss:%.3f"%(epoch, batch_idx, loss.item()))
  381. print("epcho %d train_acc%.3f val_loss%.3f"%(epoch,total_correct/train_len,epoch_loss/train_len))
  382. val_acc,val_loss = self.test()
  383. print("epcho %d val_acc%.3f val_loss%.3f"%(epoch,val_acc,val_loss))
  384. self.saveModel(epoch,val_acc,val_loss)
  385. def saveModel(self,epoch,val_acc,val_loss):
  386. save_flag = False
  387. if self.best_acc is None:
  388. self.best_acc = val_acc
  389. self.best_loss = val_loss
  390. save_flag = True
  391. else:
  392. if val_acc>=self.best_acc and val_loss<self.best_loss:
  393. self.best_acc = val_acc
  394. self.best_loss = val_loss
  395. save_flag = True
  396. if save_flag:
  397. torch.save(self.net,"%s_epoch%d_acc%.3f_loss%.3f.pt"%(self.model_ckpt,epoch,self.best_acc,self.best_loss))
  398. def trainTorch():
  399. cnnnet = torch.load("model/rotate_epoch3_acc0.970_loss0.017.pt",map_location=torch.device(device))
  400. _d = cnnnet.state_dict()
  401. cnnnet = CNNNet()
  402. cnnnet.load_state_dict(_d,True)
  403. cnnnet = cnnnet.to(device)
  404. optimizer = optim.SGD(cnnnet.parameters(),lr=0.01)
  405. train_loader = torch.utils.data.DataLoader(RotateDataSet(train=True),batch_size=20)
  406. test_loader = torch.utils.data.DataLoader(RotateDataSet(train=False),batch_size=20)
  407. loss_fn = nn.CrossEntropyLoss()
  408. trainer = TorchTrainer(100,cnnnet,optimizer,loss_fn,train_loader,test_loader,"model/rotate")
  409. trainer.train()
  410. import keras
  411. from keras.layers import *
  412. def getKerasModel():
  413. input = keras.models.Input(shape=(600,360,3))
  414. conv1 = Conv2D(5,(3,3),activation="relu")(input)
  415. maxpool1 = MaxPool2D((2,2))(conv1)
  416. conv2 = Conv2D(10,(3,3),activation="relu")(maxpool1)
  417. maxpool2 = MaxPool2D((5,5))(conv2)
  418. fla = Flatten()(maxpool2)
  419. out = Dense(2,activation="softmax")(fla)
  420. model = keras.models.Model(inputs=[input],outputs=[out])
  421. model.summary()
  422. model.compile("adam",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
  423. return model
  424. def trainKeras():
  425. train_loader = torch.utils.data.DataLoader(RotateDataSet(train=True),batch_size=60)
  426. test_loader = torch.utils.data.DataLoader(RotateDataSet(train=False),batch_size=512)
  427. list_x = []
  428. list_y = []
  429. for batch_idx, (x, y) in enumerate(train_loader):
  430. print(type(x))
  431. list_x.extend(x.numpy())
  432. list_y.extend(y.numpy())
  433. val_x = []
  434. val_y = []
  435. for batch_index,(x,y) in enumerate(test_loader):
  436. val_x.extend(x.numpy())
  437. val_y.extend(y.numpy())
  438. model = getKerasModel()
  439. model.fit(np.array(list_x),np.array(list_y),epochs=100,validation_data=(np.array(val_x),np.array(val_y)),callbacks=[keras.callbacks.ModelCheckpoint(filepath="keras")])
  440. if __name__ == '__main__':
  441. trainTorch()
  442. # trainKeras()