123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import torch
- from torch import nn # 常用网络
- from torch import optim # 优化工具包
- import torchvision # 视觉数据集
- from matplotlib import pyplot as plt
- from torch.nn import functional as F
- ## 加载数据
- batch_size=512
- train_loader = torch.utils.data.DataLoader(
- torchvision.datasets.MNIST('mnist_data',train=True,download=True,
- transform=torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize((0.1307,), (0.3081,)) # 做一个标准化
- ])),
- batch_size=batch_size,shuffle=True)
- test_loader = torch.utils.data.DataLoader(
- torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
- transform=torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=batch_size,shuffle=True)
- x,y=next(iter(train_loader))
- print(x.shape,y.shape,x.min(),x.max())
- relu = nn.ReLU() # 如果使用torch.sigmoid作为激活函数的话正确率只有60%
- # 创建网络
- class Net(nn.Module):
- def __init__(self):
- super(Net,self).__init__()
- # xw+b 这里的256,64使我们人根据自己的感觉指定的
- self.fc1 = nn.Linear(28*28,256)
- self.fc2 = nn.Linear(256,64)
- self.fc3 = nn.Linear(256,10)
- self.activate_softmax = nn.Softmax()
- def forward(self,x):
- # 因为找不到relu函数,就换成了激活函数
- # x:[b,1,28,28]
- # h1 = relu(xw1+b1)
- x = relu(self.fc1(x))
- # h2 = relu(h1w2+b2)
- # x = relu(self.fc2(x))
- # h3 = h2*w3+b3
- x = self.fc3(x)
- x = self.activate_softmax(x)
- return x
- class CNNNet(nn.Module):
- def __init__(self):
- super(CNNNet,self).__init__()
- self.con1 = nn.Conv2d(1,10,5,bias=False)
- self.maxpool1 = nn.MaxPool2d(2)
- self.con2 = nn.Conv2d(10,20,5,bias=False)
- self.fc = nn.Linear(320,10)
- def forward(self,x):
- batch_size = x.size(0)
- x = F.relu(self.con1(x))
- x = F.relu(self.con2(self.maxpool1(x)))
- x = self.maxpool1(x)
- x = x.view(batch_size,-1)
- out = self.fc(x)
- return out
- # 因为找不到自带的one_hot函数,就手写了一个
- def one_hot(label, depth=10):
- out = torch.zeros(label.size(0), depth)
- idx = torch.LongTensor(label).view(-1, 1)
- out.scatter_(dim=1, index=idx, value=1)
- return out
- ## 训练模型
- net = Net()
- cnnnet = CNNNet()
- # 返回[w1,b1,w2,b2,w3,b3] 对象,lr是学习过程
- optimizer = optim.SGD(cnnnet.parameters(), lr=0.01, momentum=0.9)
- train_loss = []
- mes_loss = nn.CrossEntropyLoss()
- for epoch in range(5):
- for batch_idx, (x, y) in enumerate(train_loader):
- # x:[b,1,28,28],y:[512]
- # [b,1,28,28] => [b,784]
- # x = x.view(x.size(0), 28 * 28)
- # =>[b,10]
- # out = net(x)
- # 清零梯度
- optimizer.zero_grad()
- out = cnnnet(x)
- # [b,10]
- y_onehot = one_hot(y)
- # loss = mse(out,y_onehot)
- loss = mes_loss(out, y)
- # 计算梯度
- loss.backward()
- # w' = w -lr*grad
- # 更新梯度,得到新的[w1,b1,w2,b2,w3,b3]
- optimizer.step()
- train_loss.append(loss.item())
- if batch_idx % 10 == 0:
- print(epoch, batch_idx, loss.item())
- # plot_curve(train_loss)
- # 到现在得到了[w1,b1,w2,b2,w3,b3]
- ## 准确度测试
- total_correct = 0
- for x,y in test_loader:
- # x = x.view(x.size(0),28*28)
- out = cnnnet(x)
- # out : [b,10] => pred: [b]
- pred = out.argmax(dim = 1)
- correct = pred.eq(y).sum().float().item() # .float之后还是tensor类型,要拿到数据需要使用item()
- total_correct += correct
- total_num = len(test_loader.dataset)
- acc = total_correct/total_num
- print('准确率acc:',acc)
|