DL之DNN:自定义MultiLayerNet(5*100+ReLU+SGD/Momentum/AdaGrad/Adam四种最优化)对MNIST数据集训练进而比较不同方法的性能


鸵鸟不安
鸵鸟不安 2022-09-19 15:10:34 51760
分类专栏: 资讯

DL之DNN:自定义MultiLayerNet(5*100+ReLU+SGD/Momentum/AdaGrad/Adam四种最优化)对MNIST数据集训练进而比较不同方法的性能

目录

输出结果

设计思路

核心代码


输出结果

  1. ===========iteration:0===========
  2. SGD:2.289282108880558
  3. Momentum:2.2858501933777964
  4. AdaGrad:2.135969407893337
  5. Adam:2.2214629551644443
  6. ===========iteration:100===========
  7. SGD:1.549948593098733
  8. Momentum:0.2630614409487161
  9. AdaGrad:0.1280980906681204
  10. Adam:0.21268580798960957
  11. ===========iteration:200===========
  12. SGD:0.7668413651485669
  13. Momentum:0.19974263379725932
  14. AdaGrad:0.0688320187945635
  15. Adam:0.12737004371824456
  16. ===========iteration:300===========
  17. SGD:0.46630711328743457
  18. Momentum:0.17680542175883507
  19. AdaGrad:0.0580940990397764
  20. Adam:0.12930303058268838
  21. ===========iteration:400===========
  22. SGD:0.34526365067568743
  23. Momentum:0.08914404106297127
  24. AdaGrad:0.038093353912494965
  25. Adam:0.06415424083978832
  26. ===========iteration:500===========
  27. SGD:0.3588584559967853
  28. Momentum:0.1299949652623088
  29. AdaGrad:0.040978421988412894
  30. Adam:0.058780880102566074
  31. ===========iteration:600===========
  32. SGD:0.38273120367667224
  33. Momentum:0.14074766142608885
  34. AdaGrad:0.08641723451090685
  35. Adam:0.11339321858037713
  36. ===========iteration:700===========
  37. SGD:0.381094901742027
  38. Momentum:0.1566582072807326
  39. AdaGrad:0.08844650332208387
  40. Adam:0.10485802139218811
  41. ===========iteration:800===========
  42. SGD:0.25722603754213674
  43. Momentum:0.07897119725740888
  44. AdaGrad:0.04960128385990466
  45. Adam:0.0835996553542796
  46. ===========iteration:900===========
  47. SGD:0.33273148769731326
  48. Momentum:0.19612162874621766
  49. AdaGrad:0.03441995281224886
  50. Adam:0.12248261979926914
  51. ===========iteration:1000===========
  52. SGD:0.26394416793465253
  53. Momentum:0.10157776537129978
  54. AdaGrad:0.04761303979039287
  55. Adam:0.046994040537976525
  56. ===========iteration:1100===========
  57. SGD:0.23894569840123672
  58. Momentum:0.09093030644899333
  59. AdaGrad:0.07018006635107976
  60. Adam:0.07879622117292093
  61. ===========iteration:1200===========
  62. SGD:0.24382935069334477
  63. Momentum:0.08324889705863456
  64. AdaGrad:0.04484659272127939
  65. Adam:0.0719509559060747
  66. ===========iteration:1300===========
  67. SGD:0.21307958354960485
  68. Momentum:0.07030166296163001
  69. AdaGrad:0.022552468995955182
  70. Adam:0.049860815437560935
  71. ===========iteration:1400===========
  72. SGD:0.3110486414209358
  73. Momentum:0.13117004626934742
  74. AdaGrad:0.07351569965620054
  75. Adam:0.09723751626189574
  76. ===========iteration:1500===========
  77. SGD:0.2087589466947655
  78. Momentum:0.09088929766254576
  79. AdaGrad:0.027825434320282873
  80. Adam:0.06352715244823183
  81. ===========iteration:1600===========
  82. SGD:0.12783635178644553
  83. Momentum:0.053366262737818
  84. AdaGrad:0.012093087503155344
  85. Adam:0.021385013278486315
  86. ===========iteration:1700===========
  87. SGD:0.21476134194349975
  88. Momentum:0.08453161462373757
  89. AdaGrad:0.054955557126319256
  90. Adam:0.035257261368372185
  91. ===========iteration:1800===========
  92. SGD:0.3415964018415049
  93. Momentum:0.13866704706781385
  94. AdaGrad:0.04585298765046911
  95. Adam:0.06437669858445684
  96. ===========iteration:1900===========
  97. SGD:0.13530674587479818
  98. Momentum:0.03958142222010819
  99. AdaGrad:0.019096102635470277
  100. Adam:0.02185864115092371

