Add at new repo again

This commit is contained in:
2025-01-28 21:48:35 +00:00
commit 6e660ddb3c
564 changed files with 75575 additions and 0 deletions

View File

@@ -0,0 +1,135 @@
# PointRend: Image Segmentation as Rendering
Alexander Kirillov, Yuxin Wu, Kaiming He, Ross Girshick
[[`arXiv`](https://arxiv.org/abs/1912.08193)] [[`BibTeX`](#CitingPointRend)]
<div align="center">
<img src="https://alexander-kirillov.github.io/images/kirillov2019pointrend.jpg"/>
</div><br/>
In this repository, we release code for PointRend in Detectron2. PointRend can be flexibly applied to both instance and semantic segmentation tasks by building on top of existing state-of-the-art models.
## Installation
Install Detectron 2 following [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). You are ready to go!
## Quick start and visualization
This [Colab Notebook](https://colab.research.google.com/drive/1isGPL5h5_cKoPPhVL9XhMokRtHDvmMVL) tutorial contains examples of PointRend usage and visualizations of its point sampling stages.
## Training
To train a model with 8 GPUs run:
```bash
cd /path/to/detectron2/projects/PointRend
python train_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml --num-gpus 8
```
## Evaluation
Model evaluation can be done similarly:
```bash
cd /path/to/detectron2/projects/PointRend
python train_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint
```
# Pretrained Models
## Instance Segmentation
#### COCO
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom">Mask<br/>head</th>
<th valign="bottom">Backbone</th>
<th valign="bottom">lr<br/>sched</th>
<th valign="bottom">Output<br/>resolution</th>
<th valign="bottom">mask<br/>AP</th>
<th valign="bottom">mask<br/>AP&ast;</th>
<th valign="bottom">model id</th>
<th valign="bottom">download</th>
<!-- TABLE BODY -->
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml">PointRend</a></td>
<td align="center">R50-FPN</td>
<td align="center">1&times;</td>
<td align="center">224&times;224</td>
<td align="center">36.2</td>
<td align="center">39.7</td>
<td align="center">164254221</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco/164254221/model_final_88c6f8.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco/164254221/metrics.json">metrics</a></td>
</tr>
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml">PointRend</a></td>
<td align="center">R50-FPN</td>
<td align="center">3&times;</td>
<td align="center">224&times;224</td>
<td align="center">38.3</td>
<td align="center">41.6</td>
<td align="center">164955410</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/model_final_3c3198.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/metrics.json">metrics</a></td>
</tr>
</tbody></table>
AP&ast; is COCO mask AP evaluated against the higher-quality LVIS annotations; see the paper for details. Run `python detectron2/datasets/prepare_cocofied_lvis.py` to prepare GT files for AP&ast; evaluation. Since LVIS annotations are not exhaustive `lvis-api` and not `cocoapi` should be used to evaluate AP&ast;.
#### Cityscapes
Cityscapes model is trained with ImageNet pretraining.
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom">Mask<br/>head</th>
<th valign="bottom">Backbone</th>
<th valign="bottom">lr<br/>sched</th>
<th valign="bottom">Output<br/>resolution</th>
<th valign="bottom">mask<br/>AP</th>
<th valign="bottom">model id</th>
<th valign="bottom">download</th>
<!-- TABLE BODY -->
<tr><td align="left"><a href="configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml">PointRend</a></td>
<td align="center">R50-FPN</td>
<td align="center">1&times;</td>
<td align="center">224&times;224</td>
<td align="center">35.9</td>
<td align="center">164255101</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes/164255101/model_final_318a02.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes/164255101/metrics.json">metrics</a></td>
</tr>
</tbody></table>
## Semantic Segmentation
#### Cityscapes
Cityscapes model is trained with ImageNet pretraining.
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom">Method</th>
<th valign="bottom">Backbone</th>
<th valign="bottom">Output<br/>resolution</th>
<th valign="bottom">mIoU</th>
<th valign="bottom">model id</th>
<th valign="bottom">download</th>
<!-- TABLE BODY -->
<tr><td align="left"><a href="configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml">SemanticFPN + PointRend</a></td>
<td align="center">R101-FPN</td>
<td align="center">1024&times;2048</td>
<td align="center">78.6</td>
<td align="center">186480235</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes/186480235/model_final_5f3665.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/PointRend/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes/186480235/metrics.json">metrics</a></td>
</tr>
</tbody></table>
## <a name="CitingPointRend"></a>Citing PointRend
If you use PointRend, please use the following BibTeX entry.
```BibTeX
@InProceedings{kirillov2019pointrend,
title={{PointRend}: Image Segmentation as Rendering},
author={Alexander Kirillov and Yuxin Wu and Kaiming He and Ross Girshick},
journal={ArXiv:1912.08193},
year={2019}
}
```

View File

@@ -0,0 +1,21 @@
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml"
MODEL:
ROI_HEADS:
NAME: "PointRendROIHeads"
IN_FEATURES: ["p2", "p3", "p4", "p5"]
ROI_BOX_HEAD:
TRAIN_ON_PRED_BOXES: True
ROI_MASK_HEAD:
NAME: "CoarseMaskHead"
FC_DIM: 1024
NUM_FC: 2
OUTPUT_SIDE_RESOLUTION: 7
IN_FEATURES: ["p2"]
POINT_HEAD_ON: True
POINT_HEAD:
FC_DIM: 256
NUM_FC: 3
IN_FEATURES: ["p2"]
INPUT:
# PointRend for instance segmenation does not work with "polygon" mask_format.
MASK_FORMAT: "bitmask"

View File

@@ -0,0 +1,23 @@
_BASE_: Base-PointRend-RCNN-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
MASK_ON: true
RESNETS:
DEPTH: 50
ROI_HEADS:
NUM_CLASSES: 8
POINT_HEAD:
NUM_CLASSES: 8
DATASETS:
TEST: ("cityscapes_fine_instance_seg_val",)
TRAIN: ("cityscapes_fine_instance_seg_train",)
SOLVER:
BASE_LR: 0.01
IMS_PER_BATCH: 8
MAX_ITER: 24000
STEPS: (18000,)
INPUT:
MAX_SIZE_TEST: 2048
MAX_SIZE_TRAIN: 2048
MIN_SIZE_TEST: 1024
MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024)

View File

@@ -0,0 +1,9 @@
_BASE_: Base-PointRend-RCNN-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
MASK_ON: true
RESNETS:
DEPTH: 50
# To add COCO AP evaluation against the higher-quality LVIS annotations.
# DATASETS:
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")

View File

@@ -0,0 +1,13 @@
_BASE_: Base-PointRend-RCNN-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
MASK_ON: true
RESNETS:
DEPTH: 50
SOLVER:
STEPS: (210000, 250000)
MAX_ITER: 270000
# To add COCO AP evaluation against the higher-quality LVIS annotations.
# DATASETS:
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")

View File

@@ -0,0 +1,20 @@
_BASE_: Base-PointRend-RCNN-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
MASK_ON: true
RESNETS:
DEPTH: 50
ROI_HEADS:
NUM_CLASSES: 1
POINT_HEAD:
NUM_CLASSES: 1
SOLVER:
STEPS: (210000, 250000)
MAX_ITER: 270000
IMS_PER_BATCH: 1
# To add COCO AP evaluation against the higher-quality LVIS annotations.
# DATASETS:
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")
DATASETS:
TRAIN: ("CIHP_train",)
TEST: ("CIHP_val",)

View File

@@ -0,0 +1,28 @@
_BASE_: Base-PointRend-RCNN-FPN.yaml
MODEL:
WEIGHTS: "./X-101-32x8d.pkl"
PIXEL_STD: [57.375, 57.120, 58.395]
MASK_ON: true
RESNETS:
STRIDE_IN_1X1: False # this is a C2 model
NUM_GROUPS: 32
WIDTH_PER_GROUP: 8
DEPTH: 101
ROI_HEADS:
NUM_CLASSES: 1
POINT_HEAD:
NUM_CLASSES: 1
SOLVER:
STEPS: (210000, 250000)
MAX_ITER: 270000
IMS_PER_BATCH: 1
# To add COCO AP evaluation against the higher-quality LVIS annotations.
# DATASETS:
# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied")
INPUT:
MIN_SIZE_TRAIN: (640, 864)
MIN_SIZE_TRAIN_SAMPLING: "range"
MAX_SIZE_TRAIN: 1440
DATASETS:
TRAIN: ("CIHP_train",)
TEST: ("CIHP_val",)

View File

@@ -0,0 +1,19 @@
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml"
MODEL:
META_ARCHITECTURE: "SemanticSegmentor"
BACKBONE:
FREEZE_AT: 0
SEM_SEG_HEAD:
NAME: "PointRendSemSegHead"
POINT_HEAD:
NUM_CLASSES: 54
FC_DIM: 256
NUM_FC: 3
IN_FEATURES: ["p2"]
TRAIN_NUM_POINTS: 1024
SUBDIVISION_STEPS: 2
SUBDIVISION_NUM_POINTS: 8192
COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead"
DATASETS:
TRAIN: ("coco_2017_train_panoptic_stuffonly",)
TEST: ("coco_2017_val_panoptic_stuffonly",)

View File

@@ -0,0 +1,33 @@
_BASE_: Base-PointRend-Semantic-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl
RESNETS:
DEPTH: 101
SEM_SEG_HEAD:
NUM_CLASSES: 19
POINT_HEAD:
NUM_CLASSES: 19
TRAIN_NUM_POINTS: 2048
SUBDIVISION_NUM_POINTS: 8192
DATASETS:
TRAIN: ("cityscapes_fine_sem_seg_train",)
TEST: ("cityscapes_fine_sem_seg_val",)
SOLVER:
BASE_LR: 0.01
STEPS: (40000, 55000)
MAX_ITER: 65000
IMS_PER_BATCH: 32
INPUT:
MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048)
MIN_SIZE_TRAIN_SAMPLING: "choice"
MIN_SIZE_TEST: 1024
MAX_SIZE_TRAIN: 4096
MAX_SIZE_TEST: 2048
CROP:
ENABLED: True
TYPE: "absolute"
SIZE: (512, 1024)
SINGLE_CATEGORY_MAX_AREA: 0.75
COLOR_AUG_SSD: True
DATALOADER:
NUM_WORKERS: 16

View File

@@ -0,0 +1,5 @@
_BASE_: Base-PointRend-Semantic-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
RESNETS:
DEPTH: 50

View File

@@ -0,0 +1,139 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
PointRend Training Script.
This script is a simplified version of the training script in detectron2/tools.
"""
import os
import torch
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import (
CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator,
COCOEvaluator,
DatasetEvaluators,
LVISEvaluator,
SemSegEvaluator,
verify_results,
)
from point_rend import SemSegDatasetMapper, add_pointrend_config
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
# Register Custom Dataset
from detectron2.data.datasets import register_coco_instances
register_coco_instances("CIHP_train", {}, "/data03/v_xuyunqiu/multi_parsing/data/msrcnn_finetune_annotations/CIHP_train.json", "/data03/v_xuyunqiu/data/instance-level_human_parsing/Training/Images")
register_coco_instances("CIHP_val", {}, "/data03/v_xuyunqiu/multi_parsing/data/msrcnn_finetune_annotations/CIHP_val.json", "/data03/v_xuyunqiu/data/instance-level_human_parsing/Validation/Images")
class Trainer(DefaultTrainer):
"""
We use the "DefaultTrainer" which contains a number pre-defined logic for
standard training workflow. They may not work for you, especially if you
are working on a new research project. In that case you can use the cleaner
"SimpleTrainer", or write your own training loop.
"""
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each builtin dataset.
For your own dataset, you can simply create an evaluator manually in your
script and do not have to worry about the hacky if-else logic here.
"""
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type == "lvis":
return LVISEvaluator(dataset_name, cfg, True, output_folder)
if evaluator_type == "coco":
return COCOEvaluator(dataset_name, cfg, True, output_folder)
if evaluator_type == "sem_seg":
return SemSegEvaluator(
dataset_name,
distributed=True,
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
output_dir=output_folder,
)
if evaluator_type == "cityscapes_instance":
assert (
torch.cuda.device_count() >= comm.get_rank()
), "CityscapesEvaluator currently do not work with multiple machines."
return CityscapesInstanceEvaluator(dataset_name)
if evaluator_type == "cityscapes_sem_seg":
assert (
torch.cuda.device_count() >= comm.get_rank()
), "CityscapesEvaluator currently do not work with multiple machines."
return CityscapesSemSegEvaluator(dataset_name)
if len(evaluator_list) == 0:
raise NotImplementedError(
"no Evaluator for the dataset {} with the type {}".format(
dataset_name, evaluator_type
)
)
if len(evaluator_list) == 1:
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)
@classmethod
def build_train_loader(cls, cfg):
if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE:
mapper = SemSegDatasetMapper(cfg, True)
else:
mapper = None
return build_detection_train_loader(cfg, mapper=mapper)
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
add_pointrend_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model)
if comm.is_main_process():
verify_results(cfg, res)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)

