CV:基于Keras利用CNN主流架构之mini_XCEPTION训练情感分类模型hdf5并保存到指定文件夹下


曾经笑鼠标
曾经笑鼠标 2022-09-20 09:56:43 50257
分类专栏: 资讯

CV:基于Keras利用CNN主流架构之mini_XCEPTION训练情感分类模型hdf5并保存到指定文件夹下

目录

图示过程

核心代码


图示过程

核心代码

  1. def mini_XCEPTION(input_shape, num_classes, l2_regularization=0.01):
  2. regularization = l2(l2_regularization)
  3. base
  4. img_input = Input(input_shape)
  5. x = Conv2D(8, (3, 3), strides=(1, 1), kernel_regularizer=regularization,
  6. use_bias=False)(img_input)
  7. x = BatchNormalization()(x)
  8. x = Activation('relu')(x)
  9. x = Conv2D(8, (3, 3), strides=(1, 1), kernel_regularizer=regularization,
  10. use_bias=False)(x)
  11. x = BatchNormalization()(x)
  12. x = Activation('relu')(x)
  13. module 1
  14. residual = Conv2D(16, (1, 1), strides=(2, 2),
  15. padding='same', use_bias=False)(x)
  16. residual = BatchNormalization()(residual)
  17. x = SeparableConv2D(16, (3, 3), padding='same',
  18. kernel_regularizer=regularization,
  19. use_bias=False)(x)
  20. x = BatchNormalization()(x)
  21. x = Activation('relu')(x)
  22. x = SeparableConv2D(16, (3, 3), padding='same',
  23. kernel_regularizer=regularization,
  24. use_bias=False)(x)
  25. x = BatchNormalization()(x)
  26. x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
  27. x = layers.add([x, residual])
  28. module 2
  29. residual = Conv2D(32, (1, 1), strides=(2, 2),
  30. padding='same', use_bias=False)(x)
  31. residual = BatchNormalization()(residual)
  32. x = SeparableConv2D(32, (3, 3), padding='same',
  33. kernel_regularizer=regularization,
  34. use_bias=False)(x)
  35. x = BatchNormalization()(x)
  36. x = Activation('relu')(x)
  37. x = SeparableConv2D(32, (3, 3), padding='same',
  38. kernel_regularizer=regularization,
  39. use_bias=False)(x)
  40. x = BatchNormalization()(x)
  41. x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
  42. x = layers.add([x, residual])
  43. module 3
  44. residual = Conv2D(64, (1, 1), strides=(2, 2),
  45. padding='same', use_bias=False)(x)
  46. residual = BatchNormalization()(residual)
  47. x = SeparableConv2D(64, (3, 3), padding='same',
  48. kernel_regularizer=regularization,
  49. use_bias=False)(x)
  50. x = BatchNormalization()(x)
  51. x = Activation('relu')(x)
  52. x = SeparableConv2D(64, (3, 3), padding='same',
  53. kernel_regularizer=regularization,
  54. use_bias=False)(x)
  55. x = BatchNormalization()(x)
  56. x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
  57. x = layers.add([x, residual])
  58. module 4
  59. residual = Conv2D(128, (1, 1), strides=(2, 2),
  60. padding='same', use_bias=False)(x)
  61. residual = BatchNormalization()(residual)
  62. x = SeparableConv2D(128, (3, 3), padding='same',
  63. kernel_regularizer=regularization,
  64. use_bias=False)(x)
  65. x = BatchNormalization()(x)
  66. x = Activation('relu')(x)
  67. x = SeparableConv2D(128, (3, 3), padding='same',
  68. kernel_regularizer=regularization,
  69. use_bias=False)(x)
  70. x = BatchNormalization()(x)
  71. x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
  72. x = layers.add([x, residual])
  73. x = Conv2D(num_classes, (3, 3),
  74. kernel_regularizer=regularization,
  75. padding='same')(x)
  76. x = GlobalAveragePooling2D()(x)
  77. output = Activation('softmax',name='predictions')(x)
  78. model = Model(img_input, output)
  79. return model
  1. CV:利用CNN主流架构之一的XCEPTION训练情感分类模型.hdf5并保存到指定文件夹下边
  2. from keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping
  3. from keras.callbacks import ReduceLROnPlateau
  4. from keras.preprocessing.image import ImageDataGenerator
  5. from models.cnn import mini_XCEPTION
  6. parameters 1、定义参数:每个batch的采样本数、训练轮数、输入shape、部分比例分离用于验证、冗长参数、分类个数、patience、loghdf5保存路径
  7. batch_size = 32 整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。
  8. num_epochs = 10000 整数,训练终止时的epoch值,训练将在达到该epoch值时停止,当没有设置initial_epoch时,它就是训练的总轮数,否则训练的总轮数为epochs - inital_epoch
  9. input_shape = (64, 64, 1)
  10. validation_split = .2 0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。
  11. verbose = 1 日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
  12. num_classes = 7
  13. patience = 50 当monitor不再有改善的时候就会停止训练,这个可以通过patience看出来
  14. base_path = '../trained_models/emotion_models/'
  15. data generator调用ImageDataGenerator函数实现实时数据增强生成小批量的图像数据。
  16. data_generator = ImageDataGenerator(
  17. featurewise_center=False,
  18. featurewise_std_normalization=False,
  19. rotation_range=10,
  20. width_shift_range=0.1,
  21. height_shift_range=0.1,
  22. zoom_range=.1,
  23. horizontal_flip=True)
  24. model parameters/compilation2、建立XCEPTION模型并compile编译配置参数,最后输出网络摘要
  25. model = mini_XCEPTION(input_shape, num_classes) mini_XCEPTION函数(XCEPTION是属于CNN下目前最新的一种模型)实现输入形状、分类个数两个参数建立模型
  26. model.compile(optimizer='adam', loss='categorical_crossentropy', model.compile函数(属于keras库)用来配置训练模型参数,可以指定你设想的随机梯度下降中的网络的损失函数、优化方式等参数
  27. metrics=['accuracy'])
  28. model.summary() Prints a string summary of the network.
  29. 3、指定要训练的数据集(emotion→fer2013即喜怒哀乐数据集)
  30. datasets = ['fer2013']
  31. 4、for循环实现callbacks、loading dataset
  32. for dataset_name in datasets:
  33. print('Training dataset:', dataset_name)
  34. callbacks回调:通过调用CSVLogger、EarlyStopping、ReduceLROnPlateau、ModelCheckpoint等函数得到训练参数存到一个list内
  35. log_file_path = base_path + dataset_name + '_emotion_training.log'
  36. csv_logger = CSVLogger(log_file_path, append=False) Callback that streams epoch results to a csv file.
  37. early_stop = EarlyStopping('val_loss', patience=patience) Stop training when a monitored quantity has stopped improving.
  38. reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, Reduce learning rate when a metric has stopped improving.
  39. patience=int(patience/4), verbose=1)
  40. trained_models_path = base_path + dataset_name + '_mini_XCEPTION'
  41. model_names = trained_models_path + '.{epoch:02d}-{val_acc:.2f}.hdf5'
  42. model_checkpoint = ModelCheckpoint(model_names, 'val_loss', verbose=1, Save the model after every epoch
  43. save_best_only=True)
  44. callbacks = [model_checkpoint, csv_logger, early_stop, reduce_lr]
  45. loading dataset加载数据集:通过调用DataManager、
  46. data_loader = DataManager(dataset_name, image_size=input_shape[:2]) 自定义DataManager函数实现根据数据集name进行加载
  47. faces, emotions = data_loader.get_data() 自定义get_data函数根据不同数据集name得到各自的ground truth data,
  48. faces = preprocess_input(faces) 自定义preprocess_input函数:处理输入的数据,先转为float32类型然后/ 255.0
  49. num_samples, num_classes = emotions.shape shape函数读取矩阵的长度
  50. train_data, val_data = split_data(faces, emotions, validation_split) 自定义split_data对数据整理各取所得train_data、 val_data
  51. train_faces, train_emotions = train_data
  52. training model调用fit_generator函数训练模型
  53. model.fit_generator(data_generator.flow(train_faces, train_emotions, flow函数返回Numpy Array Iterator迭代
  54. batch_size),
  55. steps_per_epoch=len(train_faces) / batch_size,
  56. epochs=num_epochs, verbose=1, callbacks=callbacks,
  57. validation_data=val_data) fit_generator函数Fits the model on data generated batch-by-batch by a Python generator

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3982
赞同 0
评论 0 条
曾经笑鼠标L0
粉丝 0 发表 8 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2935
【软件正版化】软件正版化工作要点  2854
统信UOS试玩黑神话:悟空  2811
信刻光盘安全隔离与信息交换系统  2702
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1235
grub引导程序无法找到指定设备和分区  1205
点击报名 | 京东2025校招进校行程预告  162
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  160
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  156
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  154
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!