TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)


皮皮
皮皮 2022-09-19 16:39:27 51077
分类专栏: 资讯

TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)

目录

1、基于CIFAR-10数据集训练CNN(2+2)模型代码

2、检测CNN(2+2)模型

3、TensorBoard查看损失的变化曲线


1、基于CIFAR-10数据集训练CNN(2+2)模型代码

  1. from datetime import datetime
  2. import time
  3. import tensorflow as tf
  4. import cifar10
  5. FLAGS = tf.app.flags.FLAGS
  6. tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
  7. """Directory where to write event logs """
  8. """and checkpoint.""") 写入事件日志和检查点的目录
  9. tf.app.flags.DEFINE_integer('max_steps', 1000000,
  10. """Number of batches to run.""") 要运行的批次数
  11. tf.app.flags.DEFINE_boolean('log_device_placement', False,
  12. """Whether to log device placement.""") 是否记录设备放置
  13. tf.app.flags.DEFINE_integer('log_frequency', 10,
  14. """How often to log results to the console.""") 将结果记录到控制台的频率
  15. def train():
  16. """Train CIFAR-10 for a number of steps."""
  17. with tf.Graph().as_default():
  18. global_step = tf.train.get_or_create_global_step() tf.contrib.framework.get_or_create_global_step()
  19. Get images and labels for CIFAR-10.
  20. images, labels = cifar10.distorted_inputs()
  21. Build a Graph that computes the logits predictions from the
  22. inference model.
  23. logits = cifar10.inference(images)
  24. Calculate loss.
  25. loss = cifar10.loss(logits, labels)
  26. Build a Graph that trains the model with one batch of examples and
  27. updates the model parameters.
  28. train_op = cifar10.train(loss, global_step)
  29. class _LoggerHook(tf.train.SessionRunHook):
  30. """Logs loss and runtime."""
  31. def begin(self):
  32. self._step = -1
  33. self._start_time = time.time()
  34. def before_run(self, run_context):
  35. self._step += 1
  36. return tf.train.SessionRunArgs(loss) Asks for loss value.
  37. def after_run(self, run_context, run_values):
  38. if self._step % FLAGS.log_frequency == 0:
  39. current_time = time.time()
  40. duration = current_time - self._start_time
  41. self._start_time = current_time
  42. loss_value = run_values.results
  43. examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
  44. sec_per_batch = float(duration / FLAGS.log_frequency)
  45. format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
  46. 'sec/batch)')
  47. print(format_str % (datetime.now(), self._step, loss_value,
  48. examples_per_sec, sec_per_batch))
  49. with tf.train.MonitoredTrainingSession(
  50. checkpoint_dir=FLAGS.train_dir, FLAGS.train_dir,写入事件日志和检查点的目录
  51. hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), FLAGS.max_steps,要运行的批次数
  52. tf.train.NanTensorHook(loss),
  53. _LoggerHook()],
  54. config=tf.ConfigProto(
  55. log_device_placement=FLAGS.log_device_placement)) as mon_sess: Whether to log device placement
  56. while not mon_sess.should_stop():
  57. mon_sess.run(train_op)
  58. def main(argv=None): pylint: disable=unused-argument
  59. cifar10.maybe_download_and_extract()
  60. if tf.gfile.Exists(FLAGS.train_dir):
  61. tf.gfile.DeleteRecursively(FLAGS.train_dir)
  62. tf.gfile.MakeDirs(FLAGS.train_dir)
  63. train()
  64. if __name__ == '__main__':
  65. FLAGS.train_dir='cifarlO_train/'
  66. FLAGS.max_steps='1000000'
  67. FLAGS.log_device_placement='False'
  68. FLAGS.log_frequency='10'
  69. tf.app.run()