View File

@@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .config import add_pointrend_config
from .coarse_mask_head import CoarseMaskHead
from .roi_heads import PointRendROIHeads
from .dataset_mapper import SemSegDatasetMapper
from .semantic_seg import PointRendSemSegHead

View File

@@ -0,0 +1,92 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import Conv2d, ShapeSpec
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY
@ROI_MASK_HEAD_REGISTRY.register()
class CoarseMaskHead(nn.Module):
"""
A mask head with fully connected layers. Given pooled features it first reduces channels and
spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously
to the standard box head.
"""
def __init__(self, cfg, input_shape: ShapeSpec):
"""
The following attributes are parsed from config:
conv_dim: the output dimension of the conv layers
fc_dim: the feature dimenstion of the FC layers
num_fc: the number of FC layers
output_side_resolution: side resolution of the output square mask prediction
"""
super(CoarseMaskHead, self).__init__()
# fmt: off
self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES
conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM
self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM
num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC
self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION
self.input_channels = input_shape.channels
self.input_h = input_shape.height
self.input_w = input_shape.width
# fmt: on
self.conv_layers = []
if self.input_channels > conv_dim:
self.reduce_channel_dim_conv = Conv2d(
self.input_channels,
conv_dim,
kernel_size=1,
stride=1,
padding=0,
bias=True,
activation=F.relu,
)
self.conv_layers.append(self.reduce_channel_dim_conv)
self.reduce_spatial_dim_conv = Conv2d(
conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu
)
self.conv_layers.append(self.reduce_spatial_dim_conv)
input_dim = conv_dim * self.input_h * self.input_w
input_dim //= 4
self.fcs = []
for k in range(num_fc):
fc = nn.Linear(input_dim, self.fc_dim)
self.add_module("coarse_mask_fc{}".format(k + 1), fc)
self.fcs.append(fc)
input_dim = self.fc_dim
output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution
self.prediction = nn.Linear(self.fc_dim, output_dim)
# use normal distribution initialization for mask prediction layer
nn.init.normal_(self.prediction.weight, std=0.001)
nn.init.constant_(self.prediction.bias, 0)
for layer in self.conv_layers:
weight_init.c2_msra_fill(layer)
for layer in self.fcs:
weight_init.c2_xavier_fill(layer)
def forward(self, x):
# unlike BaseMaskRCNNHead, this head only outputs intermediate
# features, because the features will be used later by PointHead.
N = x.shape[0]
x = x.view(N, self.input_channels, self.input_h, self.input_w)
for layer in self.conv_layers:
x = layer(x)
x = torch.flatten(x, start_dim=1)
for layer in self.fcs:
x = F.relu(layer(x))
return self.prediction(x).view(
N, self.num_classes, self.output_side_resolution, self.output_side_resolution
)

