DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程


陶醉等于白猫
陶醉等于白猫 2022-09-19 15:07:33 50157
分类专栏: 资讯

DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程

目录

输出结果

设计思路

核心代码

更多输出


相关文章
DL之DNN优化技术:采用三种激活函数(sigmoid、relu、tanh)构建5层神经网络,权重初始值(He初始化和Xavier初始化)影响隐藏层的激活值分布的直方图可视化
DL之DNN优化技术:自定义MultiLayerNet【5*100+ReLU】对MNIST数据集训练进而比较三种权重初始值(Xavier初始化、He初始化)性能差异
DL之DNN优化技术:利用MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程
DL之DNN优化技术:DNN中参数初始化【Lecun参数初始化、He参数初始化和Xavier参数初始化】的简介、使用方法详细攻略
DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程全部代码

输出结果

更多输出详见最后

设计思路

核心代码

  1. (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
  2. x_train = x_train[:1000]
  3. t_train = t_train[:1000]
  4. max_epochs = 20
  5. train_size = x_train.shape[0]
  6. batch_size = 100
  7. learning_rate = 0.01
  8. bn_network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
  9. weight_init_std=weight_init_std, use_batchnorm=True)
  10. network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
  11. weight_init_std=weight_init_std)
  12. optimizer = SGD(lr=learning_rate)
  13. train_acc_list = []
  14. bn_train_acc_list = []
  15. iter_per_epoch = max(train_size / batch_size, 1)
  16. for i in range(1000000000):
  17. 定义x_batch、t_batch
  18. batch_mask = np.random.choice(train_size, batch_size)
  19. x_batch = x_train[batch_mask]
  20. t_batch = t_train[batch_mask]
  21. for _network in (bn_network, network):
  22. grads = _network.gradient(x_batch, t_batch)
  23. optimizer.update(_network.params, grads)
  24. if i % iter_per_epoch == 0:
  25. train_acc = network.accuracy(x_train, t_train)
  26. bn_train_acc = bn_network.accuracy(x_train, t_train)
  27. train_acc_list.append(train_acc)
  28. bn_train_acc_list.append(bn_train_acc)
  29. print("epoch:" + str(epoch_cnt) + " | " + str(train_acc) + " - " + str(bn_train_acc))
  30. epoch_cnt += 1
  31. if epoch_cnt >= max_epochs:
  32. break
  33. return train_acc_list, bn_train_acc_list

更多输出

  1. ============== 1/16 ==============
  2. epoch:0 | 0.093 - 0.085
  3. epoch:1 | 0.117 - 0.08
  4. epoch:2 | 0.117 - 0.081
  5. epoch:3 | 0.117 - 0.1
  6. epoch:4 | 0.117 - 0.125
  7. epoch:5 | 0.117 - 0.143
  8. epoch:6 | 0.117 - 0.163
  9. epoch:7 | 0.117 - 0.191
  10. epoch:8 | 0.117 - 0.213
  11. epoch:9 | 0.117 - 0.236
  12. epoch:10 | 0.117 - 0.258
  13. epoch:11 | 0.117 - 0.268
  14. epoch:12 | 0.117 - 0.28
  15. epoch:13 | 0.117 - 0.297
  16. epoch:14 | 0.117 - 0.31
  17. epoch:15 | 0.117 - 0.322
  18. epoch:16 | 0.117 - 0.335
  19. epoch:17 | 0.117 - 0.36
  20. epoch:18 | 0.116 - 0.378
  21. epoch:19 | 0.117 - 0.391
  22. ============== 2/16 ==============
  23. epoch:0 | 0.087 - 0.099
  24. epoch:1 | 0.097 - 0.108
  25. epoch:2 | 0.097 - 0.151
  26. epoch:3 | 0.097 - 0.185
  27. epoch:4 | 0.097 - 0.216
  28. epoch:5 | 0.097 - 0.226
  29. epoch:6 | 0.097 - 0.243
  30. epoch:7 | 0.097 - 0.281
  31. epoch:8 | 0.097 - 0.306
  32. epoch:9 | 0.097 - 0.323
  33. epoch:10 | 0.097 - 0.344
  34. epoch:11 | 0.097 - 0.364
  35. epoch:12 | 0.097 - 0.38
  36. epoch:13 | 0.097 - 0.394
  37. epoch:14 | 0.097 - 0.402
  38. epoch:15 | 0.097 - 0.415
  39. epoch:16 | 0.097 - 0.441
  40. epoch:17 | 0.097 - 0.454
  41. epoch:18 | 0.097 - 0.464
  42. epoch:19 | 0.097 - 0.48
  43. ============== 3/16 ==============
  44. epoch:0 | 0.104 - 0.108
  45. epoch:1 | 0.364 - 0.111
  46. epoch:2 | 0.499 - 0.121
  47. epoch:3 | 0.587 - 0.153
  48. ……
  49. epoch:17 | 0.116 - 0.62
  50. epoch:18 | 0.116 - 0.615
  51. epoch:19 | 0.116 - 0.652
  52. ============== 16/16 ==============
  53. epoch:0 | 0.092 - 0.092
  54. epoch:1 | 0.094 - 0.288
  55. epoch:2 | 0.116 - 0.373
  56. epoch:3 | 0.116 - 0.407
  57. epoch:4 | 0.116 - 0.416
  58. epoch:5 | 0.116 - 0.418
  59. epoch:6 | 0.116 - 0.488
  60. epoch:7 | 0.117 - 0.493
  61. epoch:8 | 0.117 - 0.502
  62. epoch:9 | 0.117 - 0.517
  63. epoch:10 | 0.117 - 0.52
  64. epoch:11 | 0.117 - 0.507
  65. epoch:12 | 0.117 - 0.524
  66. epoch:13 | 0.117 - 0.521
  67. epoch:14 | 0.117 - 0.523
  68. epoch:15 | 0.117 - 0.522
  69. epoch:16 | 0.117 - 0.522
  70. epoch:17 | 0.116 - 0.523
  71. epoch:18 | 0.116 - 0.481
  72. epoch:19 | 0.116 - 0.509

相关文章
CSDN:2019.04.09起

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3179
赞同 0
评论 0 条
陶醉等于白猫L0
粉丝 0 发表 10 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2941
【软件正版化】软件正版化工作要点  2860
统信UOS试玩黑神话:悟空  2819
信刻光盘安全隔离与信息交换系统  2712
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1246
grub引导程序无法找到指定设备和分区  1213
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!