DL之LiR&DNN&CNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测
目录
数据集:Dataset之MNIST:MNIST(手写数字图片识别+csv文件)数据集简介、下载、使用方法之详细攻略
- classifier = skflow.TensorFlowLinearClassifier(
- n_classes=10, learning_rate=0.01)
- classifier.fit(X_train, y_train)
- linear_y_predict = classifier.predict(X_test)
-
-
- classifier = skflow.TensorFlowDNNClassifier(hidden_units=[200, 50, 10], n_classes = 10,
- learning_rate=0.01)
- classifier.fit(X_train, y_train)
- dnn_y_predict = classifier.predict(X_test)
-
-
-
- classifier = skflow.TensorFlowEstimator(
- model_fn=conv_model, n_classes=10, steps=20000,
- learning_rate=0.001)
- classifier.fit(X_train, y_train)
- classifier.predict(X_test)
- class TensorFlowDNNClassifier(TensorFlowEstimator, ClassifierMixin):
- """TensorFlow DNN Classifier model.
-
- Parameters:
- hidden_units: List of hidden units per layer.
- n_classes: Number of classes in the target.
- tf_master: TensorFlow master. Empty string is default for local.
- batch_size: Mini batch size.
- steps: Number of steps to run over data.
- optimizer: Optimizer name (or class), for example "SGD", "Adam",
- "Adagrad".
- learning_rate: If this is constant float value, no decay function is
- used.
- Instead, a customized decay function can be passed that accepts
- global_step as parameter and returns a Tensor.
- e.g. exponential decay function:
- def exp_decay(global_step):
- return tf.train.exponential_decay(
- learning_rate=0.1, global_step,
- decay_steps=2, decay_rate=0.001)
- class_weight: None or list of n_classes floats. Weight associated
- with
- classes for loss computation. If not given, all classes are suppose
- to have
- weight one.
- tf_random_seed: Random seed for TensorFlow initializers.
- Setting this value, allows consistency between reruns.
- continue_training: when continue_training is True, once initialized
- model will be continuely trained on every call of fit.
- num_cores: Number of cores to be used. (default: 4)
- early_stopping_rounds: Activates early stopping if this is not
- None.
- Loss needs to decrease at least every every
- <early_stopping_rounds>
- round(s) to continue training. (default: None)
- max_to_keep: The maximum number of recent checkpoint files to
- keep.
- As new files are created, older files are deleted.
- If None or 0, all checkpoint files are kept.
- Defaults to 5 (that is, the 5 most recent checkpoint files are kept.)
- keep_checkpoint_every_n_hours: Number of hours between each
- checkpoint
- to be saved. The default value of 10,000 hours effectively disables
- the feature.
- """
- def __init__(self, hidden_units, n_classes, tf_master="",
- batch_size=32,
- steps=200, optimizer="SGD", learning_rate=0.1,
- class_weight=None,
- tf_random_seed=42, continue_training=False, num_cores=4,
- verbose=1, early_stopping_rounds=None,
- max_to_keep=5, keep_checkpoint_every_n_hours=10000):
- self.hidden_units = hidden_units
- super(TensorFlowDNNClassifier, self).__init__(model_fn=self.
- _model_fn, n_classes=n_classes, tf_master=tf_master,
- batch_size=batch_size, steps=steps, optimizer=optimizer,
- learning_rate=learning_rate, class_weight=class_weight,
- tf_random_seed=tf_random_seed,
- continue_training=continue_training, num_cores=4,
- verbose=verbose, early_stopping_rounds=early_stopping_rounds,
- max_to_keep=max_to_keep,
- keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
-
- def _model_fn(self, X, y):
- return models.get_dnn_model(self.hidden_units, models.
- logistic_regression)(X, y)
-
- -meta"> @property
- def weights_(self):
- """Returns weights of the DNN weight layers."""
- weights = []
- for layer in range(len(self.hidden_units)):
- weights.append(self.get_tensor_value('dnn/layer%
- d/Linear/Matrix:0' % layer))
-
- weights.append(self.get_tensor_value
- ('logistic_regression/weights:0'))
- return weights
-
- -meta"> @property
- def bias_(self):
- """Returns bias of the DNN's bias layers."""
- biases = []
- for layer in range(len(self.hidden_units)):
- biases.append(self.get_tensor_value('dnn/layer%d/Linear/Bias:
- 0' % layer))
-
- biases.append(self.get_tensor_value('logistic_regression/bias:
- 0'))
- return biases
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
加入交流群
请使用微信扫一扫!