CV:基于Keras利用CNN主流架构之mini_XCEPTION训练性别分类模型hdf5并保存到指定文件夹下
目录
- from keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping
- from keras.callbacks import ReduceLROnPlateau
- from models.cnn import mini_XCEPTION
-
- parameters1、定义参数:每个batch的采样本数、训练轮数、输入shape、部分比例分离用于验证、冗长参数、分类个数、patience、do_random_crop
- batch_size = 32
- num_epochs = 1000
- validation_split = .2
- do_random_crop = False random crop only works for classification since the current implementation does no transform bounding boxes
- patience = 100
- num_classes = 2
- dataset_name = 'imdb'
- input_shape = (64, 64, 1)
-
- if判断,然后指定图像、log、loghdf5各自保存路径
- if input_shape[2] == 1:
- grayscale = True
- images_path = '../datasets/imdb_crop/'
- log_file_path = '../trained_models/gender_models/gender_training.log'
- trained_models_path = '../trained_models/gender_models/gender_mini_XCEPTION'
-
-
- data loader
- data_loader = DataManager(dataset_name) 自定义DataManager函数实现根据数据集name进行加载
- ground_truth_data = data_loader.get_data() 自定义get_data函数根据不同数据集name得到各自的ground truth data,
- train_keys, val_keys = split_imdb_data(ground_truth_data, validation_split)
- print('Number of training samples:', len(train_keys))
- print('Number of validation samples:', len(val_keys))
-
- 调用ImageDataGenerator函数实现实时数据增强生成小批量的图像数据。
- image_generator = ImageGenerator(ground_truth_data, batch_size,
- input_shape[:2],
- train_keys, val_keys, None,
- path_prefix=images_path,
- vertical_flip_probability=0,
- grayscale=grayscale,
- do_random_crop=do_random_crop)
-
- model parameters/compilation2、建立XCEPTION模型并compile编译配置参数,最后输出网络摘要
- model = mini_XCEPTION(input_shape, num_classes)
- model.compile(optimizer='adam',
- loss='categorical_crossentropy',
- metrics=['accuracy'])
- model.summary()
-
- 3、指定要训练的数据集(gender→imdb即男女数据集)
-
- model callbacks
- callbacks4、回调:通过调用CSVLogger、EarlyStopping、ReduceLROnPlateau、ModelCheckpoint等函数得到训练参数存到一个list内
- early_stop = EarlyStopping('val_loss', patience=patience)
- reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1,
- patience=int(patience/2), verbose=1)
- csv_logger = CSVLogger(log_file_path, append=False)
- model_names = trained_models_path + '.{epoch:02d}-{val_acc:.2f}.hdf5'
- model_checkpoint = ModelCheckpoint(model_names,
- monitor='val_loss',
- verbose=1,
- save_best_only=True,
- save_weights_only=False)
- callbacks = [model_checkpoint, csv_logger, early_stop, reduce_lr]
-
- training model5、调用fit_generator函数训练模型
- model.fit_generator(image_generator.flow(mode='train'),
- steps_per_epoch=int(len(train_keys) / batch_size),
- epochs=num_epochs, verbose=1,
- callbacks=callbacks,
- validation_data=image_generator.flow('val'),
- validation_steps=int(len(val_keys) / batch_size))
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
加入交流群
请使用微信扫一扫!