DL之MaskR-CNN:基于类MaskR-CNN算法(RetinaNet+mask head)利用数据集(resnet50_coco_v0.2.0.h5)实现目标检测和目标图像分割(语义分割)


重庆小强
重庆小强 2022-09-19 14:02:21 49526
分类专栏: 资讯

DL之MaskR-CNN:基于类MaskR-CNN算法(RetinaNet+mask head)利用数据集(resnet50_coco_v0.2.0.h5)实现目标检测和目标图像分割(语义分割)

目录

输出结果

设计思路

核心代码

1、retinanet.py

2、resnet.py


输出结果

更新……

设计思路

参考文章DL之MaskR-CNN:Mask R-CNN算法的简介(论文介绍)、架构详解、案例应用等配图集合之详细攻略
    在ResNet的基础上,增加了ROI_Align、mask_submodel、masks(ConcatenateBoxes,计算loss的拼接)。

核心代码

更新……

1、retinanet.py

default_mask_model函数内,定义了类别个数num_classes、金字塔特征的大小pyramid_feature_size=256等
    mask_feature_size=256,
    roi_size=(14, 14),
    mask_size=(28, 28),

  1. """
  2. Copyright 2017-2018 Fizyr (https://fizyr.com)
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. import keras
  14. import keras.backend
  15. import keras.models
  16. import keras_retinanet.layers
  17. import keras_retinanet.models.retinanet
  18. import keras_retinanet.backend.tensorflow_backend as backend
  19. from ..layers.roi import RoiAlign
  20. from ..layers.upsample import Upsample
  21. from ..layers.misc import Shape, ConcatenateBoxes, Cast
  22. def default_mask_model(
  23. num_classes,
  24. pyramid_feature_size=256,
  25. mask_feature_size=256,
  26. roi_size=(14, 14),
  27. mask_size=(28, 28),
  28. name='mask_submodel',
  29. mask_dtype=keras.backend.floatx(),
  30. retinanet_dtype=keras.backend.floatx()
  31. ):
  32. options = {
  33. 'kernel_size' : 3,
  34. 'strides' : 1,
  35. 'padding' : 'same',
  36. 'kernel_initializer' : keras.initializers.normal(mean=0.0, stddev=0.01, seed=None),
  37. 'bias_initializer' : 'zeros',
  38. 'activation' : 'relu',
  39. }
  40. inputs = keras.layers.Input(shape=(None, roi_size[0], roi_size[1], pyramid_feature_size))
  41. outputs = inputs
  42. casting to the desidered data type, which may be different than
  43. the one used for the underlying keras-retinanet model
  44. if mask_dtype != retinanet_dtype:
  45. outputs = keras.layers.TimeDistributed(
  46. Cast(dtype=mask_dtype),
  47. name='cast_masks')(outputs)
  48. for i in range(4):
  49. outputs = keras.layers.TimeDistributed(keras.layers.Conv2D(
  50. filters=mask_feature_size,
  51. **options
  52. ), name='roi_mask_{}'.format(i))(outputs)
  53. perform upsampling + conv instead of deconv as in the paper
  54. https://distill.pub/2016/deconv-checkerboard/
  55. outputs = keras.layers.TimeDistributed(
  56. Upsample(mask_size),
  57. name='roi_mask_upsample')(outputs)
  58. outputs = keras.layers.TimeDistributed(keras.layers.Conv2D(
  59. filters=mask_feature_size,
  60. **options
  61. ), name='roi_mask_features')(outputs)
  62. outputs = keras.layers.TimeDistributed(keras.layers.Conv2D(
  63. filters=num_classes,
  64. kernel_size=1,
  65. activation='sigmoid'
  66. ), name='roi_mask')(outputs)
  67. casting back to the underlying keras-retinanet model data type
  68. if mask_dtype != retinanet_dtype:
  69. outputs = keras.layers.TimeDistributed(
  70. Cast(dtype=retinanet_dtype),
  71. name='recast_masks')(outputs)
  72. return keras.models.Model(inputs=inputs, outputs=outputs, name=name)
  73. def default_roi_submodels(num_classes, mask_dtype=keras.backend.floatx(), retinanet_dtype=keras.backend.floatx()):
  74. return [
  75. ('masks', default_mask_model(num_classes, mask_dtype=mask_dtype, retinanet_dtype=retinanet_dtype)),
  76. ]
  77. def retinanet_mask(
  78. inputs,
  79. num_classes,
  80. retinanet_model=None,
  81. anchor_params=None,
  82. nms=True,
  83. class_specific_filter=True,
  84. name='retinanet-mask',
  85. roi_submodels=None,
  86. mask_dtype=keras.backend.floatx(),
  87. modifier=None,
  88. **kwargs
  89. ):
  90. """ Construct a RetinaNet mask model on top of a retinanet bbox model.
  91. This model uses the retinanet bbox model and appends a few layers to compute masks.
  92. Arguments
  93. inputs : List of keras.layers.Input. The first input is the image, the second input the blob of masks.
  94. num_classes : Number of classes to classify.
  95. retinanet_model : keras_retinanet.models.retinanet model, returning regression and classification values.
  96. anchor_params : Struct containing anchor parameters. If None, default values are used.
  97. nms : Use NMS.
  98. class_specific_filter : Use class specific filtering.
  99. roi_submodels : Submodels for processing ROIs.
  100. mask_dtype : Data type of the masks, can be different from the main one.
  101. modifier : Modifier for the underlying retinanet model, such as freeze.
  102. name : Name of the model.
  103. **kwargs : Additional kwargs to pass to the retinanet bbox model.
  104. Returns
  105. Model with inputs as input and as output the output of each submodel for each pyramid level and the detections.
  106. The order is as defined in submodels.
  107. ```
  108. [
  109. regression, classification, other[0], other[1], ..., boxes_masks, boxes, scores, labels, masks, other[0], other[1], ...
  110. ]
  111. ```
  112. """
  113. if anchor_params is None:
  114. anchor_params = keras_retinanet.utils.anchors.AnchorParameters.default
  115. if roi_submodels is None:
  116. retinanet_dtype = keras.backend.floatx()
  117. keras.backend.set_floatx(mask_dtype)
  118. roi_submodels = default_roi_submodels(num_classes, mask_dtype, retinanet_dtype)
  119. keras.backend.set_floatx(retinanet_dtype)
  120. image = inputs
  121. image_shape = Shape()(image)
  122. if retinanet_model is None:
  123. retinanet_model = keras_retinanet.models.retinanet.retinanet(
  124. inputs=image,
  125. num_classes=num_classes,
  126. num_anchors=anchor_params.num_anchors(),
  127. **kwargs
  128. )
  129. if modifier:
  130. retinanet_model = modifier(retinanet_model)
  131. parse outputs
  132. regression = retinanet_model.outputs[0]
  133. classification = retinanet_model.outputs[1]
  134. other = retinanet_model.outputs[2:]
  135. features = [retinanet_model.get_layer(name).output for name in ['P3', 'P4', 'P5', 'P6', 'P7']]
  136. build boxes
  137. anchors = keras_retinanet.models.retinanet.__build_anchors(anchor_params, features)
  138. boxes = keras_retinanet.layers.RegressBoxes(name='boxes')([anchors, regression])
  139. boxes = keras_retinanet.layers.ClipBoxes(name='clipped_boxes')([image, boxes])
  140. filter detections (apply NMS / score threshold / select top-k)
  141. detections = keras_retinanet.layers.FilterDetections(
  142. nms = nms,
  143. class_specific_filter = class_specific_filter,
  144. max_detections = 100,
  145. name = 'filtered_detections'
  146. )([boxes, classification] + other)
  147. split up in known outputs and "other"
  148. boxes = detections[0]
  149. scores = detections[1]
  150. get the region of interest features
  151. rois = RoiAlign()([image_shape, boxes, scores] + features)
  152. execute maskrcnn submodels
  153. maskrcnn_outputs = [submodel(rois) for _, submodel in roi_submodels]
  154. concatenate boxes for loss computation
  155. trainable_outputs = [ConcatenateBoxes(name=name)([boxes, output]) for (name, _), output in zip(roi_submodels, maskrcnn_outputs)]
  156. reconstruct the new output
  157. outputs = [regression, classification] + other + trainable_outputs + detections + maskrcnn_outputs
  158. return keras.models.Model(inputs=inputs, outputs=outputs, name=name)

2、resnet.py

作为骨架,resnet_maskrcnn模型,代码中,也可选用resnet50、resnet101、resnet152骨架模型。

  1. """
  2. Copyright 2017-2018 Fizyr (https://fizyr.com)
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. import warnings
  14. import keras
  15. import keras_resnet
  16. import keras_resnet.models
  17. import keras_retinanet.models.resnet
  18. from ..models import retinanet, Backbone
  19. class ResNetBackbone(Backbone, keras_retinanet.models.resnet.ResNetBackbone):
  20. def maskrcnn(self, *args, **kwargs):
  21. """ Returns a maskrcnn model using the correct backbone.
  22. """
  23. return resnet_maskrcnn(*args, backbone=self.backbone, **kwargs)
  24. def resnet_maskrcnn(num_classes, backbone='resnet50', inputs=None, modifier=None, mask_dtype=keras.backend.floatx(), **kwargs):
  25. choose default input
  26. if inputs is None:
  27. inputs = keras.layers.Input(shape=(None, None, 3), name='image')
  28. create the resnet backbone
  29. if backbone == 'resnet50':
  30. resnet = keras_resnet.models.ResNet50(inputs, include_top=False, freeze_bn=True)
  31. elif backbone == 'resnet101':
  32. resnet = keras_resnet.models.ResNet101(inputs, include_top=False, freeze_bn=True)
  33. elif backbone == 'resnet152':
  34. resnet = keras_resnet.models.ResNet152(inputs, include_top=False, freeze_bn=True)
  35. invoke modifier if given
  36. if modifier:
  37. resnet = modifier(resnet)

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=2841
赞同 0
评论 0 条
重庆小强L0
粉丝 0 发表 8 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2944
【软件正版化】软件正版化工作要点  2863
统信UOS试玩黑神话:悟空  2823
信刻光盘安全隔离与信息交换系统  2717
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1251
grub引导程序无法找到指定设备和分区  1217
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!