PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN


老鼠清脆
老鼠清脆 2022-09-20 10:01:39 50979
分类专栏: 资讯

PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN

目录

训练过程

代码设计


训练过程

代码设计

  1. PyTorch:利用PyTorch实现最经典的LeNet卷积神经网络对手写数字进行识别CNN——Jason niu
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. class LeNet(nn.Module):
  6. def __init__(self):
  7. super(LeNet,self).__init__()
  8. Conv1 和 Conv2:卷积层,每个层输出在卷积核(小尺寸的权重张量)和同样尺寸输入区域之间的点积;
  9. self.conv1 = nn.Conv2d(1,10,kernel_size=5)
  10. self.conv2 = nn.Conv2d(10,20,kernel_size=5)
  11. self.conv2_drop = nn.Dropout2d()
  12. self.fc1 = nn.Linear(320,50)
  13. self.fc2 = nn.Linear(50,10)
  14. def forward(self,x):
  15. x = F.relu(F.max_pool2d(self.conv1(x),2)) 使用 max 运算执行特定区域的下采样(通常 2x2 像素);
  16. x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))
  17. x = x.view(-1, 320)
  18. x = F.relu(self.fc1(x)) 修正线性单元函数,使用逐元素的激活函数 max(0,x);
  19. x = F.dropout(x, training=self.training) Dropout2D随机将输入张量的所有通道设为零。当特征图具备强相关时,dropout2D 提升特征图之间的独立性;
  20. x = self.fc2(x)
  21. return F.log_softmax(x, dim=1) 将 Log(Softmax(x)) 函数应用到 n 维输入张量,以使输出在 0 到 1 之间。
  22. 创建 LeNet 类后,创建对象并移至 GPU
  23. model = LeNet()
  24. criterion = nn.CrossEntropyLoss()
  25. optimizer = optim.SGD(model.parameters(),lr = 0.005, momentum = 0.9) 要训练该模型,我们需要使用带动量的 SGD,学习率为 0.01,momentum 为 0.5。
  26. import os
  27. from torch.autograd import Variable
  28. import torch.nn.functional as F
  29. cuda_gpu = torch.cuda.is_available()
  30. def train(model, epoch, criterion, optimizer, data_loader):
  31. model.train()
  32. for batch_idx, (data, target) in enumerate(data_loader):
  33. if cuda_gpu:
  34. data, target = data.cuda(), target.cuda()
  35. model.cuda()
  36. data, target = Variable(data), Variable(target)
  37. output = model(data)
  38. optimizer.zero_grad()
  39. loss = criterion(output, target)
  40. loss.backward()
  41. optimizer.step()
  42. if (batch_idx+1) % 400 == 0:
  43. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  44. epoch, (batch_idx+1) * len(data), len(data_loader.dataset),
  45. 100. * (batch_idx+1) / len(data_loader), loss.data[0]))
  46. from torchvision import datasets, transforms
  47. batch_num_size = 64
  48. train_loader = torch.utils.data.DataLoader(
  49. datasets.MNIST('data',train=True, download=True, transform=transforms.Compose([
  50. transforms.ToTensor(),
  51. transforms.Normalize((0.1307,), (0.3081,))
  52. ])),
  53. batch_size=batch_num_size, shuffle=True)
  54. test_loader = torch.utils.data.DataLoader(
  55. datasets.MNIST('data',train=False, transform=transforms.Compose([
  56. transforms.ToTensor(),
  57. transforms.Normalize((0.1307,), (0.3081,))
  58. ])),
  59. batch_size=batch_num_size, shuffle=True)
  60. def test(model, epoch, criterion, data_loader):
  61. model.eval()
  62. test_loss = 0
  63. correct = 0
  64. for data, target in data_loader:
  65. if cuda_gpu:
  66. data, target = data.cuda(), target.cuda()
  67. model.cuda()
  68. data, target = Variable(data), Variable(target)
  69. output = model(data)
  70. test_loss += criterion(output, target).data[0]
  71. pred = output.data.max(1)[1] get the index of the max log-probability
  72. correct += pred.eq(target.data).cpu().sum()
  73. test_loss /= len(data_loader) loss function already averages over batch size
  74. acc = correct / len(data_loader.dataset)
  75. print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  76. test_loss, correct, len(data_loader.dataset), 100. * acc))
  77. return (acc, test_loss)
  78. epochs = 5 仅仅需要 5 个 epoch(一个 epoch 意味着你使用整个训练数据集来更新训练模型的权重),就可以训练出一个相当准确的 LeNet 模型。
  79. 这段代码检查可以确定文件中是否已有预训练好的模型。有则加载;无则训练一个并保存至磁盘。
  80. if (os.path.isfile('pretrained/MNIST_net.t7')):
  81. print ('Loading model')
  82. model.load_state_dict(torch.load('pretrained/MNIST_net.t7', map_location=lambda storage, loc: storage))
  83. acc, loss = test(model, 1, criterion, test_loader)
  84. else:
  85. print ('Training model') 打印出该模型的信息。打印函数显示所有层(如 Dropout 被实现为一个单独的层)及其名称和参数。
  86. for epoch in range(1, epochs + 1):
  87. train(model, epoch, criterion, optimizer, train_loader)
  88. acc, loss = test(model, 1, criterion, test_loader)
  89. torch.save(model.state_dict(), 'pretrained/MNIST_net.t7')
  90. print (type(t.cpu().data))以使用 .cpu() 方法将张量移至 CPU(或确保它在那里)。
  91. 或当 GPU 可用时(torch.cuda. 可用),使用 .cuda() 方法将张量移至 GPU。你可以看到张量是否在 GPU 上,其类型为 torch.cuda.FloatTensor。
  92. 如果张量在 CPU 上,则其类型为 torch.FloatTensor。
  93. if torch.cuda.is_available():
  94. print ("Cuda is available")
  95. print (type(t.cuda().data))
  96. else:
  97. print ("Cuda is NOT available")
  98. if torch.cuda.is_available():
  99. try:
  100. print(t.data.numpy())
  101. except RuntimeError as e:
  102. "you can't transform a GPU tensor to a numpy nd array, you have to copy your weight tendor to cpu and then get the numpy array"
  103. print(type(t.cpu().data.numpy()))
  104. print(t.cpu().data.numpy().shape)
  105. print(t.cpu().data.numpy())
  106. data = model.conv1.weight.cpu().data.numpy()
  107. print (data.shape)
  108. print (data[:, 0].shape)
  109. kernel_num = data.shape[0]
  110. fig, axes = plt.subplots(ncols=kernel_num, figsize=(2*kernel_num, 2))
  111. for col in range(kernel_num):
  112. axes[col].imshow(data[col, 0, :, :], cmap=plt.cm.gray)
  113. plt.show()

相关文章
LeNet-5 is our latest convolutional network designed for handwritten and machine-printed character recognition.

文章知识点与官方知识档案匹配,可进一步学习相关知识
Python入门技能树人工智能深度学习123871 人正在系统学习中

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=4006
赞同 0
评论 0 条
老鼠清脆L0
粉丝 0 发表 8 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2941
【软件正版化】软件正版化工作要点  2860
统信UOS试玩黑神话:悟空  2819
信刻光盘安全隔离与信息交换系统  2712
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1246
grub引导程序无法找到指定设备和分区  1213
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!