TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类


晴空万里
晴空万里 2022-09-19 13:13:24 64826
分类专栏: 资讯

TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类

目录

设计思路

实现代码


设计思路

更新……

实现代码

  1. -*- coding:utf-8 -*-
  2. import tensorflow as tf
  3. import numpy as np
  4. from tensorflow.contrib import rnn
  5. from tensorflow.examples.tutorials.mnist import input_data
  6. 根据电脑情况设置 GPU
  7. config = tf.ConfigProto()
  8. config.gpu_options.allow_growth = True
  9. sess = tf.Session(config=config)
  10. 1、定义数据集
  11. mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
  12. print(mnist.train.images.shape)
  13. 2、定义模型超参数
  14. lr = 1e-3
  15. batch_size = 128
  16. batch_size = tf.placeholder(tf.int32) 采用占位符的方式,因为在训练和测试的时候要用不同的batch_size。注意类型必须为 tf.int32
  17. input_size = 28 每个时刻的输入特征是28维的,就是每个时刻输入一行,一行有 28 个像素
  18. timestep_size = 28 时序持续长度为28,即每做一次预测,需要先输入28行
  19. hidden_size = 256 每个隐含层的节点数
  20. layer_num = 2 LSTM layer 的层数
  21. class_num = 10 最后输出分类类别数量,如果是回归预测的话应该是 1
  22. _X = tf.placeholder(tf.float32, [None, 784])
  23. y = tf.placeholder(tf.float32, [None, class_num])
  24. keep_prob = tf.placeholder(tf.float32)
  25. 3、LSTM模型的搭建、训练、测试
  26. 3.1、LSTM模型的搭建
  27. X = tf.reshape(_X, [-1, 28, 28]) RNN 的输入shape = (batch_size, timestep_size, input_size),把784个点的字符信息还原成 28 * 28 的图片
  28. lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True) 定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度
  29. lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob) 添加 dropout layer, 一般只设置 output_keep_prob
  30. mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True) 调用 MultiRNNCell来实现多层 LSTM
  31. init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32) 用全零来初始化state
  32. 3.2、LSTM模型的运行:构建好的网络运行起来
  33. T1、调用 dynamic_rnn()法
  34. ** 当 time_major==False 时, outputs.shape = [batch_size, timestep_size, hidden_size],所以,可以取 h_state = outputs[:, -1, :] 作为最后输出
  35. ** state.shape = [layer_num, 2, batch_size, hidden_size],或者,可以取 h_state = state[-1][1] 作为最后输出,最后输出维度是 [batch_size, hidden_size]
  36. outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)
  37. h_state = outputs[:, -1, :] 或者 h_state = state[-1][1]
  38. T2、自定义LSTM迭代按时间步展开计算:为了更好的理解 LSTM 工作原理把T1的函数自己来实现
  39. (1)、可以采用RNNCell的 __call__()函数,来实现LSTM按时间步迭代。
  40. outputs = list()
  41. state = init_state
  42. with tf.variable_scope('RNN'):
  43. for timestep in range(timestep_size):
  44. if timestep > 0:
  45. tf.get_variable_scope().reuse_variables()
  46. (cell_output, state) = mlstm_cell(X[:, timestep, :], state) 这里的state保存了每一层 LSTM 的状态
  47. outputs.append(cell_output)
  48. h_state = outputs[-1]
  49. 3.3、LSTM模型的训练
  50. 定义 softmax 的连接权重矩阵和偏置:上面 LSTM 部分的输出会是一个 [hidden_size] 的tensor,我们要分类的话,还需要接一个 softmax 层
  51. out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights')
  52. out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias')
  53. W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)
  54. bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)
  55. y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)
  56. 定义损失和评估函数
  57. cross_entropy = -tf.reduce_mean(y * tf.log(y_pre))
  58. train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)
  59. correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1))
  60. accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
  61. sess.run(tf.global_variables_initializer())
  62. for i in range(2000):
  63. _batch_size = 128
  64. batch = mnist.train.next_batch(_batch_size)
  65. if (i+1)%200 == 0:
  66. train_accuracy = sess.run(accuracy, feed_dict={
  67. _X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size})
  68. 已经迭代完成的 epoch 数: mnist.train.epochs_completed
  69. print("Iter%d, step %d, training accuracy %g" % ( mnist.train.epochs_completed, (i+1), train_accuracy))
  70. sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size})
  71. 计算测试数据的准确率
  72. print("test accuracy %g"% sess.run(accuracy, feed_dict={_X: mnist.test.images, y: mnist.test.labels,
  73. keep_prob: 1.0, batch_size:mnist.test.images.shape[0]}))

参考文章https://www.cnblogs.com/mfryf/p/7903958.html

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=2595
赞同 0
评论 0 条
晴空万里L0
粉丝 0 发表 11 + 关注 私信
上周热门
Kingbase用户权限管理  2008
信刻全自动光盘摆渡系统  1738
信刻国产化智能光盘柜管理系统  1407
银河麒麟添加网络打印机时,出现“client-error-not-possible”错误提示  1002
银河麒麟打印带有图像的文档时出错  906
银河麒麟添加打印机时,出现“server-error-internal-error”  698
麒麟系统也能完整体验微信啦!  645
统信桌面专业版【如何查询系统安装时间】  616
统信操作系统各版本介绍  607
统信桌面专业版【全盘安装UOS系统】介绍  582
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

添加我为好友,拉您入交流群!

请使用微信扫一扫!