设计思路

核心代码

  1. T1、SGD算法
  2. class SGD:
  3. '……'
  4. def update(self, params, grads):
  5. for key in params.keys():
  6. params[key] -= self.lr * grads[key]
  7. T2、Momentum算法
  8. import numpy as np
  9. class Momentum:
  10. '……'
  11. def update(self, params, grads):
  12. if self.v is None:
  13. self.v = {}
  14. for key, val in params.items():
  15. self.v[key] = np.zeros_like(val)
  16. for key in params.keys():
  17. self.v[key] = self.momentum*self.v[key] - self.lr*grads[key]
  18. params[key] += self.v[key]
  19. T3、AdaGrad算法
  20. '……'
  21. def update(self, params, grads):
  22. if self.h is None:
  23. self.h = {}
  24. for key, val in params.items():
  25. self.h[key] = np.zeros_like(val)
  26. for key in params.keys():
  27. self.h[key] += grads[key] * grads[key]
  28. params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)
  29. T4、Adam算法
  30. '……'
  31. def update(self, params, grads):
  32. if self.m is None:
  33. self.m, self.v = {}, {}
  34. for key, val in params.items():
  35. self.m[key] = np.zeros_like(val)
  36. self.v[key] = np.zeros_like(val)
  37. self.iter += 1
  38. lr_t = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter)
  39. for key in params.keys():
  40. self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key])
  41. self.v[key] += (1 - self.beta2) * (grads[key]**2 - self.v[key])
  42. params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7)
  43. networks = {}
  44. train_loss = {}
  45. for key in optimizers.keys():
  46. networks[key] = MultiLayerNet( input_size=784, hidden_size_list=[10, 10, 10, 10], output_size=10)
  47. train_loss[key] = []
  48. for i in range(max_iterations):
  49. batch_mask = np.random.choice(train_size, batch_size)
  50. x_batch = x_train[batch_mask]
  51. t_batch = t_train[batch_mask]
  52. for key in optimizers.keys():
  53. grads = networks[key].gradient(x_batch, t_batch)
  54. optimizers[key].update(networks[key].params, grads)
  55. loss = networks[key].loss(x_batch, t_batch)
  56. train_loss[key].append(loss)
  57. if i % 100 == 0:
  58. print( "===========" + "iteration:" + str(i) + "===========")
  59. for key in optimizers.keys():
  60. loss = networks[key].loss(x_batch, t_batch)
  61. print(key + ":" + str(loss))

相关文章
DL之DNN:自定义五层DNN(5*100+ReLU+SGD/Momentum/AdaGrad/Adam四种最优化)对MNIST数据集训练进而比较不同方法的性能

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3191
赞同 0
评论 0 条
鸵鸟不安L0
粉丝 0 发表 6 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2947
【软件正版化】软件正版化工作要点  2867
统信UOS试玩黑神话:悟空  2828
信刻光盘安全隔离与信息交换系统  2723
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1256
grub引导程序无法找到指定设备和分区  1221
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  164
点击报名 | 京东2025校招进校行程预告  163
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  162
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
如何玩转信创开放社区—从小白进阶到专家 15
信创开放社区邀请他人注册的具体步骤如下 15
方德桌面操作系统 14
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
我有15积分有什么用? 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!