DL之LiR&DNN&CNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测


尊敬人生
尊敬人生 2022-09-19 15:21:32 49356
分类专栏: 资讯

DL之LiR&DNN&CNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测

目录

输出结果

设计思路

核心代码


输出结果

数据集:Dataset之MNIST:MNIST(手写数字图片识别+csv文件)数据集简介、下载、使用方法之详细攻略

设计思路

核心代码

  1. classifier = skflow.TensorFlowLinearClassifier(
  2. n_classes=10, learning_rate=0.01)
  3. classifier.fit(X_train, y_train)
  4. linear_y_predict = classifier.predict(X_test)
  5. classifier = skflow.TensorFlowDNNClassifier(hidden_units=[200, 50, 10], n_classes = 10,
  6. learning_rate=0.01)
  7. classifier.fit(X_train, y_train)
  8. dnn_y_predict = classifier.predict(X_test)
  9. classifier = skflow.TensorFlowEstimator(
  10. model_fn=conv_model, n_classes=10, steps=20000,
  11. learning_rate=0.001)
  12. classifier.fit(X_train, y_train)
  13. classifier.predict(X_test)
  1. class TensorFlowDNNClassifier(TensorFlowEstimator, ClassifierMixin):
  2. """TensorFlow DNN Classifier model.
  3. Parameters:
  4. hidden_units: List of hidden units per layer.
  5. n_classes: Number of classes in the target.
  6. tf_master: TensorFlow master. Empty string is default for local.
  7. batch_size: Mini batch size.
  8. steps: Number of steps to run over data.
  9. optimizer: Optimizer name (or class), for example "SGD", "Adam",
  10. "Adagrad".
  11. learning_rate: If this is constant float value, no decay function is
  12. used.
  13. Instead, a customized decay function can be passed that accepts
  14. global_step as parameter and returns a Tensor.
  15. e.g. exponential decay function:
  16. def exp_decay(global_step):
  17. return tf.train.exponential_decay(
  18. learning_rate=0.1, global_step,
  19. decay_steps=2, decay_rate=0.001)
  20. class_weight: None or list of n_classes floats. Weight associated
  21. with
  22. classes for loss computation. If not given, all classes are suppose
  23. to have
  24. weight one.
  25. tf_random_seed: Random seed for TensorFlow initializers.
  26. Setting this value, allows consistency between reruns.
  27. continue_training: when continue_training is True, once initialized
  28. model will be continuely trained on every call of fit.
  29. num_cores: Number of cores to be used. (default: 4)
  30. early_stopping_rounds: Activates early stopping if this is not
  31. None.
  32. Loss needs to decrease at least every every
  33. <early_stopping_rounds>
  34. round(s) to continue training. (default: None)
  35. max_to_keep: The maximum number of recent checkpoint files to
  36. keep.
  37. As new files are created, older files are deleted.
  38. If None or 0, all checkpoint files are kept.
  39. Defaults to 5 (that is, the 5 most recent checkpoint files are kept.)
  40. keep_checkpoint_every_n_hours: Number of hours between each
  41. checkpoint
  42. to be saved. The default value of 10,000 hours effectively disables
  43. the feature.
  44. """
  45. def __init__(self, hidden_units, n_classes, tf_master="",
  46. batch_size=32,
  47. steps=200, optimizer="SGD", learning_rate=0.1,
  48. class_weight=None,
  49. tf_random_seed=42, continue_training=False, num_cores=4,
  50. verbose=1, early_stopping_rounds=None,
  51. max_to_keep=5, keep_checkpoint_every_n_hours=10000):
  52. self.hidden_units = hidden_units
  53. super(TensorFlowDNNClassifier, self).__init__(model_fn=self.
  54. _model_fn, n_classes=n_classes, tf_master=tf_master,
  55. batch_size=batch_size, steps=steps, optimizer=optimizer,
  56. learning_rate=learning_rate, class_weight=class_weight,
  57. tf_random_seed=tf_random_seed,
  58. continue_training=continue_training, num_cores=4,
  59. verbose=verbose, early_stopping_rounds=early_stopping_rounds,
  60. max_to_keep=max_to_keep,
  61. keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
  62. def _model_fn(self, X, y):
  63. return models.get_dnn_model(self.hidden_units, models.
  64. logistic_regression)(X, y)
  65. -meta"> @property
  66. def weights_(self):
  67. """Returns weights of the DNN weight layers."""
  68. weights = []
  69. for layer in range(len(self.hidden_units)):
  70. weights.append(self.get_tensor_value('dnn/layer%
  71. d/Linear/Matrix:0' % layer))
  72. weights.append(self.get_tensor_value
  73. ('logistic_regression/weights:0'))
  74. return weights
  75. -meta"> @property
  76. def bias_(self):
  77. """Returns bias of the DNN's bias layers."""
  78. biases = []
  79. for layer in range(len(self.hidden_units)):
  80. biases.append(self.get_tensor_value('dnn/layer%d/Linear/Bias:
  81. 0' % layer))
  82. biases.append(self.get_tensor_value('logistic_regression/bias:
  83. 0'))
  84. return biases

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3244
赞同 0
评论 0 条
尊敬人生L0
粉丝 0 发表 7 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2941
【软件正版化】软件正版化工作要点  2860
统信UOS试玩黑神话:悟空  2819
信刻光盘安全隔离与信息交换系统  2712
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1246
grub引导程序无法找到指定设备和分区  1213
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!