DL之CNN:自定义SimpleConvNet【3层,im2col优化】利用mnist数据集实现手写数字识别多分类训练来评估模型


雪碧
雪碧 2022-09-19 15:04:57 49218
分类专栏: 资讯

DL之CNN:自定义SimpleConvNet【3层,im2col优化】利用mnist数据集实现手写数字识别多分类训练来评估模型

目录

输出结果

设计思路

核心代码

更多输出


输出结果

设计思路

核心代码

  1. class Convolution:
  2. def __init__(self, W, b, stride=1, pad=0):
  3. ……
  4. def forward(self, x):
  5. FN, C, FH, FW = self.W.shape
  6. N, C, H, W = x.shape
  7. out_h = 1 + int((H + 2*self.pad - FH) / self.stride)
  8. out_w = 1 + int((W + 2*self.pad - FW) / self.stride)
  9. col = im2col(x, FH, FW, self.stride, self.pad)
  10. col_W = self.W.reshape(FN, -1).T
  11. out = np.dot(col, col_W) + self.b
  12. out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
  13. self.x = x
  14. self.col = col
  15. self.col_W = col_W
  16. return out
  17. def backward(self, dout):
  18. FN, C, FH, FW = self.W.shape
  19. dout = dout.transpose(0,2,3,1).reshape(-1, FN)
  20. self.db = np.sum(dout, axis=0)
  21. self.dW = np.dot(self.col.T, dout)
  22. self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
  23. dcol = np.dot(dout, self.col_W.T)
  24. return dx
  25. class Pooling:
  26. def __init__(self, pool_h, pool_w, stride=1, pad=0):
  27. self.pool_h = pool_h
  28. self.pool_w = pool_w
  29. self.stride = stride
  30. self.pad = pad
  31. self.x = None
  32. self.arg_max = None
  33. ……
  34. class SimpleConvNet:
  35. def __init__(self, input_dim=(1, 28, 28),
  36. conv_param={'filter_num':30, 'filter_size':5, 'pad':0, 'stride':1},
  37. hidden_size=100, output_size=10, weight_init_std=0.01):
  38. filter_num = conv_param['filter_num']
  39. filter_size = conv_param['filter_size']
  40. filter_pad = conv_param['pad']
  41. filter_stride = conv_param['stride']
  42. input_size = input_dim[1]
  43. conv_output_size = (input_size - filter_size + 2*filter_pad) / filter_stride + 1
  44. pool_output_size = int(filter_num * (conv_output_size/2) * (conv_output_size/2))
  45. self.params = {}
  46. self.params['W1'] = weight_init_std * \
  47. np.random.randn(filter_num, input_dim[0], filter_size, filter_size)
  48. self.params['b1'] = np.zeros(filter_num)
  49. self.params['W2'] = weight_init_std * \
  50. np.random.randn(pool_output_size, hidden_size)
  51. self.params['b2'] = np.zeros(hidden_size)
  52. self.params['W3'] = weight_init_std * \
  53. np.random.randn(hidden_size, output_size)
  54. self.params['b3'] = np.zeros(output_size)
  55. self.layers = OrderedDict()
  56. self.layers['Conv1'] = Convolution(self.params['W1'], self.params['b1'],
  57. conv_param['stride'], conv_param['pad'])
  58. self.layers['Relu1'] = Relu()
  59. self.layers['Pool1'] = Pooling(pool_h=2, pool_w=2, stride=2)
  60. self.layers['Affine1'] = Affine(self.params['W2'], self.params['b2'])
  61. self.layers['Relu2'] = Relu()
  62. self.layers['Affine2'] = Affine(self.params['W3'], self.params['b3'])
  63. self.last_layer = SoftmaxWithLoss()
  64. ……
  65. def save_params(self, file_name="params.pkl"):
  66. params = {}
  67. for key, val in self.params.items():
  68. params[key] = val
  69. with open(file_name, 'wb') as f:
  70. pickle.dump(params, f)
  71. def load_params(self, file_name="params.pkl"):
  72. with open(file_name, 'rb') as f:
  73. params = pickle.load(f)
  74. for key, val in params.items():
  75. self.params[key] = val
  76. for i, key in enumerate(['Conv1', 'Affine1', 'Affine2']):
  77. self.layers[key].W = self.params['W' + str(i+1)]
  78. self.layers[key].b = self.params['b' + str(i+1)]

