TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)


冰淇淋虚心
冰淇淋虚心 2022-09-20 09:58:37 51278
分类专栏: 资讯

TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

目录

输出结果

设计思路

代码设计


输出结果

  1. 0 accuracy 0.125
  2. 20 accuracy 0.6484375
  3. 40 accuracy 0.78125
  4. 60 accuracy 0.9296875
  5. 80 accuracy 0.8671875
  6. 100 accuracy 0.90625
  7. 120 accuracy 0.8671875
  8. 140 accuracy 0.8671875
  9. 160 accuracy 0.8671875
  10. 180 accuracy 0.921875
  11. 200 accuracy 0.890625
  12. 220 accuracy 0.953125
  13. 240 accuracy 0.921875
  14. 260 accuracy 0.9296875
  15. 280 accuracy 0.9140625
  16. 300 accuracy 0.921875
  17. 320 accuracy 0.9609375
  18. 340 accuracy 0.953125
  19. 360 accuracy 0.984375
  20. 380 accuracy 0.921875
  21. 400 accuracy 0.9453125
  22. 420 accuracy 0.921875
  23. 440 accuracy 0.9296875
  24. 460 accuracy 0.96875
  25. 480 accuracy 0.984375
  26. 500 accuracy 0.96875
  27. 520 accuracy 0.953125
  28. 540 accuracy 0.96875
  29. 560 accuracy 0.953125
  30. 580 accuracy 0.9921875
  31. 600 accuracy 0.984375
  32. 620 accuracy 0.953125
  33. 640 accuracy 0.953125
  34. 660 accuracy 0.9921875
  35. 680 accuracy 0.96875
  36. 700 accuracy 0.9765625
  37. 720 accuracy 0.96875
  38. 740 accuracy 0.9921875
  39. 760 accuracy 0.984375
  40. 780 accuracy 0.953125

设计思路

代码设计

  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
  4. lr=0.001
  5. training_iters=100000
  6. batch_size=128
  7. n_inputs=28
  8. n_steps=28
  9. n_hidden_units=128
  10. n_classes=10
  11. x=tf.placeholder(tf.float32, [None,n_steps,n_inputs])
  12. y=tf.placeholder(tf.float32, [None,n_classes])
  13. weights ={
  14. 'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),
  15. 'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),
  16. }
  17. biases ={
  18. 'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),
  19. 'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),
  20. }
  21. def RNN(X,weights,biases):
  22. X=tf.reshape(X,[-1,n_inputs])
  23. X_in=tf.matmul(X,weights['in'])+biases['in']
  24. X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])
  25. lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)
  26. __init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)
  27. outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)
  28. outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))
  29. results=tf.matmul(outputs[-1],weights['out'])+biases['out']
  30. return results
  31. pred =RNN(x,weights,biases)
  32. cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
  33. train_op=tf.train.AdamOptimizer(lr).minimize(cost)
  34. correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
  35. accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
  36. <br>
  37. with tf.Session() as sess:
  38. sess.run(init)
  39. step=0
  40. while step*batch_size < training_iters:
  41. batch_xs,batch_ys=mnist.train.next_batch(batch_size)
  42. batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])
  43. sess.run([train_op],feed_dict={
  44. x:batch_xs,
  45. y:batch_ys,})
  46. if step%20==0:
  47. print(sess.run(accuracy,feed_dict={
  48. x:batch_xs,
  49. y:batch_ys,}))
  50. step+=1

相关文章
TF之LSTM:利用LSTM算法对mnist手写数字图片数据集训练、评估(偶尔100%准确度)

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3993
赞同 0
评论 0 条
冰淇淋虚心L0
粉丝 0 发表 8 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2944
【软件正版化】软件正版化工作要点  2863
统信UOS试玩黑神话:悟空  2823
信刻光盘安全隔离与信息交换系统  2717
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1251
grub引导程序无法找到指定设备和分区  1217
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!