View File

@@ -0,0 +1,98 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
import random
import cv2
from fvcore.transforms.transform import Transform
class ColorAugSSDTransform(Transform):
"""
A color related data augmentation used in Single Shot Multibox Detector (SSD).
Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy,
Scott Reed, Cheng-Yang Fu, Alexander C. Berg.
SSD: Single Shot MultiBox Detector. ECCV 2016.
Implementation based on:
https://github.com/weiliu89/caffe/blob
/4817bf8b4200b35ada8ed0dc378dceaf38c539e4
/src/caffe/util/im_transforms.cpp
https://github.com/chainer/chainercv/blob
/7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv
/links/model/ssd/transforms.py
"""
def __init__(
self,
img_format,
brightness_delta=32,
contrast_low=0.5,
contrast_high=1.5,
saturation_low=0.5,
saturation_high=1.5,
hue_delta=18,
):
super().__init__()
assert img_format in ["BGR", "RGB"]
self.is_rgb = img_format == "RGB"
del img_format
self._set_attributes(locals())
def apply_coords(self, coords):
return coords
def apply_segmentation(self, segmentation):
return segmentation
def apply_image(self, img, interp=None):
if self.is_rgb:
img = img[:, :, [2, 1, 0]]
img = self.brightness(img)
if random.randrange(2):
img = self.contrast(img)
img = self.saturation(img)
img = self.hue(img)
else:
img = self.saturation(img)
img = self.hue(img)
img = self.contrast(img)
if self.is_rgb:
img = img[:, :, [2, 1, 0]]
return img
def convert(self, img, alpha=1, beta=0):
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img):
if random.randrange(2):
return self.convert(
img, beta=random.uniform(-self.brightness_delta, self.brightness_delta)
)
return img
def contrast(self, img):
if random.randrange(2):
return self.convert(img, alpha=random.uniform(self.contrast_low, self.contrast_high))
return img
def saturation(self, img):
if random.randrange(2):
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
img[:, :, 1] = self.convert(
img[:, :, 1], alpha=random.uniform(self.saturation_low, self.saturation_high)
)
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img
def hue(self, img):
if random.randrange(2):
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
img[:, :, 0] = (
img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta)
) % 180
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img

