Add at new repo again
This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
# PointRend: Image Segmentation as Rendering
|
||||
|
||||
Alexander Kirillov, Yuxin Wu, Kaiming He, Ross Girshick
|
||||
|
||||
[[`arXiv`](https://arxiv.org/abs/1912.08193)] [[`BibTeX`](#CitingPointRend)]
|
||||
|
||||
<div align="center">
|
||||
<img src="https://alexander-kirillov.github.io/images/kirillov2019pointrend.jpg"/>
|
||||
</div><br/>
|
||||
|
||||
In this repository, we release code for PointRend in Detectron2. PointRend can be flexibly applied to both instance and semantic segmentation tasks by building on top of existing state-of-the-art models.
|
||||
|
||||
## Installation
|
||||
Install Detectron 2 following [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). You are ready to go!
|
||||
|
||||
## Quick start and visualization
|
||||
|
||||
This [Colab Notebook](https://colab.research.google.com/drive/1isGPL5h5_cKoPPhVL9XhMokRtHDvmMVL) tutorial contains examples of PointRend usage and visualizations of its point sampling stages.
|
||||
|
||||
## Training
|
||||
|
||||
To train a model with 8 GPUs run:
|
||||
```bash
|
||||
cd /path/to/detectron2/projects/PointRend
|
||||
python train_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml --num-gpus 8
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Model evaluation can be done similarly:
|
||||
```bash
|
||||
cd /path/to/detectron2/projects/PointRend
|
||||
python train_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint
|
||||
```
|
||||
|
||||
# Pretrained Models
|
||||
|
||||
## Instance Segmentation
|
||||
#### COCO
|
||||
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="bottom">Mask<br/>head</th>
|
||||
<th valign="bottom">Backbone</th>
|
||||
<th valign="bottom">lr<br/>sched</th>
|
||||
<th valign="bottom">Output<br/>resolution</th>
|
||||
<th valign="bottom">mask<br/>AP</th>
|
||||
<th valign="bottom">mask<br/>AP*</th>
|
||||
<th valign="bottom">model id</th>
|
||||
<th valign="bottom">download</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml">PointRend</a></td>
|
||||
<td align="center">R50-FPN</td>
|
||||
<td align="center">1×</td>
|
||||
<td align="center">224×224</td>
|
||||
<td align="center">36.2</td>
|
||||
<td align="center">39.7</td>
|
||||
<td align="center">164254221</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco/164254221/model_final_88c6f8.pkl">model</a> | <a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco/164254221/metrics.json">metrics</a></td>
|
||||
</tr>
|
||||
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml">PointRend</a></td>
|
||||
<td align="center">R50-FPN</td>
|
||||
<td align="center">3×</td>
|
||||
<td align="center">224×224</td>
|
||||
<td align="center">38.3</td>
|
||||
<td align="center">41.6</td>
|
||||
<td align="center">164955410</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/model_final_3c3198.pkl">model</a> | <a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/metrics.json">metrics</a></td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
AP* is COCO mask AP evaluated against the higher-quality LVIS annotations; see the paper for details. Run `python detectron2/datasets/prepare_cocofied_lvis.py` to prepare GT files for AP* evaluation. Since LVIS annotations are not exhaustive `lvis-api` and not `cocoapi` should be used to evaluate AP*.
|
||||
|
||||
#### Cityscapes
|
||||
Cityscapes model is trained with ImageNet pretraining.
|
||||
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="bottom">Mask<br/>head</th>
|
||||
<th valign="bottom">Backbone</th>
|
||||
<th valign="bottom">lr<br/>sched</th>
|
||||
<th valign="bottom">Output<br/>resolution</th>
|
||||
<th valign="bottom">mask<br/>AP</th>
|
||||
<th valign="bottom">model id</th>
|
||||
<th valign="bottom">download</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml">PointRend</a></td>
|
||||
<td align="center">R50-FPN</td>
|
||||
<td align="center">1×</td>
|
||||
<td align="center">224×224</td>
|
||||
<td align="center">35.9</td>
|
||||
<td align="center">164255101</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes/164255101/model_final_318a02.pkl">model</a> | <a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes/164255101/metrics.json">metrics</a></td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
|
||||
## Semantic Segmentation
|
||||
|
||||
#### Cityscapes
|
||||
Cityscapes model is trained with ImageNet pretraining.
|
||||
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="bottom">Method</th>
|
||||
<th valign="bottom">Backbone</th>
|
||||
<th valign="bottom">Output<br/>resolution</th>
|
||||
<th valign="bottom">mIoU</th>
|
||||
<th valign="bottom">model id</th>
|
||||
<th valign="bottom">download</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left"><a href="configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml">SemanticFPN + PointRend</a></td>
|
||||
<td align="center">R101-FPN</td>
|
||||
<td align="center">1024×2048</td>
|
||||
<td align="center">78.6</td>
|
||||
<td align="center">186480235</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes/186480235/model_final_5f3665.pkl">model</a> | <a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes/186480235/metrics.json">metrics</a></td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
## <a name="CitingPointRend"></a>Citing PointRend
|
||||
|
||||
If you use PointRend, please use the following BibTeX entry.
|
||||
|
||||
```BibTeX
|
||||
@InProceedings{kirillov2019pointrend,
|
||||
title={{PointRend}: Image Segmentation as Rendering},
|
||||
author={Alexander Kirillov and Yuxin Wu and Kaiming He and Ross Girshick},
|
||||
journal={ArXiv:1912.08193},
|
||||
year={2019}
|
||||
}
|
||||
```
|
@@ -0,0 +1,21 @@
|
||||
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml"
|
||||
MODEL:
|
||||
ROI_HEADS:
|
||||
NAME: "PointRendROIHeads"
|
||||
IN_FEATURES: ["p2", "p3", "p4", "p5"]
|
||||
ROI_BOX_HEAD:
|
||||
TRAIN_ON_PRED_BOXES: True
|
||||
ROI_MASK_HEAD:
|
||||
NAME: "CoarseMaskHead"
|
||||
FC_DIM: 1024
|
||||
NUM_FC: 2
|
||||
OUTPUT_SIDE_RESOLUTION: 7
|
||||
IN_FEATURES: ["p2"]
|
||||
POINT_HEAD_ON: True
|
||||
POINT_HEAD:
|
||||
FC_DIM: 256
|
||||
NUM_FC: 3
|
||||
IN_FEATURES: ["p2"]
|
||||
INPUT:
|
||||
# PointRend for instance segmenation does not work with "polygon" mask_format.
|
||||
MASK_FORMAT: "bitmask"
|
@@ -0,0 +1,23 @@
|
||||
_BASE_: Base-PointRend-RCNN-FPN.yaml
|
||||
MODEL:
|
||||
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
|
||||
MASK_ON: true
|
||||
RESNETS:
|
||||
DEPTH: 50
|
||||
ROI_HEADS:
|
||||
NUM_CLASSES: 8
|
||||
POINT_HEAD:
|
||||
NUM_CLASSES: 8
|
||||
DATASETS:
|
||||
TEST: ("cityscapes_fine_instance_seg_val",)
|
||||
TRAIN: ("cityscapes_fine_instance_seg_train",)
|
||||
SOLVER:
|
||||
BASE_LR: 0.01
|
||||
IMS_PER_BATCH: 8
|
||||
MAX_ITER: 24000
|
||||
STEPS: (18000,)
|
||||
INPUT:
|
||||
MAX_SIZE_TEST: 2048
|
||||
MAX_SIZE_TRAIN: 2048
|
||||
MIN_SIZE_TEST: 1024
|
||||
MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024)
|
@@ -0,0 +1,9 @@
|
||||
_BASE_: Base-PointRend-RCNN-FPN.yaml
|
||||
MODEL:
|
||||
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
|
||||
MASK_ON: true
|
||||
RESNETS:
|
||||
DEPTH: 50
|
||||
# To add COCO AP evaluation against the higher-quality LVIS annotations.
|
||||
# DATASETS:
|
||||
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")
|
@@ -0,0 +1,13 @@
|
||||
_BASE_: Base-PointRend-RCNN-FPN.yaml
|
||||
MODEL:
|
||||
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
|
||||
MASK_ON: true
|
||||
RESNETS:
|
||||
DEPTH: 50
|
||||
SOLVER:
|
||||
STEPS: (210000, 250000)
|
||||
MAX_ITER: 270000
|
||||
# To add COCO AP evaluation against the higher-quality LVIS annotations.
|
||||
# DATASETS:
|
||||
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")
|
||||
|
@@ -0,0 +1,20 @@
|
||||
_BASE_: Base-PointRend-RCNN-FPN.yaml
|
||||
MODEL:
|
||||
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
|
||||
MASK_ON: true
|
||||
RESNETS:
|
||||
DEPTH: 50
|
||||
ROI_HEADS:
|
||||
NUM_CLASSES: 1
|
||||
POINT_HEAD:
|
||||
NUM_CLASSES: 1
|
||||
SOLVER:
|
||||
STEPS: (210000, 250000)
|
||||
MAX_ITER: 270000
|
||||
IMS_PER_BATCH: 1
|
||||
# To add COCO AP evaluation against the higher-quality LVIS annotations.
|
||||
# DATASETS:
|
||||
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")
|
||||
DATASETS:
|
||||
TRAIN: ("CIHP_train",)
|
||||
TEST: ("CIHP_val",)
|
@@ -0,0 +1,28 @@
|
||||
_BASE_: Base-PointRend-RCNN-FPN.yaml
|
||||
MODEL:
|
||||
WEIGHTS: "./X-101-32x8d.pkl"
|
||||
PIXEL_STD: [57.375, 57.120, 58.395]
|
||||
MASK_ON: true
|
||||
RESNETS:
|
||||
STRIDE_IN_1X1: False # this is a C2 model
|
||||
NUM_GROUPS: 32
|
||||
WIDTH_PER_GROUP: 8
|
||||
DEPTH: 101
|
||||
ROI_HEADS:
|
||||
NUM_CLASSES: 1
|
||||
POINT_HEAD:
|
||||
NUM_CLASSES: 1
|
||||
SOLVER:
|
||||
STEPS: (210000, 250000)
|
||||
MAX_ITER: 270000
|
||||
IMS_PER_BATCH: 1
|
||||
# To add COCO AP evaluation against the higher-quality LVIS annotations.
|
||||
# DATASETS:
|
||||
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (640, 864)
|
||||
MIN_SIZE_TRAIN_SAMPLING: "range"
|
||||
MAX_SIZE_TRAIN: 1440
|
||||
DATASETS:
|
||||
TRAIN: ("CIHP_train",)
|
||||
TEST: ("CIHP_val",)
|
@@ -0,0 +1,19 @@
|
||||
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml"
|
||||
MODEL:
|
||||
META_ARCHITECTURE: "SemanticSegmentor"
|
||||
BACKBONE:
|
||||
FREEZE_AT: 0
|
||||
SEM_SEG_HEAD:
|
||||
NAME: "PointRendSemSegHead"
|
||||
POINT_HEAD:
|
||||
NUM_CLASSES: 54
|
||||
FC_DIM: 256
|
||||
NUM_FC: 3
|
||||
IN_FEATURES: ["p2"]
|
||||
TRAIN_NUM_POINTS: 1024
|
||||
SUBDIVISION_STEPS: 2
|
||||
SUBDIVISION_NUM_POINTS: 8192
|
||||
COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead"
|
||||
DATASETS:
|
||||
TRAIN: ("coco_2017_train_panoptic_stuffonly",)
|
||||
TEST: ("coco_2017_val_panoptic_stuffonly",)
|
@@ -0,0 +1,33 @@
|
||||
_BASE_: Base-PointRend-Semantic-FPN.yaml
|
||||
MODEL:
|
||||
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl
|
||||
RESNETS:
|
||||
DEPTH: 101
|
||||
SEM_SEG_HEAD:
|
||||
NUM_CLASSES: 19
|
||||
POINT_HEAD:
|
||||
NUM_CLASSES: 19
|
||||
TRAIN_NUM_POINTS: 2048
|
||||
SUBDIVISION_NUM_POINTS: 8192
|
||||
DATASETS:
|
||||
TRAIN: ("cityscapes_fine_sem_seg_train",)
|
||||
TEST: ("cityscapes_fine_sem_seg_val",)
|
||||
SOLVER:
|
||||
BASE_LR: 0.01
|
||||
STEPS: (40000, 55000)
|
||||
MAX_ITER: 65000
|
||||
IMS_PER_BATCH: 32
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048)
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
||||
MIN_SIZE_TEST: 1024
|
||||
MAX_SIZE_TRAIN: 4096
|
||||
MAX_SIZE_TEST: 2048
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute"
|
||||
SIZE: (512, 1024)
|
||||
SINGLE_CATEGORY_MAX_AREA: 0.75
|
||||
COLOR_AUG_SSD: True
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 16
|
@@ -0,0 +1,5 @@
|
||||
_BASE_: Base-PointRend-Semantic-FPN.yaml
|
||||
MODEL:
|
||||
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
|
||||
RESNETS:
|
||||
DEPTH: 50
|
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
PointRend Training Script.
|
||||
|
||||
This script is a simplified version of the training script in detectron2/tools.
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
import detectron2.utils.comm as comm
|
||||
from detectron2.checkpoint import DetectionCheckpointer
|
||||
from detectron2.config import get_cfg
|
||||
from detectron2.data import MetadataCatalog, build_detection_train_loader
|
||||
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
|
||||
from detectron2.evaluation import (
|
||||
CityscapesInstanceEvaluator,
|
||||
CityscapesSemSegEvaluator,
|
||||
COCOEvaluator,
|
||||
DatasetEvaluators,
|
||||
LVISEvaluator,
|
||||
SemSegEvaluator,
|
||||
verify_results,
|
||||
)
|
||||
|
||||
from point_rend import SemSegDatasetMapper, add_pointrend_config
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
|
||||
# Register Custom Dataset
|
||||
from detectron2.data.datasets import register_coco_instances
|
||||
register_coco_instances("CIHP_train", {}, "/data03/v_xuyunqiu/multi_parsing/data/msrcnn_finetune_annotations/CIHP_train.json", "/data03/v_xuyunqiu/data/instance-level_human_parsing/Training/Images")
|
||||
register_coco_instances("CIHP_val", {}, "/data03/v_xuyunqiu/multi_parsing/data/msrcnn_finetune_annotations/CIHP_val.json", "/data03/v_xuyunqiu/data/instance-level_human_parsing/Validation/Images")
|
||||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
"""
|
||||
We use the "DefaultTrainer" which contains a number pre-defined logic for
|
||||
standard training workflow. They may not work for you, especially if you
|
||||
are working on a new research project. In that case you can use the cleaner
|
||||
"SimpleTrainer", or write your own training loop.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
||||
"""
|
||||
Create evaluator(s) for a given dataset.
|
||||
This uses the special metadata "evaluator_type" associated with each builtin dataset.
|
||||
For your own dataset, you can simply create an evaluator manually in your
|
||||
script and do not have to worry about the hacky if-else logic here.
|
||||
"""
|
||||
if output_folder is None:
|
||||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||||
evaluator_list = []
|
||||
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
||||
if evaluator_type == "lvis":
|
||||
return LVISEvaluator(dataset_name, cfg, True, output_folder)
|
||||
if evaluator_type == "coco":
|
||||
return COCOEvaluator(dataset_name, cfg, True, output_folder)
|
||||
if evaluator_type == "sem_seg":
|
||||
return SemSegEvaluator(
|
||||
dataset_name,
|
||||
distributed=True,
|
||||
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
||||
ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
||||
output_dir=output_folder,
|
||||
)
|
||||
if evaluator_type == "cityscapes_instance":
|
||||
assert (
|
||||
torch.cuda.device_count() >= comm.get_rank()
|
||||
), "CityscapesEvaluator currently do not work with multiple machines."
|
||||
return CityscapesInstanceEvaluator(dataset_name)
|
||||
if evaluator_type == "cityscapes_sem_seg":
|
||||
assert (
|
||||
torch.cuda.device_count() >= comm.get_rank()
|
||||
), "CityscapesEvaluator currently do not work with multiple machines."
|
||||
return CityscapesSemSegEvaluator(dataset_name)
|
||||
if len(evaluator_list) == 0:
|
||||
raise NotImplementedError(
|
||||
"no Evaluator for the dataset {} with the type {}".format(
|
||||
dataset_name, evaluator_type
|
||||
)
|
||||
)
|
||||
if len(evaluator_list) == 1:
|
||||
return evaluator_list[0]
|
||||
return DatasetEvaluators(evaluator_list)
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE:
|
||||
mapper = SemSegDatasetMapper(cfg, True)
|
||||
else:
|
||||
mapper = None
|
||||
return build_detection_train_loader(cfg, mapper=mapper)
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
add_pointrend_config(cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
model = Trainer.build_model(cfg)
|
||||
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
res = Trainer.test(cfg, model)
|
||||
if comm.is_main_process():
|
||||
verify_results(cfg, res)
|
||||
return res
|
||||
|
||||
trainer = Trainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
launch(
|
||||
main,
|
||||
args.num_gpus,
|
||||
num_machines=args.num_machines,
|
||||
machine_rank=args.machine_rank,
|
||||
dist_url=args.dist_url,
|
||||
args=(args,),
|
||||
)
|
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .config import add_pointrend_config
|
||||
from .coarse_mask_head import CoarseMaskHead
|
||||
from .roi_heads import PointRendROIHeads
|
||||
from .dataset_mapper import SemSegDatasetMapper
|
||||
from .semantic_seg import PointRendSemSegHead
|
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import fvcore.nn.weight_init as weight_init
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from detectron2.layers import Conv2d, ShapeSpec
|
||||
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY
|
||||
|
||||
|
||||
@ROI_MASK_HEAD_REGISTRY.register()
|
||||
class CoarseMaskHead(nn.Module):
|
||||
"""
|
||||
A mask head with fully connected layers. Given pooled features it first reduces channels and
|
||||
spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously
|
||||
to the standard box head.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape: ShapeSpec):
|
||||
"""
|
||||
The following attributes are parsed from config:
|
||||
conv_dim: the output dimension of the conv layers
|
||||
fc_dim: the feature dimenstion of the FC layers
|
||||
num_fc: the number of FC layers
|
||||
output_side_resolution: side resolution of the output square mask prediction
|
||||
"""
|
||||
super(CoarseMaskHead, self).__init__()
|
||||
|
||||
# fmt: off
|
||||
self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
||||
conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM
|
||||
self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM
|
||||
num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC
|
||||
self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION
|
||||
self.input_channels = input_shape.channels
|
||||
self.input_h = input_shape.height
|
||||
self.input_w = input_shape.width
|
||||
# fmt: on
|
||||
|
||||
self.conv_layers = []
|
||||
if self.input_channels > conv_dim:
|
||||
self.reduce_channel_dim_conv = Conv2d(
|
||||
self.input_channels,
|
||||
conv_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
activation=F.relu,
|
||||
)
|
||||
self.conv_layers.append(self.reduce_channel_dim_conv)
|
||||
|
||||
self.reduce_spatial_dim_conv = Conv2d(
|
||||
conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu
|
||||
)
|
||||
self.conv_layers.append(self.reduce_spatial_dim_conv)
|
||||
|
||||
input_dim = conv_dim * self.input_h * self.input_w
|
||||
input_dim //= 4
|
||||
|
||||
self.fcs = []
|
||||
for k in range(num_fc):
|
||||
fc = nn.Linear(input_dim, self.fc_dim)
|
||||
self.add_module("coarse_mask_fc{}".format(k + 1), fc)
|
||||
self.fcs.append(fc)
|
||||
input_dim = self.fc_dim
|
||||
|
||||
output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution
|
||||
|
||||
self.prediction = nn.Linear(self.fc_dim, output_dim)
|
||||
# use normal distribution initialization for mask prediction layer
|
||||
nn.init.normal_(self.prediction.weight, std=0.001)
|
||||
nn.init.constant_(self.prediction.bias, 0)
|
||||
|
||||
for layer in self.conv_layers:
|
||||
weight_init.c2_msra_fill(layer)
|
||||
for layer in self.fcs:
|
||||
weight_init.c2_xavier_fill(layer)
|
||||
|
||||
def forward(self, x):
|
||||
# unlike BaseMaskRCNNHead, this head only outputs intermediate
|
||||
# features, because the features will be used later by PointHead.
|
||||
N = x.shape[0]
|
||||
x = x.view(N, self.input_channels, self.input_h, self.input_w)
|
||||
for layer in self.conv_layers:
|
||||
x = layer(x)
|
||||
x = torch.flatten(x, start_dim=1)
|
||||
for layer in self.fcs:
|
||||
x = F.relu(layer(x))
|
||||
return self.prediction(x).view(
|
||||
N, self.num_classes, self.output_side_resolution, self.output_side_resolution
|
||||
)
|
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import numpy as np
|
||||
import random
|
||||
import cv2
|
||||
from fvcore.transforms.transform import Transform
|
||||
|
||||
|
||||
class ColorAugSSDTransform(Transform):
|
||||
"""
|
||||
A color related data augmentation used in Single Shot Multibox Detector (SSD).
|
||||
|
||||
Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy,
|
||||
Scott Reed, Cheng-Yang Fu, Alexander C. Berg.
|
||||
SSD: Single Shot MultiBox Detector. ECCV 2016.
|
||||
|
||||
Implementation based on:
|
||||
|
||||
https://github.com/weiliu89/caffe/blob
|
||||
/4817bf8b4200b35ada8ed0dc378dceaf38c539e4
|
||||
/src/caffe/util/im_transforms.cpp
|
||||
|
||||
https://github.com/chainer/chainercv/blob
|
||||
/7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv
|
||||
/links/model/ssd/transforms.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_format,
|
||||
brightness_delta=32,
|
||||
contrast_low=0.5,
|
||||
contrast_high=1.5,
|
||||
saturation_low=0.5,
|
||||
saturation_high=1.5,
|
||||
hue_delta=18,
|
||||
):
|
||||
super().__init__()
|
||||
assert img_format in ["BGR", "RGB"]
|
||||
self.is_rgb = img_format == "RGB"
|
||||
del img_format
|
||||
self._set_attributes(locals())
|
||||
|
||||
def apply_coords(self, coords):
|
||||
return coords
|
||||
|
||||
def apply_segmentation(self, segmentation):
|
||||
return segmentation
|
||||
|
||||
def apply_image(self, img, interp=None):
|
||||
if self.is_rgb:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
img = self.brightness(img)
|
||||
if random.randrange(2):
|
||||
img = self.contrast(img)
|
||||
img = self.saturation(img)
|
||||
img = self.hue(img)
|
||||
else:
|
||||
img = self.saturation(img)
|
||||
img = self.hue(img)
|
||||
img = self.contrast(img)
|
||||
if self.is_rgb:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
return img
|
||||
|
||||
def convert(self, img, alpha=1, beta=0):
|
||||
img = img.astype(np.float32) * alpha + beta
|
||||
img = np.clip(img, 0, 255)
|
||||
return img.astype(np.uint8)
|
||||
|
||||
def brightness(self, img):
|
||||
if random.randrange(2):
|
||||
return self.convert(
|
||||
img, beta=random.uniform(-self.brightness_delta, self.brightness_delta)
|
||||
)
|
||||
return img
|
||||
|
||||
def contrast(self, img):
|
||||
if random.randrange(2):
|
||||
return self.convert(img, alpha=random.uniform(self.contrast_low, self.contrast_high))
|
||||
return img
|
||||
|
||||
def saturation(self, img):
|
||||
if random.randrange(2):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
img[:, :, 1] = self.convert(
|
||||
img[:, :, 1], alpha=random.uniform(self.saturation_low, self.saturation_high)
|
||||
)
|
||||
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
|
||||
return img
|
||||
|
||||
def hue(self, img):
|
||||
if random.randrange(2):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
img[:, :, 0] = (
|
||||
img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta)
|
||||
) % 180
|
||||
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
|
||||
return img
|
@@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
from detectron2.config import CfgNode as CN
|
||||
|
||||
|
||||
def add_pointrend_config(cfg):
|
||||
"""
|
||||
Add config for PointRend.
|
||||
"""
|
||||
# We retry random cropping until no single category in semantic segmentation GT occupies more
|
||||
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
|
||||
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
|
||||
# Color augmentatition from SSD paper for semantic segmentation model during training.
|
||||
cfg.INPUT.COLOR_AUG_SSD = False
|
||||
|
||||
# Names of the input feature maps to be used by a coarse mask head.
|
||||
cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",)
|
||||
cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024
|
||||
cfg.MODEL.ROI_MASK_HEAD.NUM_FC = 2
|
||||
# The side size of a coarse mask head prediction.
|
||||
cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7
|
||||
# True if point head is used.
|
||||
cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = False
|
||||
|
||||
cfg.MODEL.POINT_HEAD = CN()
|
||||
cfg.MODEL.POINT_HEAD.NAME = "StandardPointHead"
|
||||
cfg.MODEL.POINT_HEAD.NUM_CLASSES = 80
|
||||
# Names of the input feature maps to be used by a mask point head.
|
||||
cfg.MODEL.POINT_HEAD.IN_FEATURES = ("p2",)
|
||||
# Number of points sampled during training for a mask point head.
|
||||
cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 14 * 14
|
||||
# Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
|
||||
# original paper.
|
||||
cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO = 3
|
||||
# Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
|
||||
# the original paper.
|
||||
cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO = 0.75
|
||||
# Number of subdivision steps during inference.
|
||||
cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS = 5
|
||||
# Maximum number of points selected at each subdivision step (N).
|
||||
cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 28 * 28
|
||||
cfg.MODEL.POINT_HEAD.FC_DIM = 256
|
||||
cfg.MODEL.POINT_HEAD.NUM_FC = 3
|
||||
cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False
|
||||
# If True, then coarse prediction features are used as inout for each layer in PointRend's MLP.
|
||||
cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True
|
||||
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead"
|
@@ -0,0 +1,121 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import copy
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
from fvcore.common.file_io import PathManager
|
||||
from fvcore.transforms.transform import CropTransform
|
||||
from PIL import Image
|
||||
|
||||
from detectron2.data import detection_utils as utils
|
||||
from detectron2.data import transforms as T
|
||||
|
||||
from .color_augmentation import ColorAugSSDTransform
|
||||
|
||||
"""
|
||||
This file contains the mapping that's applied to "dataset dicts" for semantic segmentation models.
|
||||
Unlike the default DatasetMapper this mapper uses cropping as the last transformation.
|
||||
"""
|
||||
|
||||
__all__ = ["SemSegDatasetMapper"]
|
||||
|
||||
|
||||
class SemSegDatasetMapper:
|
||||
"""
|
||||
A callable which takes a dataset dict in Detectron2 Dataset format,
|
||||
and map it into a format used by semantic segmentation models.
|
||||
|
||||
The callable currently does the following:
|
||||
|
||||
1. Read the image from "file_name"
|
||||
2. Applies geometric transforms to the image and annotation
|
||||
3. Find and applies suitable cropping to the image and annotation
|
||||
4. Prepare image and annotation to Tensors
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, is_train=True):
|
||||
if cfg.INPUT.CROP.ENABLED and is_train:
|
||||
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
|
||||
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
|
||||
else:
|
||||
self.crop_gen = None
|
||||
|
||||
self.tfm_gens = utils.build_transform_gen(cfg, is_train)
|
||||
|
||||
if cfg.INPUT.COLOR_AUG_SSD:
|
||||
self.tfm_gens.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
|
||||
logging.getLogger(__name__).info(
|
||||
"Color augmnetation used in training: " + str(self.tfm_gens[-1])
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
self.img_format = cfg.INPUT.FORMAT
|
||||
self.single_category_max_area = cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA
|
||||
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
|
||||
# fmt: on
|
||||
|
||||
self.is_train = is_train
|
||||
|
||||
def __call__(self, dataset_dict):
|
||||
"""
|
||||
Args:
|
||||
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
||||
|
||||
Returns:
|
||||
dict: a format that builtin models in detectron2 accept
|
||||
"""
|
||||
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
||||
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
||||
utils.check_image_size(dataset_dict, image)
|
||||
assert "sem_seg_file_name" in dataset_dict
|
||||
|
||||
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
||||
if self.is_train:
|
||||
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
|
||||
sem_seg_gt = Image.open(f)
|
||||
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
|
||||
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
|
||||
if self.crop_gen:
|
||||
image, sem_seg_gt = crop_transform(
|
||||
image,
|
||||
sem_seg_gt,
|
||||
self.crop_gen,
|
||||
self.single_category_max_area,
|
||||
self.ignore_value,
|
||||
)
|
||||
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
|
||||
|
||||
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
||||
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
||||
# Therefore it's important to use torch.Tensor.
|
||||
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
||||
|
||||
if not self.is_train:
|
||||
dataset_dict.pop("sem_seg_file_name", None)
|
||||
return dataset_dict
|
||||
|
||||
return dataset_dict
|
||||
|
||||
|
||||
def crop_transform(image, sem_seg, crop_gen, single_category_max_area, ignore_value):
|
||||
"""
|
||||
Find a cropping window such that no single category occupies more than
|
||||
`single_category_max_area` in `sem_seg`. The function retries random cropping 10 times max.
|
||||
"""
|
||||
if single_category_max_area >= 1.0:
|
||||
crop_tfm = crop_gen.get_transform(image)
|
||||
sem_seg_temp = crop_tfm.apply_segmentation(sem_seg)
|
||||
else:
|
||||
h, w = sem_seg.shape
|
||||
crop_size = crop_gen.get_crop_size((h, w))
|
||||
for _ in range(10):
|
||||
y0 = np.random.randint(h - crop_size[0] + 1)
|
||||
x0 = np.random.randint(w - crop_size[1] + 1)
|
||||
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
|
||||
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
|
||||
cnt = cnt[labels != ignore_value]
|
||||
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < single_category_max_area:
|
||||
break
|
||||
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
|
||||
image = crop_tfm.apply_image(image)
|
||||
return image, sem_seg_temp
|
@@ -0,0 +1,216 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from detectron2.layers import cat
|
||||
from detectron2.structures import Boxes
|
||||
|
||||
|
||||
"""
|
||||
Shape shorthand in this module:
|
||||
|
||||
N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the
|
||||
number of images for semantic segmenation.
|
||||
R: number of ROIs, combined over all images, in the minibatch
|
||||
P: number of points
|
||||
"""
|
||||
|
||||
|
||||
def point_sample(input, point_coords, **kwargs):
|
||||
"""
|
||||
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
|
||||
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
|
||||
[0, 1] x [0, 1] square.
|
||||
|
||||
Args:
|
||||
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
|
||||
[0, 1] x [0, 1] normalized point coordinates.
|
||||
|
||||
Returns:
|
||||
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
|
||||
features for points in `point_coords`. The features are obtained via bilinear
|
||||
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
|
||||
"""
|
||||
add_dim = False
|
||||
if point_coords.dim() == 3:
|
||||
add_dim = True
|
||||
point_coords = point_coords.unsqueeze(2)
|
||||
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
|
||||
if add_dim:
|
||||
output = output.squeeze(3)
|
||||
return output
|
||||
|
||||
|
||||
def generate_regular_grid_point_coords(R, side_size, device):
|
||||
"""
|
||||
Generate regular square grid of points in [0, 1] x [0, 1] coordinate space.
|
||||
|
||||
Args:
|
||||
R (int): The number of grids to sample, one for each region.
|
||||
side_size (int): The side size of the regular grid.
|
||||
device (torch.device): Desired device of returned tensor.
|
||||
|
||||
Returns:
|
||||
(Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates
|
||||
for the regular grids.
|
||||
"""
|
||||
aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device)
|
||||
r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False)
|
||||
return r.view(1, -1, 2).expand(R, -1, -1)
|
||||
|
||||
|
||||
def get_uncertain_point_coords_with_randomness(
|
||||
coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio
|
||||
):
|
||||
"""
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
|
||||
are calculated for each point using 'uncertainty_func' function that takes point's logit
|
||||
prediction as input.
|
||||
See PointRend paper for details.
|
||||
|
||||
Args:
|
||||
coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
|
||||
class-specific or class-agnostic prediction.
|
||||
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
|
||||
contains logit predictions for P points and returns their uncertainties as a Tensor of
|
||||
shape (N, 1, P).
|
||||
num_points (int): The number of points P to sample.
|
||||
oversample_ratio (int): Oversampling parameter.
|
||||
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
|
||||
sampled points.
|
||||
"""
|
||||
assert oversample_ratio >= 1
|
||||
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
|
||||
num_boxes = coarse_logits.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
|
||||
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
|
||||
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
|
||||
# Calculating uncertainties of the coarse predictions first and sampling them for points leads
|
||||
# to incorrect results.
|
||||
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
|
||||
# two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
|
||||
# However, if we calculate uncertainties for the coarse predictions first,
|
||||
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
|
||||
point_uncertainties = uncertainty_func(point_logits)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
num_boxes, num_uncertain_points, 2
|
||||
)
|
||||
if num_random_points > 0:
|
||||
point_coords = cat(
|
||||
[
|
||||
point_coords,
|
||||
torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return point_coords
|
||||
|
||||
|
||||
def get_uncertain_point_coords_on_grid(uncertainty_map, num_points):
|
||||
"""
|
||||
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
||||
|
||||
Args:
|
||||
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
|
||||
values for a set of points on a regular H x W grid.
|
||||
num_points (int): The number of points P to select.
|
||||
|
||||
Returns:
|
||||
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
||||
[0, H x W) of the most uncertain points.
|
||||
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
|
||||
coordinates of the most uncertain points from the H x W grid.
|
||||
"""
|
||||
R, _, H, W = uncertainty_map.shape
|
||||
h_step = 1.0 / float(H)
|
||||
w_step = 1.0 / float(W)
|
||||
|
||||
num_points = min(H * W, num_points)
|
||||
point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1]
|
||||
point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device)
|
||||
point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
||||
point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
||||
return point_indices, point_coords
|
||||
|
||||
|
||||
def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords):
|
||||
"""
|
||||
Get features from feature maps in `features_list` that correspond to specific point coordinates
|
||||
inside each bounding box from `boxes`.
|
||||
|
||||
Args:
|
||||
features_list (list[Tensor]): A list of feature map tensors to get features from.
|
||||
feature_scales (list[float]): A list of scales for tensors in `features_list`.
|
||||
boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all
|
||||
together.
|
||||
point_coords (Tensor): A tensor of shape (R, P, 2) that contains
|
||||
[0, 1] x [0, 1] box-normalized coordinates of the P sampled points.
|
||||
|
||||
Returns:
|
||||
point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled
|
||||
from all features maps in feature_list for P sampled points for all R boxes in `boxes`.
|
||||
point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level
|
||||
coordinates of P points.
|
||||
"""
|
||||
cat_boxes = Boxes.cat(boxes)
|
||||
num_boxes = [len(b) for b in boxes]
|
||||
|
||||
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords)
|
||||
split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes)
|
||||
|
||||
point_features = []
|
||||
for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image):
|
||||
point_features_per_image = []
|
||||
for idx_feature, feature_map in enumerate(features_list):
|
||||
h, w = feature_map.shape[-2:]
|
||||
scale = torch.tensor([w, h], device=feature_map.device) / feature_scales[idx_feature]
|
||||
point_coords_scaled = point_coords_wrt_image_per_image / scale
|
||||
point_features_per_image.append(
|
||||
point_sample(
|
||||
feature_map[idx_img].unsqueeze(0),
|
||||
point_coords_scaled.unsqueeze(0),
|
||||
align_corners=False,
|
||||
)
|
||||
.squeeze(0)
|
||||
.transpose(1, 0)
|
||||
)
|
||||
point_features.append(cat(point_features_per_image, dim=1))
|
||||
|
||||
return cat(point_features, dim=0), point_coords_wrt_image
|
||||
|
||||
|
||||
def get_point_coords_wrt_image(boxes_coords, point_coords):
|
||||
"""
|
||||
Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates.
|
||||
|
||||
Args:
|
||||
boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes.
|
||||
coordinates.
|
||||
point_coords (Tensor): A tensor of shape (R, P, 2) that contains
|
||||
[0, 1] x [0, 1] box-normalized coordinates of the P sampled points.
|
||||
|
||||
Returns:
|
||||
point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains
|
||||
image-normalized coordinates of P sampled points.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
point_coords_wrt_image = point_coords.clone()
|
||||
point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * (
|
||||
boxes_coords[:, None, 2] - boxes_coords[:, None, 0]
|
||||
)
|
||||
point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * (
|
||||
boxes_coords[:, None, 3] - boxes_coords[:, None, 1]
|
||||
)
|
||||
point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0]
|
||||
point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1]
|
||||
return point_coords_wrt_image
|
@@ -0,0 +1,154 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import fvcore.nn.weight_init as weight_init
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from detectron2.layers import ShapeSpec, cat
|
||||
from detectron2.structures import BitMasks
|
||||
from detectron2.utils.events import get_event_storage
|
||||
from detectron2.utils.registry import Registry
|
||||
|
||||
from .point_features import point_sample
|
||||
|
||||
POINT_HEAD_REGISTRY = Registry("POINT_HEAD")
|
||||
POINT_HEAD_REGISTRY.__doc__ = """
|
||||
Registry for point heads, which makes prediction for a given set of per-point features.
|
||||
|
||||
The registered object will be called with `obj(cfg, input_shape)`.
|
||||
"""
|
||||
|
||||
|
||||
def roi_mask_point_loss(mask_logits, instances, points_coord):
|
||||
"""
|
||||
Compute the point-based loss for instance segmentation mask predictions.
|
||||
|
||||
Args:
|
||||
mask_logits (Tensor): A tensor of shape (R, C, P) or (R, 1, P) for class-specific or
|
||||
class-agnostic, where R is the total number of predicted masks in all images, C is the
|
||||
number of foreground classes, and P is the number of points sampled for each mask.
|
||||
The values are logits.
|
||||
instances (list[Instances]): A list of N Instances, where N is the number of images
|
||||
in the batch. These instances are in 1:1 correspondence with the `mask_logits`. So, i_th
|
||||
elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R.
|
||||
The ground-truth labels (class, box, mask, ...) associated with each instance are stored
|
||||
in fields.
|
||||
points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of
|
||||
predicted masks and P is the number of points for each mask. The coordinates are in
|
||||
the image pixel coordinate space, i.e. [0, H] x [0, W].
|
||||
Returns:
|
||||
point_loss (Tensor): A scalar tensor containing the loss.
|
||||
"""
|
||||
assert len(instances) == 0 or isinstance(
|
||||
instances[0].gt_masks, BitMasks
|
||||
), "Point head works with GT in 'bitmask' format only. Set INPUT.MASK_FORMAT to 'bitmask'."
|
||||
with torch.no_grad():
|
||||
cls_agnostic_mask = mask_logits.size(1) == 1
|
||||
total_num_masks = mask_logits.size(0)
|
||||
|
||||
gt_classes = []
|
||||
gt_mask_logits = []
|
||||
idx = 0
|
||||
for instances_per_image in instances:
|
||||
if not cls_agnostic_mask:
|
||||
gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64)
|
||||
gt_classes.append(gt_classes_per_image)
|
||||
|
||||
gt_bit_masks = instances_per_image.gt_masks.tensor
|
||||
h, w = instances_per_image.gt_masks.image_size
|
||||
scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device)
|
||||
points_coord_grid_sample_format = (
|
||||
points_coord[idx : idx + len(instances_per_image)] / scale
|
||||
)
|
||||
idx += len(instances_per_image)
|
||||
gt_mask_logits.append(
|
||||
point_sample(
|
||||
gt_bit_masks.to(torch.float32).unsqueeze(1),
|
||||
points_coord_grid_sample_format,
|
||||
align_corners=False,
|
||||
).squeeze(1)
|
||||
)
|
||||
gt_mask_logits = cat(gt_mask_logits)
|
||||
|
||||
# torch.mean (in binary_cross_entropy_with_logits) doesn't
|
||||
# accept empty tensors, so handle it separately
|
||||
if gt_mask_logits.numel() == 0:
|
||||
return mask_logits.sum() * 0
|
||||
|
||||
if cls_agnostic_mask:
|
||||
mask_logits = mask_logits[:, 0]
|
||||
else:
|
||||
indices = torch.arange(total_num_masks)
|
||||
gt_classes = cat(gt_classes, dim=0)
|
||||
mask_logits = mask_logits[indices, gt_classes]
|
||||
|
||||
# Log the training accuracy (using gt classes and 0.0 threshold for the logits)
|
||||
mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8)
|
||||
mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel()
|
||||
get_event_storage().put_scalar("point_rend/accuracy", mask_accuracy)
|
||||
|
||||
point_loss = F.binary_cross_entropy_with_logits(
|
||||
mask_logits, gt_mask_logits.to(dtype=torch.float32), reduction="mean"
|
||||
)
|
||||
return point_loss
|
||||
|
||||
|
||||
@POINT_HEAD_REGISTRY.register()
|
||||
class StandardPointHead(nn.Module):
|
||||
"""
|
||||
A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head
|
||||
takes both fine-grained and coarse prediction features as its input.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape: ShapeSpec):
|
||||
"""
|
||||
The following attributes are parsed from config:
|
||||
fc_dim: the output dimension of each FC layers
|
||||
num_fc: the number of FC layers
|
||||
coarse_pred_each_layer: if True, coarse prediction features are concatenated to each
|
||||
layer's input
|
||||
"""
|
||||
super(StandardPointHead, self).__init__()
|
||||
# fmt: off
|
||||
num_classes = cfg.MODEL.POINT_HEAD.NUM_CLASSES
|
||||
fc_dim = cfg.MODEL.POINT_HEAD.FC_DIM
|
||||
num_fc = cfg.MODEL.POINT_HEAD.NUM_FC
|
||||
cls_agnostic_mask = cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK
|
||||
self.coarse_pred_each_layer = cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER
|
||||
input_channels = input_shape.channels
|
||||
# fmt: on
|
||||
|
||||
fc_dim_in = input_channels + num_classes
|
||||
self.fc_layers = []
|
||||
for k in range(num_fc):
|
||||
fc = nn.Conv1d(fc_dim_in, fc_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.add_module("fc{}".format(k + 1), fc)
|
||||
self.fc_layers.append(fc)
|
||||
fc_dim_in = fc_dim
|
||||
fc_dim_in += num_classes if self.coarse_pred_each_layer else 0
|
||||
|
||||
num_mask_classes = 1 if cls_agnostic_mask else num_classes
|
||||
self.predictor = nn.Conv1d(fc_dim_in, num_mask_classes, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
for layer in self.fc_layers:
|
||||
weight_init.c2_msra_fill(layer)
|
||||
# use normal distribution initialization for mask prediction layer
|
||||
nn.init.normal_(self.predictor.weight, std=0.001)
|
||||
if self.predictor.bias is not None:
|
||||
nn.init.constant_(self.predictor.bias, 0)
|
||||
|
||||
def forward(self, fine_grained_features, coarse_features):
|
||||
x = torch.cat((fine_grained_features, coarse_features), dim=1)
|
||||
for layer in self.fc_layers:
|
||||
x = F.relu(layer(x))
|
||||
if self.coarse_pred_each_layer:
|
||||
x = cat((x, coarse_features), dim=1)
|
||||
return self.predictor(x)
|
||||
|
||||
|
||||
def build_point_head(cfg, input_channels):
|
||||
"""
|
||||
Build a point head defined by `cfg.MODEL.POINT_HEAD.NAME`.
|
||||
"""
|
||||
head_name = cfg.MODEL.POINT_HEAD.NAME
|
||||
return POINT_HEAD_REGISTRY.get(head_name)(cfg, input_channels)
|
@@ -0,0 +1,227 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from detectron2.layers import ShapeSpec, cat, interpolate
|
||||
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
|
||||
from detectron2.modeling.roi_heads.mask_head import (
|
||||
build_mask_head,
|
||||
mask_rcnn_inference,
|
||||
mask_rcnn_loss,
|
||||
)
|
||||
from detectron2.modeling.roi_heads.roi_heads import select_foreground_proposals
|
||||
|
||||
from .point_features import (
|
||||
generate_regular_grid_point_coords,
|
||||
get_uncertain_point_coords_on_grid,
|
||||
get_uncertain_point_coords_with_randomness,
|
||||
point_sample,
|
||||
point_sample_fine_grained_features,
|
||||
)
|
||||
from .point_head import build_point_head, roi_mask_point_loss
|
||||
|
||||
|
||||
def calculate_uncertainty(logits, classes):
|
||||
"""
|
||||
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
|
||||
foreground class in `classes`.
|
||||
|
||||
Args:
|
||||
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
|
||||
class-agnostic, where R is the total number of predicted masks in all images and C is
|
||||
the number of foreground classes. The values are logits.
|
||||
classes (list): A list of length R that contains either predicted of ground truth class
|
||||
for eash predicted mask.
|
||||
|
||||
Returns:
|
||||
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
||||
the most uncertain locations having the highest uncertainty score.
|
||||
"""
|
||||
if logits.shape[1] == 1:
|
||||
gt_class_logits = logits.clone()
|
||||
else:
|
||||
gt_class_logits = logits[
|
||||
torch.arange(logits.shape[0], device=logits.device), classes
|
||||
].unsqueeze(1)
|
||||
return -(torch.abs(gt_class_logits))
|
||||
|
||||
|
||||
@ROI_HEADS_REGISTRY.register()
|
||||
class PointRendROIHeads(StandardROIHeads):
|
||||
"""
|
||||
The RoI heads class for PointRend instance segmentation models.
|
||||
|
||||
In this class we redefine the mask head of `StandardROIHeads` leaving all other heads intact.
|
||||
To avoid namespace conflict with other heads we use names starting from `mask_` for all
|
||||
variables that correspond to the mask head in the class's namespace.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
# TODO use explicit args style
|
||||
super().__init__(cfg, input_shape)
|
||||
self._init_mask_head(cfg, input_shape)
|
||||
|
||||
def _init_mask_head(self, cfg, input_shape):
|
||||
# fmt: off
|
||||
self.mask_on = cfg.MODEL.MASK_ON
|
||||
if not self.mask_on:
|
||||
return
|
||||
self.mask_coarse_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES
|
||||
self.mask_coarse_side_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
|
||||
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()}
|
||||
# fmt: on
|
||||
|
||||
in_channels = np.sum([input_shape[f].channels for f in self.mask_coarse_in_features])
|
||||
self.mask_coarse_head = build_mask_head(
|
||||
cfg,
|
||||
ShapeSpec(
|
||||
channels=in_channels,
|
||||
width=self.mask_coarse_side_size,
|
||||
height=self.mask_coarse_side_size,
|
||||
),
|
||||
)
|
||||
self._init_point_head(cfg, input_shape)
|
||||
|
||||
def _init_point_head(self, cfg, input_shape):
|
||||
# fmt: off
|
||||
self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON
|
||||
if not self.mask_point_on:
|
||||
return
|
||||
assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES
|
||||
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES
|
||||
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS
|
||||
self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO
|
||||
self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO
|
||||
# next two parameters are use in the adaptive subdivions inference procedure
|
||||
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS
|
||||
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS
|
||||
# fmt: on
|
||||
|
||||
in_channels = np.sum([input_shape[f].channels for f in self.mask_point_in_features])
|
||||
self.mask_point_head = build_point_head(
|
||||
cfg, ShapeSpec(channels=in_channels, width=1, height=1)
|
||||
)
|
||||
|
||||
def _forward_mask(self, features, instances):
|
||||
"""
|
||||
Forward logic of the mask prediction branch.
|
||||
|
||||
Args:
|
||||
features (dict[str, Tensor]): #level input features for mask prediction
|
||||
instances (list[Instances]): the per-image instances to train/predict masks.
|
||||
In training, they can be the proposals.
|
||||
In inference, they can be the predicted boxes.
|
||||
|
||||
Returns:
|
||||
In training, a dict of losses.
|
||||
In inference, update `instances` with new fields "pred_masks" and return it.
|
||||
"""
|
||||
if not self.mask_on:
|
||||
return {} if self.training else instances
|
||||
|
||||
if self.training:
|
||||
proposals, _ = select_foreground_proposals(instances, self.num_classes)
|
||||
proposal_boxes = [x.proposal_boxes for x in proposals]
|
||||
mask_coarse_logits = self._forward_mask_coarse(features, proposal_boxes)
|
||||
|
||||
losses = {"loss_mask": mask_rcnn_loss(mask_coarse_logits, proposals)}
|
||||
losses.update(self._forward_mask_point(features, mask_coarse_logits, proposals))
|
||||
return losses
|
||||
else:
|
||||
pred_boxes = [x.pred_boxes for x in instances]
|
||||
mask_coarse_logits = self._forward_mask_coarse(features, pred_boxes)
|
||||
|
||||
mask_logits = self._forward_mask_point(features, mask_coarse_logits, instances)
|
||||
mask_rcnn_inference(mask_logits, instances)
|
||||
return instances
|
||||
|
||||
def _forward_mask_coarse(self, features, boxes):
|
||||
"""
|
||||
Forward logic of the coarse mask head.
|
||||
"""
|
||||
point_coords = generate_regular_grid_point_coords(
|
||||
np.sum(len(x) for x in boxes), self.mask_coarse_side_size, boxes[0].device
|
||||
)
|
||||
mask_coarse_features_list = [features[k] for k in self.mask_coarse_in_features]
|
||||
features_scales = [self._feature_scales[k] for k in self.mask_coarse_in_features]
|
||||
# For regular grids of points, this function is equivalent to `len(features_list)' calls
|
||||
# of `ROIAlign` (with `SAMPLING_RATIO=2`), and concat the results.
|
||||
mask_features, _ = point_sample_fine_grained_features(
|
||||
mask_coarse_features_list, features_scales, boxes, point_coords
|
||||
)
|
||||
return self.mask_coarse_head(mask_features)
|
||||
|
||||
def _forward_mask_point(self, features, mask_coarse_logits, instances):
|
||||
"""
|
||||
Forward logic of the mask point head.
|
||||
"""
|
||||
if not self.mask_point_on:
|
||||
return {} if self.training else mask_coarse_logits
|
||||
|
||||
mask_features_list = [features[k] for k in self.mask_point_in_features]
|
||||
features_scales = [self._feature_scales[k] for k in self.mask_point_in_features]
|
||||
|
||||
if self.training:
|
||||
proposal_boxes = [x.proposal_boxes for x in instances]
|
||||
gt_classes = cat([x.gt_classes for x in instances])
|
||||
with torch.no_grad():
|
||||
point_coords = get_uncertain_point_coords_with_randomness(
|
||||
mask_coarse_logits,
|
||||
lambda logits: calculate_uncertainty(logits, gt_classes),
|
||||
self.mask_point_train_num_points,
|
||||
self.mask_point_oversample_ratio,
|
||||
self.mask_point_importance_sample_ratio,
|
||||
)
|
||||
|
||||
fine_grained_features, point_coords_wrt_image = point_sample_fine_grained_features(
|
||||
mask_features_list, features_scales, proposal_boxes, point_coords
|
||||
)
|
||||
coarse_features = point_sample(mask_coarse_logits, point_coords, align_corners=False)
|
||||
point_logits = self.mask_point_head(fine_grained_features, coarse_features)
|
||||
return {
|
||||
"loss_mask_point": roi_mask_point_loss(
|
||||
point_logits, instances, point_coords_wrt_image
|
||||
)
|
||||
}
|
||||
else:
|
||||
pred_boxes = [x.pred_boxes for x in instances]
|
||||
pred_classes = cat([x.pred_classes for x in instances])
|
||||
# The subdivision code will fail with the empty list of boxes
|
||||
if len(pred_classes) == 0:
|
||||
return mask_coarse_logits
|
||||
|
||||
mask_logits = mask_coarse_logits.clone()
|
||||
for subdivions_step in range(self.mask_point_subdivision_steps):
|
||||
mask_logits = interpolate(
|
||||
mask_logits, scale_factor=2, mode="bilinear", align_corners=False
|
||||
)
|
||||
# If `mask_point_subdivision_num_points` is larger or equal to the
|
||||
# resolution of the next step, then we can skip this step
|
||||
H, W = mask_logits.shape[-2:]
|
||||
if (
|
||||
self.mask_point_subdivision_num_points >= 4 * H * W
|
||||
and subdivions_step < self.mask_point_subdivision_steps - 1
|
||||
):
|
||||
continue
|
||||
uncertainty_map = calculate_uncertainty(mask_logits, pred_classes)
|
||||
point_indices, point_coords = get_uncertain_point_coords_on_grid(
|
||||
uncertainty_map, self.mask_point_subdivision_num_points
|
||||
)
|
||||
fine_grained_features, _ = point_sample_fine_grained_features(
|
||||
mask_features_list, features_scales, pred_boxes, point_coords
|
||||
)
|
||||
coarse_features = point_sample(
|
||||
mask_coarse_logits, point_coords, align_corners=False
|
||||
)
|
||||
point_logits = self.mask_point_head(fine_grained_features, coarse_features)
|
||||
|
||||
# put mask point predictions to the right places on the upsampled grid.
|
||||
R, C, H, W = mask_logits.shape
|
||||
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
||||
mask_logits = (
|
||||
mask_logits.reshape(R, C, H * W)
|
||||
.scatter_(2, point_indices, point_logits)
|
||||
.view(R, C, H, W)
|
||||
)
|
||||
return mask_logits
|
@@ -0,0 +1,134 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from detectron2.layers import ShapeSpec, cat
|
||||
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
||||
|
||||
from .point_features import (
|
||||
get_uncertain_point_coords_on_grid,
|
||||
get_uncertain_point_coords_with_randomness,
|
||||
point_sample,
|
||||
)
|
||||
from .point_head import build_point_head
|
||||
|
||||
|
||||
def calculate_uncertainty(sem_seg_logits):
|
||||
"""
|
||||
For each location of the prediction `sem_seg_logits` we estimate uncerainty as the
|
||||
difference between top first and top second predicted logits.
|
||||
|
||||
Args:
|
||||
mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and
|
||||
C is the number of foreground classes. The values are logits.
|
||||
|
||||
Returns:
|
||||
scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with
|
||||
the most uncertain locations having the highest uncertainty score.
|
||||
"""
|
||||
top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
|
||||
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
|
||||
|
||||
|
||||
@SEM_SEG_HEADS_REGISTRY.register()
|
||||
class PointRendSemSegHead(nn.Module):
|
||||
"""
|
||||
A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME`
|
||||
and a point head set in `MODEL.POINT_HEAD.NAME`.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
|
||||
super().__init__()
|
||||
|
||||
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
|
||||
|
||||
self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get(
|
||||
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME
|
||||
)(cfg, input_shape)
|
||||
self._init_point_head(cfg, input_shape)
|
||||
|
||||
def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]):
|
||||
# fmt: off
|
||||
assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES
|
||||
feature_channels = {k: v.channels for k, v in input_shape.items()}
|
||||
self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES
|
||||
self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS
|
||||
self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO
|
||||
self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO
|
||||
self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS
|
||||
self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS
|
||||
# fmt: on
|
||||
|
||||
in_channels = np.sum([feature_channels[f] for f in self.in_features])
|
||||
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1))
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features)
|
||||
|
||||
if self.training:
|
||||
losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets)
|
||||
|
||||
with torch.no_grad():
|
||||
point_coords = get_uncertain_point_coords_with_randomness(
|
||||
coarse_sem_seg_logits,
|
||||
calculate_uncertainty,
|
||||
self.train_num_points,
|
||||
self.oversample_ratio,
|
||||
self.importance_sample_ratio,
|
||||
)
|
||||
coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)
|
||||
|
||||
fine_grained_features = cat(
|
||||
[
|
||||
point_sample(features[in_feature], point_coords, align_corners=False)
|
||||
for in_feature in self.in_features
|
||||
]
|
||||
)
|
||||
point_logits = self.point_head(fine_grained_features, coarse_features)
|
||||
point_targets = (
|
||||
point_sample(
|
||||
targets.unsqueeze(1).to(torch.float),
|
||||
point_coords,
|
||||
mode="nearest",
|
||||
align_corners=False,
|
||||
)
|
||||
.squeeze(1)
|
||||
.to(torch.long)
|
||||
)
|
||||
losses["loss_sem_seg_point"] = F.cross_entropy(
|
||||
point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value
|
||||
)
|
||||
return None, losses
|
||||
else:
|
||||
sem_seg_logits = coarse_sem_seg_logits.clone()
|
||||
for _ in range(self.subdivision_steps):
|
||||
sem_seg_logits = F.interpolate(
|
||||
sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False
|
||||
)
|
||||
uncertainty_map = calculate_uncertainty(sem_seg_logits)
|
||||
point_indices, point_coords = get_uncertain_point_coords_on_grid(
|
||||
uncertainty_map, self.subdivision_num_points
|
||||
)
|
||||
fine_grained_features = cat(
|
||||
[
|
||||
point_sample(features[in_feature], point_coords, align_corners=False)
|
||||
for in_feature in self.in_features
|
||||
]
|
||||
)
|
||||
coarse_features = point_sample(
|
||||
coarse_sem_seg_logits, point_coords, align_corners=False
|
||||
)
|
||||
point_logits = self.point_head(fine_grained_features, coarse_features)
|
||||
|
||||
# put sem seg point predictions to the right places on the upsampled grid.
|
||||
N, C, H, W = sem_seg_logits.shape
|
||||
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
||||
sem_seg_logits = (
|
||||
sem_seg_logits.reshape(N, C, H * W)
|
||||
.scatter_(2, point_indices, point_logits)
|
||||
.view(N, C, H, W)
|
||||
)
|
||||
return sem_seg_logits, {}
|
@@ -0,0 +1,2 @@
|
||||
python finetune_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_parsing.yaml --num-gpus 1
|
||||
#python finetune_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_parsing.yaml --num-gpus 1
|
@@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
PointRend Training Script.
|
||||
|
||||
This script is a simplified version of the training script in detectron2/tools.
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
import detectron2.utils.comm as comm
|
||||
from detectron2.checkpoint import DetectionCheckpointer
|
||||
from detectron2.config import get_cfg
|
||||
from detectron2.data import MetadataCatalog, build_detection_train_loader
|
||||
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
|
||||
from detectron2.evaluation import (
|
||||
CityscapesInstanceEvaluator,
|
||||
CityscapesSemSegEvaluator,
|
||||
COCOEvaluator,
|
||||
DatasetEvaluators,
|
||||
LVISEvaluator,
|
||||
SemSegEvaluator,
|
||||
verify_results,
|
||||
)
|
||||
|
||||
from point_rend import SemSegDatasetMapper, add_pointrend_config
|
||||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
"""
|
||||
We use the "DefaultTrainer" which contains a number pre-defined logic for
|
||||
standard training workflow. They may not work for you, especially if you
|
||||
are working on a new research project. In that case you can use the cleaner
|
||||
"SimpleTrainer", or write your own training loop.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
||||
"""
|
||||
Create evaluator(s) for a given dataset.
|
||||
This uses the special metadata "evaluator_type" associated with each builtin dataset.
|
||||
For your own dataset, you can simply create an evaluator manually in your
|
||||
script and do not have to worry about the hacky if-else logic here.
|
||||
"""
|
||||
if output_folder is None:
|
||||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||||
evaluator_list = []
|
||||
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
||||
if evaluator_type == "lvis":
|
||||
return LVISEvaluator(dataset_name, cfg, True, output_folder)
|
||||
if evaluator_type == "coco":
|
||||
return COCOEvaluator(dataset_name, cfg, True, output_folder)
|
||||
if evaluator_type == "sem_seg":
|
||||
return SemSegEvaluator(
|
||||
dataset_name,
|
||||
distributed=True,
|
||||
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
||||
ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
||||
output_dir=output_folder,
|
||||
)
|
||||
if evaluator_type == "cityscapes_instance":
|
||||
assert (
|
||||
torch.cuda.device_count() >= comm.get_rank()
|
||||
), "CityscapesEvaluator currently do not work with multiple machines."
|
||||
return CityscapesInstanceEvaluator(dataset_name)
|
||||
if evaluator_type == "cityscapes_sem_seg":
|
||||
assert (
|
||||
torch.cuda.device_count() >= comm.get_rank()
|
||||
), "CityscapesEvaluator currently do not work with multiple machines."
|
||||
return CityscapesSemSegEvaluator(dataset_name)
|
||||
if len(evaluator_list) == 0:
|
||||
raise NotImplementedError(
|
||||
"no Evaluator for the dataset {} with the type {}".format(
|
||||
dataset_name, evaluator_type
|
||||
)
|
||||
)
|
||||
if len(evaluator_list) == 1:
|
||||
return evaluator_list[0]
|
||||
return DatasetEvaluators(evaluator_list)
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE:
|
||||
mapper = SemSegDatasetMapper(cfg, True)
|
||||
else:
|
||||
mapper = None
|
||||
return build_detection_train_loader(cfg, mapper=mapper)
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
add_pointrend_config(cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
model = Trainer.build_model(cfg)
|
||||
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
res = Trainer.test(cfg, model)
|
||||
if comm.is_main_process():
|
||||
verify_results(cfg, res)
|
||||
return res
|
||||
|
||||
trainer = Trainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
launch(
|
||||
main,
|
||||
args.num_gpus,
|
||||
num_machines=args.num_machines,
|
||||
machine_rank=args.machine_rank,
|
||||
dist_url=args.dist_url,
|
||||
args=(args,),
|
||||
)
|
Reference in New Issue
Block a user