DL之RBM:基于RBM实现手写数字图片识别提高准确率
目录
- import numpy as np
- import matplotlib.pyplot as plt
-
- from sklearn.model_selection import train_test_split
- from sklearn import metrics,linear_model
- from sklearn.neural_network import BernoulliRBM
- from sklearn.datasets import load_digits
- from sklearn.pipeline import Pipeline
-
- digits = load_digits()
- X = digits.data
- y = digits.target
-
- X -= X.min()
- X /= X.max()
- X_train, X_test, y_train, y_test = train_test_split(X, y)
-
-
- logistic = linear_model.LogisticRegression()
- rbm = BernoulliRBM(random_state=0, verbose=True)
- classifier = Pipeline(steps=[('rbm', rbm), ('logistic',logistic)])
-
- rbm.learning_rate = 0.06
- rbm.n_iter = 20
- rbm.n_components = 200
- logistic.C = 6000.0
- classifier.fit (X_train,y_train)
-
- print()
- print("Logistic regression using RBM features:\n%s\n"%(
- metrics.classification_report(y_test,classifier.predict(X_test))
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
加入交流群
请使用微信扫一扫!