View File

@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.config import CfgNode as CN
def add_pointrend_config(cfg):
"""
Add config for PointRend.
"""
# We retry random cropping until no single category in semantic segmentation GT occupies more
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
# Color augmentatition from SSD paper for semantic segmentation model during training.
cfg.INPUT.COLOR_AUG_SSD = False
# Names of the input feature maps to be used by a coarse mask head.
cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",)
cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024
cfg.MODEL.ROI_MASK_HEAD.NUM_FC = 2
# The side size of a coarse mask head prediction.
cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7
# True if point head is used.
cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = False
cfg.MODEL.POINT_HEAD = CN()
cfg.MODEL.POINT_HEAD.NAME = "StandardPointHead"
cfg.MODEL.POINT_HEAD.NUM_CLASSES = 80
# Names of the input feature maps to be used by a mask point head.
cfg.MODEL.POINT_HEAD.IN_FEATURES = ("p2",)
# Number of points sampled during training for a mask point head.
cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 14 * 14
# Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
# original paper.
cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO = 3
# Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
# the original paper.
cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO = 0.75
# Number of subdivision steps during inference.
cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS = 5
# Maximum number of points selected at each subdivision step (N).
cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 28 * 28
cfg.MODEL.POINT_HEAD.FC_DIM = 256
cfg.MODEL.POINT_HEAD.NUM_FC = 3
cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False
# If True, then coarse prediction features are used as inout for each layer in PointRend's MLP.
cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead"

View File

