DL框架之AutoKeras框架:深度学习框架AutoKeras框架的简介、特点、安装、使用方法详细攻略


bro
Bro 2022-09-20 10:54:49 50864
分类专栏: 资讯

DL框架之AutoKeras框架:深度学习框架AutoKeras框架的简介、特点、安装、使用方法详细攻略

Paper:《Efficient Neural Architecture Search via Parameter Sharing》

目录

AutoKeras框架的简介

AutoKeras框架的特点

AutoKeras的安装

AutoKeras框架的使用方法


AutoKeras框架的简介

        AutoKeras是一个开源的,基于 Keras 的新型 AutoML 库。AutoKeras 是一个用于自动化机器学习的开源软件库,提供自动搜索深度学习模型的架构和超参数的功能。
(1)、Keras 是一个用 Python 编写的高级神经网络 API,能够在 TensorFlow、CNTK 或 Theano 之上运行。它的意义在于可以实现快速实验。而能够以最小的延迟把想法变成结果是顺利进行研究的关键。

       AutoKeras比AutoML伟大的地方就是开源,哈哈,开源就等同于免费!这是我非常喜欢的一点!!!简而言之,AutoML是给有钱的公司玩的,像我们这样做学术研究的, AutoKeras简直妙不可言!!

官方网站:https://autokeras.com/
项目github:https://github.com/jhfjhfj1/autokeras
TensorFlow版本:https://github.com/melodyguan/enas
PyTorch 版本:https://github.com/carpedm20/ENAS-pytorch

AutoKeras框架的特点

1、AutoKeras 基于非常易于使用的深度学习数据库 Keras,使用 ENAS 的方法。ENAS 是 NAS 的最新版本,因此让 AutoKeras 具有高效、安装简单、参数可调、易修改等特点。

AutoKeras的安装

1、安装AutoKeras

pip install autokeras


2、测试

  1. import autokeras as ak
  2. clf = ak.ImageClassifier()
  3. clf.fit(x_train, y_train)
  4. results = clf.predict(x_test)
  5. 导出模型
  6. from autokeras import ImageClassifier
  7. clf = ImageClassifier(verbose=True, augment=False)
  8. clf.load_searcher().load_best_model().produce_keras_model().save('my_model.h5')
  9. 可视化模型
  10. from keras.models import load_model
  11. model = load_model('my_model.h5') See 'How to export keras models?' to generate this file before loading it.
  12. from keras.utils import plot_model
  13. plot_model(model, to_file='my_model.png')

AutoKeras框架的使用方法

1、举个栗子

  1. from keras.datasets import mnist
  2. from autokeras.image_supervised import ImageClassifier
  3. if __name__ == '__main__':
  4. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  5. x_train = x_train.reshape(x_train.shape + (1,))
  6. x_test = x_test.reshape(x_test.shape + (1,))
  7. clf = ImageClassifier(verbose=True)
  8. clf.fit(x_train, y_train, time_limit=12 * 60 * 60)
  9. clf.final_fit(x_train, y_train, x_test, y_test, retrain=True)
  10. y = clf.evaluate(x_test, y_test)
  11. print(y)

2、再举一个栗子

  1. coding:utf-8
  2. import os
  3. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from scipy.misc import imresize
  7. import cv2
  8. from autokeras.image_supervised import ImageClassifier
  9. from sklearn.metrics import accuracy_score
  10. from keras.models import load_model
  11. from keras.utils import plot_model
  12. import time
  13. start = time.time()
  14. def read_img(path,class_num):
  15. imgName_list = os.listdir(path)
  16. n = len(imgName_list)
  17. img_index,img_colummns,img_rgbSize = plt.imread(path+'/'+imgName_list[0]).shape
  18. img_index, img_colummns = [28,38] 这个设置很重要。如果你的电脑很好的话可以忽略设置。要不然内存不足的。
  19. print(img_index,img_colummns)
  20. data = np.zeros([n,img_index,img_colummns,1])
  21. label = np.zeros([n,1])
  22. class_number = 0
  23. for i in range(n):
  24. imgPath = path+'/'+imgName_list[i]
  25. data[i,:,:,0] = imresize(cv2.cvtColor(plt.imread(imgPath),cv2.COLOR_BGR2GRAY),[img_index,img_colummns])
  26. if (i)%(class_num) == 0:
  27. class_number = class_number+1
  28. label[i,0] = class_number
  29. return data,label
  30. x_train,y_train = read_img('./data/re/train',80)
  31. x_test,y_test = read_img('./data/re/test',20)
  32. animal = ['bus', 'dinosaur', 'flower', 'horse', 'elephant'] 动物类别对应 labelValue 为 [1,2,3,4,5]
  33. plt.imshow(x_test[0,:,:,0],cmap='gray')
  34. plt.show()
  35. if __name__=='__main__':
  36. 模型构建
  37. model = ImageClassifier(verbose=True)
  38. 搜索网络模型
  39. model.fit(x_train,y_train,time_limit=1*60)
  40. 验证最优模型
  41. model.final_fit(x_train,y_train,x_test,y_test,retrain=True)
  42. 给出评估结果
  43. score = model.evaluate(x_test,y_test)
  44. 识别结果
  45. y_predict = model.predict(x_test)
  46. 精确度
  47. accuracy = accuracy_score(y_test,y_predict)
  48. 打印出score与accuracy
  49. print('score:',score,' accuracy:',accuracy)
  50. model_dir = r'./modelStructure/imgModel.h5'
  51. model_img = r'./modelStructure/imgModel_ST.png'
  52. 保存可视化模型
  53. model.load_searcher().load_best_model().produce_keras_model().save(model_dir)
  54. 加载模型
  55. automodel = load_model(model_dir)
  56. 输出模型 structure 图
  57. plot_model(automodel, to_file=model_img)
  58. end = time.time()
  59. print(end-start)

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=4143
赞同 0
评论 0 条
BroL0
粉丝 0 发表 15 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2941
【软件正版化】软件正版化工作要点  2860
统信UOS试玩黑神话:悟空  2819
信刻光盘安全隔离与信息交换系统  2712
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1246
grub引导程序无法找到指定设备和分区  1213
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!