torchHandwrite.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import torch
  2. from torch import nn # 常用网络
  3. from torch import optim # 优化工具包
  4. import torchvision # 视觉数据集
  5. from matplotlib import pyplot as plt
  6. from torch.nn import functional as F
  7. ## 加载数据
  8. batch_size=512
  9. train_loader = torch.utils.data.DataLoader(
  10. torchvision.datasets.MNIST('mnist_data',train=True,download=True,
  11. transform=torchvision.transforms.Compose([
  12. torchvision.transforms.ToTensor(),
  13. torchvision.transforms.Normalize((0.1307,), (0.3081,)) # 做一个标准化
  14. ])),
  15. batch_size=batch_size,shuffle=True)
  16. test_loader = torch.utils.data.DataLoader(
  17. torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
  18. transform=torchvision.transforms.Compose([
  19. torchvision.transforms.ToTensor(),
  20. torchvision.transforms.Normalize((0.1307,), (0.3081,))
  21. ])),
  22. batch_size=batch_size,shuffle=True)
  23. x,y=next(iter(train_loader))
  24. print(x.shape,y.shape,x.min(),x.max())
  25. relu = nn.ReLU() # 如果使用torch.sigmoid作为激活函数的话正确率只有60%
  26. # 创建网络
  27. class Net(nn.Module):
  28. def __init__(self):
  29. super(Net,self).__init__()
  30. # xw+b 这里的256,64使我们人根据自己的感觉指定的
  31. self.fc1 = nn.Linear(28*28,256)
  32. self.fc2 = nn.Linear(256,64)
  33. self.fc3 = nn.Linear(256,10)
  34. self.activate_softmax = nn.Softmax()
  35. def forward(self,x):
  36. # 因为找不到relu函数,就换成了激活函数
  37. # x:[b,1,28,28]
  38. # h1 = relu(xw1+b1)
  39. x = relu(self.fc1(x))
  40. # h2 = relu(h1w2+b2)
  41. # x = relu(self.fc2(x))
  42. # h3 = h2*w3+b3
  43. x = self.fc3(x)
  44. x = self.activate_softmax(x)
  45. return x
  46. class CNNNet(nn.Module):
  47. def __init__(self):
  48. super(CNNNet,self).__init__()
  49. self.con1 = nn.Conv2d(1,10,5,bias=False)
  50. self.maxpool1 = nn.MaxPool2d(2)
  51. self.con2 = nn.Conv2d(10,20,5,bias=False)
  52. self.fc = nn.Linear(320,10)
  53. def forward(self,x):
  54. batch_size = x.size(0)
  55. x = F.relu(self.con1(x))
  56. x = F.relu(self.con2(self.maxpool1(x)))
  57. x = self.maxpool1(x)
  58. x = x.view(batch_size,-1)
  59. out = self.fc(x)
  60. return out
  61. # 因为找不到自带的one_hot函数,就手写了一个
  62. def one_hot(label, depth=10):
  63. out = torch.zeros(label.size(0), depth)
  64. idx = torch.LongTensor(label).view(-1, 1)
  65. out.scatter_(dim=1, index=idx, value=1)
  66. return out
  67. ## 训练模型
  68. net = Net()
  69. cnnnet = CNNNet()
  70. # 返回[w1,b1,w2,b2,w3,b3] 对象,lr是学习过程
  71. optimizer = optim.SGD(cnnnet.parameters(), lr=0.01, momentum=0.9)
  72. train_loss = []
  73. mes_loss = nn.CrossEntropyLoss()
  74. for epoch in range(5):
  75. for batch_idx, (x, y) in enumerate(train_loader):
  76. # x:[b,1,28,28],y:[512]
  77. # [b,1,28,28] => [b,784]
  78. # x = x.view(x.size(0), 28 * 28)
  79. # =>[b,10]
  80. # out = net(x)
  81. # 清零梯度
  82. optimizer.zero_grad()
  83. out = cnnnet(x)
  84. # [b,10]
  85. y_onehot = one_hot(y)
  86. # loss = mse(out,y_onehot)
  87. loss = mes_loss(out, y)
  88. # 计算梯度
  89. loss.backward()
  90. # w' = w -lr*grad
  91. # 更新梯度,得到新的[w1,b1,w2,b2,w3,b3]
  92. optimizer.step()
  93. train_loss.append(loss.item())
  94. if batch_idx % 10 == 0:
  95. print(epoch, batch_idx, loss.item())
  96. # plot_curve(train_loss)
  97. # 到现在得到了[w1,b1,w2,b2,w3,b3]
  98. ## 准确度测试
  99. total_correct = 0
  100. for x,y in test_loader:
  101. # x = x.view(x.size(0),28*28)
  102. out = cnnnet(x)
  103. # out : [b,10] => pred: [b]
  104. pred = out.argmax(dim = 1)
  105. correct = pred.eq(y).sum().float().item() # .float之后还是tensor类型,要拿到数据需要使用item()
  106. total_correct += correct
  107. total_num = len(test_loader.dataset)
  108. acc = total_correct/total_num
  109. print('准确率acc:',acc)