@@ -0,0 +1,121 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
import torch
from fvcore.common.file_io import PathManager
from fvcore.transforms.transform import CropTransform
from PIL import Image
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from .color_augmentation import ColorAugSSDTransform
"""
This file contains the mapping that's applied to "dataset dicts" for semantic segmentation models.
Unlike the default DatasetMapper this mapper uses cropping as the last transformation.
"""
__all__ = ["SemSegDatasetMapper"]
class SemSegDatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by semantic segmentation models.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies geometric transforms to the image and annotation
3. Find and applies suitable cropping to the image and annotation
4. Prepare image and annotation to Tensors
"""
def __init__(self, cfg, is_train=True):
if cfg.INPUT.CROP.ENABLED and is_train:
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
else:
self.crop_gen = None
self.tfm_gens = utils.build_transform_gen(cfg, is_train)
if cfg.INPUT.COLOR_AUG_SSD:
self.tfm_gens.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
logging.getLogger(__name__).info(
"Color augmnetation used in training: " + str(self.tfm_gens[-1])
)
# fmt: off
self.img_format = cfg.INPUT.FORMAT
self.single_category_max_area = cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
# fmt: on
self.is_train = is_train
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)
assert "sem_seg_file_name" in dataset_dict
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.is_train:
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
sem_seg_gt = Image.open(f)
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
if self.crop_gen:
image, sem_seg_gt = crop_transform(
image,
sem_seg_gt,
self.crop_gen,
self.single_category_max_area,
self.ignore_value,
)
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
if not self.is_train:
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict
return dataset_dict
def crop_transform(image, sem_seg, crop_gen, single_category_max_area, ignore_value):
"""
Find a cropping window such that no single category occupies more than
`single_category_max_area` in `sem_seg`. The function retries random cropping 10 times max.
"""
if single_category_max_area >= 1.0:
crop_tfm = crop_gen.get_transform(image)
sem_seg_temp = crop_tfm.apply_segmentation(sem_seg)
else:
h, w = sem_seg.shape
crop_size = crop_gen.get_crop_size((h, w))
for _ in range(10):
y0 = np.random.randint(h - crop_size[0] + 1)
x0 = np.random.randint(w - crop_size[1] + 1)
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
cnt = cnt[labels != ignore_value]
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < single_category_max_area:
break
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
image = crop_tfm.apply_image(image)
return image, sem_seg_temp

View File

@@ -0,0 +1,216 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from torch.nn import functional as F
from detectron2.layers import cat
from detectron2.structures import Boxes
"""
Shape shorthand in this module:
N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the
number of images for semantic segmenation.
R: number of ROIs, combined over all images, in the minibatch
P: number of points
"""
def point_sample(input, point_coords, **kwargs):
"""
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
def generate_regular_grid_point_coords(R, side_size, device):
"""
Generate regular square grid of points in [0, 1] x [0, 1] coordinate space.
Args:
R (int): The number of grids to sample, one for each region.
side_size (int): The side size of the regular grid.
device (torch.device): Desired device of returned tensor.
Returns:
(Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates
for the regular grids.
"""
aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device)
r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False)
return r.view(1, -1, 2).expand(R, -1, -1)
def get_uncertain_point_coords_with_randomness(
coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio
):
"""
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
are calculated for each point using 'uncertainty_func' function that takes point's logit
prediction as input.
See PointRend paper for details.
Args:
coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
class-specific or class-agnostic prediction.
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
contains logit predictions for P points and returns their uncertainties as a Tensor of
shape (N, 1, P).
num_points (int): The number of points P to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
sampled points.
"""
assert oversample_ratio >= 1
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
num_boxes = coarse_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
# Calculating uncertainties of the coarse predictions first and sampling them for points leads
# to incorrect results.
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
# two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
# However, if we calculate uncertainties for the coarse predictions first,
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
num_boxes, num_uncertain_points, 2
)
if num_random_points > 0:
point_coords = cat(
[
point_coords,
torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
],
dim=1,
)
return point_coords
def get_uncertain_point_coords_on_grid(uncertainty_map, num_points):
"""
Find `num_points` most uncertain points from `uncertainty_map` grid.
Args:
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
values for a set of points on a regular H x W grid.
num_points (int): The number of points P to select.
Returns:
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
[0, H x W) of the most uncertain points.
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
coordinates of the most uncertain points from the H x W grid.
"""
R, _, H, W = uncertainty_map.shape
h_step = 1.0 / float(H)
w_step = 1.0 / float(W)
num_points = min(H * W, num_points)
point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1]
point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device)
point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
return point_indices, point_coords
def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords):
"""
Get features from feature maps in `features_list` that correspond to specific point coordinates
inside each bounding box from `boxes`.
Args:
features_list (list[Tensor]): A list of feature map tensors to get features from.
feature_scales (list[float]): A list of scales for tensors in `features_list`.
boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all
together.
point_coords (Tensor): A tensor of shape (R, P, 2) that contains
[0, 1] x [0, 1] box-normalized coordinates of the P sampled points.
Returns:
point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled
from all features maps in feature_list for P sampled points for all R boxes in `boxes`.
point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level
coordinates of P points.
"""
cat_boxes = Boxes.cat(boxes)
num_boxes = [len(b) for b in boxes]
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords)
split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes)
point_features = []
for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image):
point_features_per_image = []
for idx_feature, feature_map in enumerate(features_list):
h, w = feature_map.shape[-2:]
scale = torch.tensor([w, h], device=feature_map.device) / feature_scales[idx_feature]
point_coords_scaled = point_coords_wrt_image_per_image / scale
point_features_per_image.append(
point_sample(
feature_map[idx_img].unsqueeze(0),
point_coords_scaled.unsqueeze(0),
align_corners=False,
)
.squeeze(0)
.transpose(1, 0)
)
point_features.append(cat(point_features_per_image, dim=1))
return cat(point_features, dim=0), point_coords_wrt_image
def get_point_coords_wrt_image(boxes_coords, point_coords):
"""
Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates.
Args:
boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes.
coordinates.
point_coords (Tensor): A tensor of shape (R, P, 2) that contains
[0, 1] x [0, 1] box-normalized coordinates of the P sampled points.
Returns:
point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains
image-normalized coordinates of P sampled points.
"""
with torch.no_grad():
point_coords_wrt_image = point_coords.clone()
point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * (
boxes_coords[:, None, 2] - boxes_coords[:, None, 0]
)
point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * (
boxes_coords[:, None, 3] - boxes_coords[:, None, 1]
)
point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0]
point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1]
return point_coords_wrt_image

View File