控制台输出结果

Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
2018-09-21 11:15:53.399945: step 0, loss = 4.67 (0.7 examples/sec; 177.888 sec/batch)
2018-09-21 11:17:13.770461: step 10, loss = 4.62 (15.9 examples/sec; 8.037 sec/batch)
2018-09-21 11:19:10.122213: step 20, loss = 4.36 (11.0 examples/sec; 11.635 sec/batch)
2018-09-21 11:21:01.145664: step 30, loss = 4.34 (11.5 examples/sec; 11.102 sec/batch)
2018-09-21 11:22:55.463296: step 40, loss = 4.37 (11.2 examples/sec; 11.432 sec/batch)
2018-09-21 11:24:43.938444: step 50, loss = 4.45 (11.8 examples/sec; 10.848 sec/batch)
2018-09-21 11:26:36.091383: step 60, loss = 4.29 (11.4 examples/sec; 11.215 sec/batch)
2018-09-21 11:28:27.229967: step 70, loss = 4.12 (11.5 examples/sec; 11.114 sec/batch)
2018-09-21 11:30:24.759522: step 80, loss = 4.04 (10.9 examples/sec; 11.753 sec/batch)
2018-09-21 11:32:04.392507: step 90, loss = 4.14 (12.8 examples/sec; 9.963 sec/batch)
2018-09-21 11:33:50.161788: step 100, loss = 4.08 (12.1 examples/sec; 10.577 sec/batch)
2018-09-21 11:35:27.867156: step 110, loss = 4.05 (13.1 examples/sec; 9.771 sec/batch)
2018-09-21 11:36:59.189017: step 120, loss = 3.99 (14.0 examples/sec; 9.132 sec/batch)
2018-09-21 11:38:44.246431: step 130, loss = 3.93 (12.2 examples/sec; 10.506 sec/batch)
2018-09-21 11:40:27.267226: step 140, loss = 4.12 (12.4 examples/sec; 10.302 sec/batch)
2018-09-21 11:42:20.492360: step 150, loss = 3.94 (11.3 examples/sec; 11.323 sec/batch)
2018-09-21 11:44:05.324174: step 160, loss = 3.93 (12.2 examples/sec; 10.483 sec/batch)
2018-09-21 11:45:45.123575: step 170, loss = 3.80 (12.8 examples/sec; 9.980 sec/batch)
2018-09-21 11:47:31.441841: step 180, loss = 3.95 (12.0 examples/sec; 10.632 sec/batch)
2018-09-21 11:49:19.129222: step 190, loss = 3.90 (11.9 examples/sec; 10.769 sec/batch)
2018-09-21 11:50:58.325049: step 200, loss = 4.15 (12.9 examples/sec; 9.920 sec/batch)
2018-09-21 11:52:34.784594: step 210, loss = 3.92 (13.3 examples/sec; 9.646 sec/batch)
2018-09-21 11:54:32.453522: step 220, loss = 3.81 (10.9 examples/sec; 11.767 sec/batch)
2018-09-21 11:56:33.002429: step 230, loss = 3.87 (10.6 examples/sec; 12.055 sec/batch)
2018-09-21 11:58:19.417427: step 240, loss = 3.67 (12.0 examples/sec; 10.641 sec/batch)

2、检测CNN(2+2)模型

       检测模型在CIFAR-10 测试数据集上的准确性,实际上到6万步左右时, 模型就有了85.99%的准确率,到10万步时的准确率为86.38%,到15万步后的准确率基本稳定在86.66%左右。

3、TensorBoard查看损失的变化曲线

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3652
赞同 0
评论 0 条
皮皮L1
粉丝 0 发表 9 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2947
【软件正版化】软件正版化工作要点  2867
统信UOS试玩黑神话:悟空  2828
信刻光盘安全隔离与信息交换系统  2723
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1256
grub引导程序无法找到指定设备和分区  1221
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  164
点击报名 | 京东2025校招进校行程预告  163
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  162
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
如何玩转信创开放社区—从小白进阶到专家 15
信创开放社区邀请他人注册的具体步骤如下 15
方德桌面操作系统 14
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
我有15积分有什么用? 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!