CV:基于Keras利用CNN主流架构之mini_XCEPTION训练性别分类模型hdf5并保存到指定文件夹下


老俊
老俊 2022-09-20 09:56:25 53077
分类专栏: 资讯

CV:基于Keras利用CNN主流架构之mini_XCEPTION训练性别分类模型hdf5并保存到指定文件夹下

目录

图示过程

核心代码


图示过程

核心代码

  1. from keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping
  2. from keras.callbacks import ReduceLROnPlateau
  3. from models.cnn import mini_XCEPTION
  4. parameters1、定义参数:每个batch的采样本数、训练轮数、输入shape、部分比例分离用于验证、冗长参数、分类个数、patience、do_random_crop
  5. batch_size = 32
  6. num_epochs = 1000
  7. validation_split = .2
  8. do_random_crop = False random crop only works for classification since the current implementation does no transform bounding boxes
  9. patience = 100
  10. num_classes = 2
  11. dataset_name = 'imdb'
  12. input_shape = (64, 64, 1)
  13. if判断,然后指定图像、log、loghdf5各自保存路径
  14. if input_shape[2] == 1:
  15. grayscale = True
  16. images_path = '../datasets/imdb_crop/'
  17. log_file_path = '../trained_models/gender_models/gender_training.log'
  18. trained_models_path = '../trained_models/gender_models/gender_mini_XCEPTION'
  19. data loader
  20. data_loader = DataManager(dataset_name) 自定义DataManager函数实现根据数据集name进行加载
  21. ground_truth_data = data_loader.get_data() 自定义get_data函数根据不同数据集name得到各自的ground truth data,
  22. train_keys, val_keys = split_imdb_data(ground_truth_data, validation_split)
  23. print('Number of training samples:', len(train_keys))
  24. print('Number of validation samples:', len(val_keys))
  25. 调用ImageDataGenerator函数实现实时数据增强生成小批量的图像数据。
  26. image_generator = ImageGenerator(ground_truth_data, batch_size,
  27. input_shape[:2],
  28. train_keys, val_keys, None,
  29. path_prefix=images_path,
  30. vertical_flip_probability=0,
  31. grayscale=grayscale,
  32. do_random_crop=do_random_crop)
  33. model parameters/compilation2、建立XCEPTION模型并compile编译配置参数,最后输出网络摘要
  34. model = mini_XCEPTION(input_shape, num_classes)
  35. model.compile(optimizer='adam',
  36. loss='categorical_crossentropy',
  37. metrics=['accuracy'])
  38. model.summary()
  39. 3、指定要训练的数据集(gender→imdb即男女数据集)
  40. model callbacks
  41. callbacks4、回调:通过调用CSVLogger、EarlyStopping、ReduceLROnPlateau、ModelCheckpoint等函数得到训练参数存到一个list内
  42. early_stop = EarlyStopping('val_loss', patience=patience)
  43. reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1,
  44. patience=int(patience/2), verbose=1)
  45. csv_logger = CSVLogger(log_file_path, append=False)
  46. model_names = trained_models_path + '.{epoch:02d}-{val_acc:.2f}.hdf5'
  47. model_checkpoint = ModelCheckpoint(model_names,
  48. monitor='val_loss',
  49. verbose=1,
  50. save_best_only=True,
  51. save_weights_only=False)
  52. callbacks = [model_checkpoint, csv_logger, early_stop, reduce_lr]
  53. training model5、调用fit_generator函数训练模型
  54. model.fit_generator(image_generator.flow(mode='train'),
  55. steps_per_epoch=int(len(train_keys) / batch_size),
  56. epochs=num_epochs, verbose=1,
  57. callbacks=callbacks,
  58. validation_data=image_generator.flow('val'),
  59. validation_steps=int(len(val_keys) / batch_size))

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3981
赞同 0
评论 0 条
老俊L0
粉丝 0 发表 10 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2935
【软件正版化】软件正版化工作要点  2854
统信UOS试玩黑神话:悟空  2811
信刻光盘安全隔离与信息交换系统  2702
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1235
grub引导程序无法找到指定设备和分区  1205
点击报名 | 京东2025校招进校行程预告  162
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  160
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  156
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  154
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!