@@ -0,0 +1,154 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import ShapeSpec, cat
from detectron2.structures import BitMasks
from detectron2.utils.events import get_event_storage
from detectron2.utils.registry import Registry
from .point_features import point_sample
POINT_HEAD_REGISTRY = Registry("POINT_HEAD")
POINT_HEAD_REGISTRY.__doc__ = """
Registry for point heads, which makes prediction for a given set of per-point features.
The registered object will be called with `obj(cfg, input_shape)`.
"""
def roi_mask_point_loss(mask_logits, instances, points_coord):
"""
Compute the point-based loss for instance segmentation mask predictions.
Args:
mask_logits (Tensor): A tensor of shape (R, C, P) or (R, 1, P) for class-specific or
class-agnostic, where R is the total number of predicted masks in all images, C is the
number of foreground classes, and P is the number of points sampled for each mask.
The values are logits.
instances (list[Instances]): A list of N Instances, where N is the number of images
in the batch. These instances are in 1:1 correspondence with the `mask_logits`. So, i_th
elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R.
The ground-truth labels (class, box, mask, ...) associated with each instance are stored
in fields.
points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of
predicted masks and P is the number of points for each mask. The coordinates are in
the image pixel coordinate space, i.e. [0, H] x [0, W].
Returns:
point_loss (Tensor): A scalar tensor containing the loss.
"""
assert len(instances) == 0 or isinstance(
instances[0].gt_masks, BitMasks
), "Point head works with GT in 'bitmask' format only. Set INPUT.MASK_FORMAT to 'bitmask'."
with torch.no_grad():
cls_agnostic_mask = mask_logits.size(1) == 1
total_num_masks = mask_logits.size(0)
gt_classes = []
gt_mask_logits = []
idx = 0
for instances_per_image in instances:
if not cls_agnostic_mask:
gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64)
gt_classes.append(gt_classes_per_image)
gt_bit_masks = instances_per_image.gt_masks.tensor
h, w = instances_per_image.gt_masks.image_size
scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device)
points_coord_grid_sample_format = (
points_coord[idx : idx + len(instances_per_image)] / scale
)
idx += len(instances_per_image)
gt_mask_logits.append(
point_sample(
gt_bit_masks.to(torch.float32).unsqueeze(1),
points_coord_grid_sample_format,
align_corners=False,
).squeeze(1)
)
gt_mask_logits = cat(gt_mask_logits)
# torch.mean (in binary_cross_entropy_with_logits) doesn't
# accept empty tensors, so handle it separately
if gt_mask_logits.numel() == 0:
return mask_logits.sum() * 0
if cls_agnostic_mask:
mask_logits = mask_logits[:, 0]
else:
indices = torch.arange(total_num_masks)
gt_classes = cat(gt_classes, dim=0)
mask_logits = mask_logits[indices, gt_classes]
# Log the training accuracy (using gt classes and 0.0 threshold for the logits)
mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8)
mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel()
get_event_storage().put_scalar("point_rend/accuracy", mask_accuracy)
point_loss = F.binary_cross_entropy_with_logits(
mask_logits, gt_mask_logits.to(dtype=torch.float32), reduction="mean"
)
return point_loss
@POINT_HEAD_REGISTRY.register()
class StandardPointHead(nn.Module):
"""
A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head
takes both fine-grained and coarse prediction features as its input.
"""
def __init__(self, cfg, input_shape: ShapeSpec):
"""
The following attributes are parsed from config:
fc_dim: the output dimension of each FC layers
num_fc: the number of FC layers
coarse_pred_each_layer: if True, coarse prediction features are concatenated to each
layer's input
"""
super(StandardPointHead, self).__init__()
# fmt: off
num_classes = cfg.MODEL.POINT_HEAD.NUM_CLASSES
fc_dim = cfg.MODEL.POINT_HEAD.FC_DIM
num_fc = cfg.MODEL.POINT_HEAD.NUM_FC
cls_agnostic_mask = cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK
self.coarse_pred_each_layer = cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER
input_channels = input_shape.channels
# fmt: on
fc_dim_in = input_channels + num_classes
self.fc_layers = []
for k in range(num_fc):
fc = nn.Conv1d(fc_dim_in, fc_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.add_module("fc{}".format(k + 1), fc)
self.fc_layers.append(fc)
fc_dim_in = fc_dim
fc_dim_in += num_classes if self.coarse_pred_each_layer else 0
num_mask_classes = 1 if cls_agnostic_mask else num_classes
self.predictor = nn.Conv1d(fc_dim_in, num_mask_classes, kernel_size=1, stride=1, padding=0)
for layer in self.fc_layers:
weight_init.c2_msra_fill(layer)
# use normal distribution initialization for mask prediction layer
nn.init.normal_(self.predictor.weight, std=0.001)
if self.predictor.bias is not None:
nn.init.constant_(self.predictor.bias, 0)
def forward(self, fine_grained_features, coarse_features):
x = torch.cat((fine_grained_features, coarse_features), dim=1)
for layer in self.fc_layers:
x = F.relu(layer(x))
if self.coarse_pred_each_layer:
x = cat((x, coarse_features), dim=1)
return self.predictor(x)
def build_point_head(cfg, input_channels):
"""
Build a point head defined by `cfg.MODEL.POINT_HEAD.NAME`.
"""
head_name = cfg.MODEL.POINT_HEAD.NAME
return POINT_HEAD_REGISTRY.get(head_name)(cfg, input_channels)

View File

@@ -0,0 +1,227 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
import torch
from detectron2.layers import ShapeSpec, cat, interpolate
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.modeling.roi_heads.mask_head import (
build_mask_head,
mask_rcnn_inference,
mask_rcnn_loss,
)
from detectron2.modeling.roi_heads.roi_heads import select_foreground_proposals
from .point_features import (
generate_regular_grid_point_coords,
get_uncertain_point_coords_on_grid,
get_uncertain_point_coords_with_randomness,
point_sample,
point_sample_fine_grained_features,
)
from .point_head import build_point_head, roi_mask_point_loss
def calculate_uncertainty(logits, classes):
"""
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
foreground class in `classes`.
Args:
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
class-agnostic, where R is the total number of predicted masks in all images and C is
the number of foreground classes. The values are logits.
classes (list): A list of length R that contains either predicted of ground truth class
for eash predicted mask.
Returns:
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
if logits.shape[1] == 1:
gt_class_logits = logits.clone()
else:
gt_class_logits = logits[
torch.arange(logits.shape[0], device=logits.device), classes
].unsqueeze(1)
return -(torch.abs(gt_class_logits))
@ROI_HEADS_REGISTRY.register()
class PointRendROIHeads(StandardROIHeads):
"""
The RoI heads class for PointRend instance segmentation models.
In this class we redefine the mask head of `StandardROIHeads` leaving all other heads intact.
To avoid namespace conflict with other heads we use names starting from `mask_` for all
variables that correspond to the mask head in the class's namespace.
"""
def __init__(self, cfg, input_shape):
# TODO use explicit args style
super().__init__(cfg, input_shape)
self._init_mask_head(cfg, input_shape)
def _init_mask_head(self, cfg, input_shape):
# fmt: off
self.mask_on = cfg.MODEL.MASK_ON
if not self.mask_on:
return
self.mask_coarse_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES
self.mask_coarse_side_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()}
# fmt: on
in_channels = np.sum([input_shape[f].channels for f in self.mask_coarse_in_features])
self.mask_coarse_head = build_mask_head(
cfg,
ShapeSpec(
channels=in_channels,
width=self.mask_coarse_side_size,
height=self.mask_coarse_side_size,
),
)
self._init_point_head(cfg, input_shape)
def _init_point_head(self, cfg, input_shape):
# fmt: off
self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON
if not self.mask_point_on:
return
assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS
self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO
self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO
# next two parameters are use in the adaptive subdivions inference procedure
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS
# fmt: on
in_channels = np.sum([input_shape[f].channels for f in self.mask_point_in_features])
self.mask_point_head = build_point_head(
cfg, ShapeSpec(channels=in_channels, width=1, height=1)
)
def _forward_mask(self, features, instances):
"""
Forward logic of the mask prediction branch.
Args:
features (dict[str, Tensor]): #level input features for mask prediction
instances (list[Instances]): the per-image instances to train/predict masks.
In training, they can be the proposals.
In inference, they can be the predicted boxes.
Returns:
In training, a dict of losses.
In inference, update `instances` with new fields "pred_masks" and return it.
"""
if not self.mask_on:
return {} if self.training else instances
if self.training:
proposals, _ = select_foreground_proposals(instances, self.num_classes)
proposal_boxes = [x.proposal_boxes for x in proposals]
mask_coarse_logits = self._forward_mask_coarse(features, proposal_boxes)
losses = {"loss_mask": mask_rcnn_loss(mask_coarse_logits, proposals)}
losses.update(self._forward_mask_point(features, mask_coarse_logits, proposals))
return losses
else:
pred_boxes = [x.pred_boxes for x in instances]
mask_coarse_logits = self._forward_mask_coarse(features, pred_boxes)
mask_logits = self._forward_mask_point(features, mask_coarse_logits, instances)
mask_rcnn_inference(mask_logits, instances)
return instances
def _forward_mask_coarse(self, features, boxes):
"""
Forward logic of the coarse mask head.
"""
point_coords = generate_regular_grid_point_coords(
np.sum(len(x) for x in boxes), self.mask_coarse_side_size, boxes[0].device
)
mask_coarse_features_list = [features[k] for k in self.mask_coarse_in_features]
features_scales = [self._feature_scales[k] for k in self.mask_coarse_in_features]
# For regular grids of points, this function is equivalent to `len(features_list)' calls
# of `ROIAlign` (with `SAMPLING_RATIO=2`), and concat the results.
mask_features, _ = point_sample_fine_grained_features(
mask_coarse_features_list, features_scales, boxes, point_coords
)
return self.mask_coarse_head(mask_features)
def _forward_mask_point(self, features, mask_coarse_logits, instances):
"""
Forward logic of the mask point head.
"""
if not self.mask_point_on:
return {} if self.training else mask_coarse_logits
mask_features_list = [features[k] for k in self.mask_point_in_features]
features_scales = [self._feature_scales[k] for k in self.mask_point_in_features]
if self.training:
proposal_boxes = [x.proposal_boxes for x in instances]
gt_classes = cat([x.gt_classes for x in instances])
with torch.no_grad():
point_coords = get_uncertain_point_coords_with_randomness(
mask_coarse_logits,
lambda logits: calculate_uncertainty(logits, gt_classes),
self.mask_point_train_num_points,
self.mask_point_oversample_ratio,
self.mask_point_importance_sample_ratio,
)
fine_grained_features, point_coords_wrt_image = point_sample_fine_grained_features(
mask_features_list, features_scales, proposal_boxes, point_coords
)
coarse_features = point_sample(mask_coarse_logits, point_coords, align_corners=False)
point_logits = self.mask_point_head(fine_grained_features, coarse_features)
return {
"loss_mask_point": roi_mask_point_loss(
point_logits, instances, point_coords_wrt_image
)
}
else:
pred_boxes = [x.pred_boxes for x in instances]
pred_classes = cat([x.pred_classes for x in instances])
# The subdivision code will fail with the empty list of boxes
if len(pred_classes) == 0:
return mask_coarse_logits
mask_logits = mask_coarse_logits.clone()
for subdivions_step in range(self.mask_point_subdivision_steps):
mask_logits = interpolate(
mask_logits, scale_factor=2, mode="bilinear", align_corners=False
)
# If `mask_point_subdivision_num_points` is larger or equal to the
# resolution of the next step, then we can skip this step
H, W = mask_logits.shape[-2:]
if (
self.mask_point_subdivision_num_points >= 4 * H * W
and subdivions_step < self.mask_point_subdivision_steps - 1
):
continue
uncertainty_map = calculate_uncertainty(mask_logits, pred_classes)
point_indices, point_coords = get_uncertain_point_coords_on_grid(
uncertainty_map, self.mask_point_subdivision_num_points
)
fine_grained_features, _ = point_sample_fine_grained_features(
mask_features_list, features_scales, pred_boxes, point_coords
)
coarse_features = point_sample(
mask_coarse_logits, point_coords, align_corners=False
)
point_logits = self.mask_point_head(fine_grained_features, coarse_features)
# put mask point predictions to the right places on the upsampled grid.
R, C, H, W = mask_logits.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
mask_logits = (
mask_logits.reshape(R, C, H * W)
.scatter_(2, point_indices, point_logits)
.view(R, C, H, W)
)
return mask_logits

View File

@@ -0,0 +1,134 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
from typing import Dict
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import ShapeSpec, cat
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from .point_features import (
get_uncertain_point_coords_on_grid,
get_uncertain_point_coords_with_randomness,
point_sample,
)
from .point_head import build_point_head
def calculate_uncertainty(sem_seg_logits):
"""
For each location of the prediction `sem_seg_logits` we estimate uncerainty as the
difference between top first and top second predicted logits.
Args:
mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and
C is the number of foreground classes. The values are logits.
Returns:
scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@SEM_SEG_HEADS_REGISTRY.register()
class PointRendSemSegHead(nn.Module):
"""
A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME`
and a point head set in `MODEL.POINT_HEAD.NAME`.
"""
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
super().__init__()
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get(
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME
)(cfg, input_shape)
self._init_point_head(cfg, input_shape)
def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]):
# fmt: off
assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES
feature_channels = {k: v.channels for k, v in input_shape.items()}
self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES
self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS
self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO
self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO
self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS
self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS
# fmt: on
in_channels = np.sum([feature_channels[f] for f in self.in_features])
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1))
def forward(self, features, targets=None):
coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features)
if self.training:
losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets)
with torch.no_grad():
point_coords = get_uncertain_point_coords_with_randomness(
coarse_sem_seg_logits,
calculate_uncertainty,
self.train_num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)
fine_grained_features = cat(
[
point_sample(features[in_feature], point_coords, align_corners=False)
for in_feature in self.in_features
]
)
point_logits = self.point_head(fine_grained_features, coarse_features)
point_targets = (
point_sample(
targets.unsqueeze(1).to(torch.float),
point_coords,
mode="nearest",
align_corners=False,
)
.squeeze(1)
.to(torch.long)
)
losses["loss_sem_seg_point"] = F.cross_entropy(
point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value
)
return None, losses
else:
sem_seg_logits = coarse_sem_seg_logits.clone()
for _ in range(self.subdivision_steps):
sem_seg_logits = F.interpolate(
sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False
)
uncertainty_map = calculate_uncertainty(sem_seg_logits)
point_indices, point_coords = get_uncertain_point_coords_on_grid(
uncertainty_map, self.subdivision_num_points
)
fine_grained_features = cat(
[
point_sample(features[in_feature], point_coords, align_corners=False)
for in_feature in self.in_features
]
)
coarse_features = point_sample(
coarse_sem_seg_logits, point_coords, align_corners=False
)
point_logits = self.point_head(fine_grained_features, coarse_features)
# put sem seg point predictions to the right places on the upsampled grid.
N, C, H, W = sem_seg_logits.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
sem_seg_logits = (
sem_seg_logits.reshape(N, C, H * W)
.scatter_(2, point_indices, point_logits)
.view(N, C, H, W)
)
return sem_seg_logits, {}