更多输出

  1. train_loss:2.29956519109714
  2. === epoch:1, train_acc:0.216, test_acc:0.218 ===
  3. train_loss:2.2975110344641716
  4. train_loss:2.291654113382576
  5. train_loss:2.2858174689127875
  6. train_loss:2.272262093336837
  7. train_loss:2.267908303517325
  8. train_loss:2.2584119706864336
  9. train_loss:2.2258807222804693
  10. train_loss:2.2111025085252543
  11. train_loss:2.188119055308738
  12. train_loss:2.163215575430596
  13. train_loss:2.1191887076886724
  14. train_loss:2.0542599060672186
  15. train_loss:2.0244523646451915
  16. train_loss:1.9779786923239808
  17. train_loss:1.9248431928319325
  18. train_loss:1.7920653808470397
  19. train_loss:1.726860911000866
  20. train_loss:1.7075144252509131
  21. train_loss:1.6875413868425186
  22. train_loss:1.6347461097804266
  23. train_loss:1.5437112361395253
  24. train_loss:1.4987893515035628
  25. train_loss:1.3856720782969847
  26. train_loss:1.2002110952243676
  27. train_loss:1.2731100379603273
  28. train_loss:1.117132621224333
  29. train_loss:1.0622583460165833
  30. train_loss:1.0960592785565957
  31. train_loss:0.8692067763172185
  32. train_loss:0.8548780420217317
  33. train_loss:0.83872966253374
  34. train_loss:0.7819342397053507
  35. train_loss:0.7589812430284729
  36. train_loss:0.7955332004991336
  37. train_loss:0.8190930469691535
  38. train_loss:0.6297212128196131
  39. train_loss:0.8279837022068413
  40. train_loss:0.6996430264702379
  41. train_loss:0.5256550729087258
  42. train_loss:0.7288553394002595
  43. train_loss:0.7033049908220391
  44. train_loss:0.5679669207218877
  45. train_loss:0.6344174262581003
  46. train_loss:0.7151382401438272
  47. train_loss:0.5814593192354963
  48. train_loss:0.5736217677325146
  49. train_loss:0.5673622947809682
  50. train_loss:0.48303413903204395
  51. train_loss:0.452267909884157
  52. train_loss:0.4009118158839013
  53. === epoch:2, train_acc:0.818, test_acc:0.806 ===
  54. train_loss:0.5669686001623327
  55. train_loss:0.5358187806595359
  56. train_loss:0.3837535143737321
  57. train_loss:0.544335563142595
  58. train_loss:0.39288485196871803
  59. train_loss:0.49770310644457566
  60. train_loss:0.4610248131112265
  61. train_loss:0.36641463191798196
  62. train_loss:0.4874682221372042
  63. train_loss:0.38796698110644817
  64. train_loss:0.3620230776259665
  65. train_loss:0.4744726274001774
  66. train_loss:0.3086952062454927
  67. train_loss:0.40012397040718645
  68. train_loss:0.3634667070910744
  69. train_loss:0.3204093812396573
  70. train_loss:0.5063082359543781
  71. train_loss:0.5624992123039615
  72. train_loss:0.34281562891324663
  73. train_loss:0.3415065217065326
  74. train_loss:0.4946703009790488
  75. train_loss:0.48942997572068253
  76. train_loss:0.25416776815225534
  77. train_loss:0.3808555005314615
  78. train_loss:0.22793380858862108
  79. train_loss:0.4709915396804245
  80. train_loss:0.25826190862498605
  81. train_loss:0.44862426522901516
  82. train_loss:0.25519522472564815
  83. train_loss:0.5063495442657376
  84. train_loss:0.37233317168099206
  85. train_loss:0.4027673899570495
  86. train_loss:0.4234905061164214
  87. train_loss:0.44590221111177714
  88. train_loss:0.3846538639824134
  89. train_loss:0.3371733857576183
  90. train_loss:0.23612786737321756
  91. train_loss:0.4814543539448962
  92. train_loss:0.38362762929477556
  93. train_loss:0.5105811329813293
  94. train_loss:0.31729857191880056
  95. train_loss:0.43677582454472663
  96. train_loss:0.37362647454980324
  97. train_loss:0.2696715797445873
  98. train_loss:0.26682852302518134
  99. train_loss:0.18763432881504752
  100. train_loss:0.2886557425885745
  101. train_loss:0.23833327847639763
  102. train_loss:0.36315802981646
  103. train_loss:0.21083779781027828
  104. === epoch:3, train_acc:0.89, test_acc:0.867 ===
  105. train_loss:0.34070333399972674
  106. train_loss:0.3356587138064409
  107. train_loss:0.25919406618960505
  108. train_loss:0.31537349840856743
  109. train_loss:0.2276928810208216
  110. train_loss:0.32171416950979326
  111. train_loss:0.22754919179736025
  112. train_loss:0.37619164258262944
  113. train_loss:0.3221102374023198
  114. train_loss:0.36724681541104537
  115. train_loss:0.3310213819075522
  116. train_loss:0.33583429981768936
  117. train_loss:0.36054827740285833
  118. train_loss:0.3002031789326344
  119. train_loss:0.19480027104864756
  120. train_loss:0.3074748184113467
  121. train_loss:0.31035699050378
  122. train_loss:0.37289594799797554
  123. train_loss:0.38054981033442864
  124. train_loss:0.2150866558286973
  125. train_loss:0.4014488874986493
  126. train_loss:0.2643304660197891
  127. train_loss:0.31806887985854354
  128. train_loss:0.29365139713396693
  129. train_loss:0.33212651106203267
  130. train_loss:0.29544164636048587
  131. train_loss:0.4969991428069569
  132. train_loss:0.3348535409949116
  133. train_loss:0.18914984777413654
  134. train_loss:0.3868380951987871
  135. train_loss:0.26857192970788485
  136. train_loss:0.373151707743815
  137. train_loss:0.3522570704735893
  138. train_loss:0.204823140388568
  139. train_loss:0.3974239710544049
  140. train_loss:0.21753509102652058
  141. train_loss:0.26034229667679715
  142. train_loss:0.26991319118062235
  143. train_loss:0.30959776720795107
  144. train_loss:0.2718109180045845
  145. train_loss:0.2738413103423023
  146. train_loss:0.22209179719364106
  147. train_loss:0.5025051167945939
  148. train_loss:0.23308114849307443
  149. train_loss:0.24989561030033144
  150. train_loss:0.4666621160650158
  151. train_loss:0.3511547384608582
  152. train_loss:0.32856542443039893
  153. train_loss:0.29344954251556093

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3166
赞同 0
评论 0 条
雪碧L0
粉丝 0 发表 11 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2941
【软件正版化】软件正版化工作要点  2860
统信UOS试玩黑神话:悟空  2819
信刻光盘安全隔离与信息交换系统  2712
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1246
grub引导程序无法找到指定设备和分区  1213
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!