View File

@@ -0,0 +1,2 @@
python finetune_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_parsing.yaml --num-gpus 1
#python finetune_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_parsing.yaml --num-gpus 1

View File

@@ -0,0 +1,133 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
PointRend Training Script.
This script is a simplified version of the training script in detectron2/tools.
"""
import os
import torch
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import (
CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator,
COCOEvaluator,
DatasetEvaluators,
LVISEvaluator,
SemSegEvaluator,
verify_results,
)
from point_rend import SemSegDatasetMapper, add_pointrend_config
class Trainer(DefaultTrainer):
"""
We use the "DefaultTrainer" which contains a number pre-defined logic for
standard training workflow. They may not work for you, especially if you
are working on a new research project. In that case you can use the cleaner
"SimpleTrainer", or write your own training loop.
"""
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each builtin dataset.
For your own dataset, you can simply create an evaluator manually in your
script and do not have to worry about the hacky if-else logic here.
"""
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type == "lvis":
return LVISEvaluator(dataset_name, cfg, True, output_folder)
if evaluator_type == "coco":
return COCOEvaluator(dataset_name, cfg, True, output_folder)
if evaluator_type == "sem_seg":
return SemSegEvaluator(
dataset_name,
distributed=True,
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
output_dir=output_folder,
)
if evaluator_type == "cityscapes_instance":
assert (
torch.cuda.device_count() >= comm.get_rank()
), "CityscapesEvaluator currently do not work with multiple machines."
return CityscapesInstanceEvaluator(dataset_name)
if evaluator_type == "cityscapes_sem_seg":
assert (
torch.cuda.device_count() >= comm.get_rank()
), "CityscapesEvaluator currently do not work with multiple machines."
return CityscapesSemSegEvaluator(dataset_name)
if len(evaluator_list) == 0:
raise NotImplementedError(
"no Evaluator for the dataset {} with the type {}".format(
dataset_name, evaluator_type
)
)
if len(evaluator_list) == 1:
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)
@classmethod
def build_train_loader(cls, cfg):
if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE:
mapper = SemSegDatasetMapper(cfg, True)
else:
mapper = None
return build_detection_train_loader(cfg, mapper=mapper)
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
add_pointrend_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model)
if comm.is_main_process():
verify_results(cfg, res)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)