Add at new repo again
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
|
||||
# TridentNet in Detectron2
|
||||
**Scale-Aware Trident Networks for Object Detection**
|
||||
|
||||
Yanghao Li\*, Yuntao Chen\*, Naiyan Wang, Zhaoxiang Zhang
|
||||
|
||||
[[`TridentNet`](https://github.com/TuSimple/simpledet/tree/master/models/tridentnet)] [[`arXiv`](https://arxiv.org/abs/1901.01892)] [[`BibTeX`](#CitingTridentNet)]
|
||||
|
||||
<div align="center">
|
||||
<img src="https://drive.google.com/uc?export=view&id=10THEPdIPmf3ooMyNzrfZbpWihEBvixwt" width="700px" />
|
||||
</div>
|
||||
|
||||
In this repository, we implement TridentNet-Fast in Detectron2.
|
||||
Trident Network (TridentNet) aims to generate scale-specific feature maps with a uniform representational power. We construct a parallel multi-branch architecture in which each branch shares the same transformation parameters but with different receptive fields. TridentNet-Fast is a fast approximation version of TridentNet that could achieve significant improvements without any additional parameters and computational cost.
|
||||
|
||||
## Training
|
||||
|
||||
To train a model, run
|
||||
```bash
|
||||
python /path/to/detectron2/projects/TridentNet/train_net.py --config-file <config.yaml>
|
||||
```
|
||||
|
||||
For example, to launch end-to-end TridentNet training with ResNet-50 backbone on 8 GPUs,
|
||||
one should execute:
|
||||
```bash
|
||||
python /path/to/detectron2/projects/TridentNet/train_net.py --config-file configs/tridentnet_fast_R_50_C4_1x.yaml --num-gpus 8
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Model evaluation can be done similarly:
|
||||
```bash
|
||||
python /path/to/detectron2/projects/TridentNet/train_net.py --config-file configs/tridentnet_fast_R_50_C4_1x.yaml --eval-only MODEL.WEIGHTS model.pth
|
||||
```
|
||||
|
||||
## Results on MS-COCO in Detectron2
|
||||
|
||||
|Model|Backbone|Head|lr sched|AP|AP50|AP75|APs|APm|APl|download|
|
||||
|-----|--------|----|--------|--|----|----|---|---|---|--------|
|
||||
|Faster|R50-C4|C5-512ROI|1X|35.7|56.1|38.0|19.2|40.9|48.7|<a href="https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_C4_1x/137257644/model_final_721ade.pkl">model</a> \| <a href="https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_C4_1x/137257644/metrics.json">metrics</a>|
|
||||
|TridentFast|R50-C4|C5-128ROI|1X|38.0|58.1|40.8|19.5|42.2|54.6|<a href="https://dl.fbaipublicfiles.com/detectron2/TridentNet/tridentnet_fast_R_50_C4_1x/148572687/model_final_756cda.pkl">model</a> \| <a href="https://dl.fbaipublicfiles.com/detectron2/TridentNet/tridentnet_fast_R_50_C4_1x/148572687/metrics.json">metrics</a>|
|
||||
|Faster|R50-C4|C5-512ROI|3X|38.4|58.7|41.3|20.7|42.7|53.1|<a href="https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_C4_3x/137849393/model_final_f97cb7.pkl">model</a> \| <a href="https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_C4_3x/137849393/metrics.json">metrics</a>|
|
||||
|TridentFast|R50-C4|C5-128ROI|3X|40.6|60.8|43.6|23.4|44.7|57.1|<a href="https://dl.fbaipublicfiles.com/detectron2/TridentNet/tridentnet_fast_R_50_C4_3x/148572287/model_final_e1027c.pkl">model</a> \| <a href="https://dl.fbaipublicfiles.com/detectron2/TridentNet/tridentnet_fast_R_50_C4_3x/148572287/metrics.json">metrics</a>|
|
||||
|Faster|R101-C4|C5-512ROI|3X|41.1|61.4|44.0|22.2|45.5|55.9|<a href="https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_101_C4_3x/138204752/model_final_298dad.pkl">model</a> \| <a href="https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_101_C4_3x/138204752/metrics.json">metrics</a>|
|
||||
|TridentFast|R101-C4|C5-128ROI|3X|43.6|63.4|47.0|24.3|47.8|60.0|<a href="https://dl.fbaipublicfiles.com/detectron2/TridentNet/tridentnet_fast_R_101_C4_3x/148572198/model_final_164568.pkl">model</a> \| <a href="https://dl.fbaipublicfiles.com/detectron2/TridentNet/tridentnet_fast_R_101_C4_3x/148572198/metrics.json">metrics</a>|
|
||||
|
||||
|
||||
## <a name="CitingTridentNet"></a>Citing TridentNet
|
||||
|
||||
If you use TridentNet, please use the following BibTeX entry.
|
||||
|
||||
```
|
||||
@InProceedings{li2019scale,
|
||||
title={Scale-Aware Trident Networks for Object Detection},
|
||||
author={Li, Yanghao and Chen, Yuntao and Wang, Naiyan and Zhang, Zhaoxiang},
|
||||
journal={The International Conference on Computer Vision (ICCV)},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
@@ -0,0 +1,29 @@
|
||||
MODEL:
|
||||
META_ARCHITECTURE: "GeneralizedRCNN"
|
||||
BACKBONE:
|
||||
NAME: "build_trident_resnet_backbone"
|
||||
ROI_HEADS:
|
||||
NAME: "TridentRes5ROIHeads"
|
||||
POSITIVE_FRACTION: 0.5
|
||||
BATCH_SIZE_PER_IMAGE: 128
|
||||
PROPOSAL_APPEND_GT: False
|
||||
PROPOSAL_GENERATOR:
|
||||
NAME: "TridentRPN"
|
||||
RPN:
|
||||
POST_NMS_TOPK_TRAIN: 500
|
||||
TRIDENT:
|
||||
NUM_BRANCH: 3
|
||||
BRANCH_DILATIONS: [1, 2, 3]
|
||||
TEST_BRANCH_IDX: 1
|
||||
TRIDENT_STAGE: "res4"
|
||||
DATASETS:
|
||||
TRAIN: ("coco_2017_train",)
|
||||
TEST: ("coco_2017_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 16
|
||||
BASE_LR: 0.02
|
||||
STEPS: (60000, 80000)
|
||||
MAX_ITER: 90000
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
||||
VERSION: 2
|
@@ -0,0 +1,9 @@
|
||||
_BASE_: "Base-TridentNet-Fast-C4.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
|
||||
MASK_ON: False
|
||||
RESNETS:
|
||||
DEPTH: 101
|
||||
SOLVER:
|
||||
STEPS: (210000, 250000)
|
||||
MAX_ITER: 270000
|
@@ -0,0 +1,6 @@
|
||||
_BASE_: "Base-TridentNet-Fast-C4.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
MASK_ON: False
|
||||
RESNETS:
|
||||
DEPTH: 50
|
@@ -0,0 +1,9 @@
|
||||
_BASE_: "Base-TridentNet-Fast-C4.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
MASK_ON: False
|
||||
RESNETS:
|
||||
DEPTH: 50
|
||||
SOLVER:
|
||||
STEPS: (210000, 250000)
|
||||
MAX_ITER: 270000
|
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
TridentNet Training Script.
|
||||
|
||||
This script is a simplified version of the training script in detectron2/tools.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from detectron2.checkpoint import DetectionCheckpointer
|
||||
from detectron2.config import get_cfg
|
||||
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
|
||||
from detectron2.evaluation import COCOEvaluator
|
||||
|
||||
from tridentnet import add_tridentnet_config
|
||||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
||||
if output_folder is None:
|
||||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||||
return COCOEvaluator(dataset_name, cfg, True, output_folder)
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
add_tridentnet_config(cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
model = Trainer.build_model(cfg)
|
||||
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
res = Trainer.test(cfg, model)
|
||||
return res
|
||||
|
||||
trainer = Trainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
launch(
|
||||
main,
|
||||
args.num_gpus,
|
||||
num_machines=args.num_machines,
|
||||
machine_rank=args.machine_rank,
|
||||
dist_url=args.dist_url,
|
||||
args=(args,),
|
||||
)
|
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .config import add_tridentnet_config
|
||||
from .trident_backbone import (
|
||||
TridentBottleneckBlock,
|
||||
build_trident_resnet_backbone,
|
||||
make_trident_stage,
|
||||
)
|
||||
from .trident_rpn import TridentRPN
|
||||
from .trident_rcnn import TridentRes5ROIHeads, TridentStandardROIHeads
|
@@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
from detectron2.config import CfgNode as CN
|
||||
|
||||
|
||||
def add_tridentnet_config(cfg):
|
||||
"""
|
||||
Add config for tridentnet.
|
||||
"""
|
||||
_C = cfg
|
||||
|
||||
_C.MODEL.TRIDENT = CN()
|
||||
|
||||
# Number of branches for TridentNet.
|
||||
_C.MODEL.TRIDENT.NUM_BRANCH = 3
|
||||
# Specify the dilations for each branch.
|
||||
_C.MODEL.TRIDENT.BRANCH_DILATIONS = [1, 2, 3]
|
||||
# Specify the stage for applying trident blocks. Default stage is Res4 according to the
|
||||
# TridentNet paper.
|
||||
_C.MODEL.TRIDENT.TRIDENT_STAGE = "res4"
|
||||
# Specify the test branch index TridentNet Fast inference:
|
||||
# - use -1 to aggregate results of all branches during inference.
|
||||
# - otherwise, only using specified branch for fast inference. Recommended setting is
|
||||
# to use the middle branch.
|
||||
_C.MODEL.TRIDENT.TEST_BRANCH_IDX = 1
|
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import fvcore.nn.weight_init as weight_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm
|
||||
from detectron2.modeling import BACKBONE_REGISTRY, ResNet, ResNetBlockBase, make_stage
|
||||
from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock
|
||||
|
||||
from .trident_conv import TridentConv
|
||||
|
||||
__all__ = ["TridentBottleneckBlock", "make_trident_stage", "build_trident_resnet_backbone"]
|
||||
|
||||
|
||||
class TridentBottleneckBlock(ResNetBlockBase):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
*,
|
||||
bottleneck_channels,
|
||||
stride=1,
|
||||
num_groups=1,
|
||||
norm="BN",
|
||||
stride_in_1x1=False,
|
||||
num_branch=3,
|
||||
dilations=(1, 2, 3),
|
||||
concat_output=False,
|
||||
test_branch_idx=-1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_branch (int): the number of branches in TridentNet.
|
||||
dilations (tuple): the dilations of multiple branches in TridentNet.
|
||||
concat_output (bool): if concatenate outputs of multiple branches in TridentNet.
|
||||
Use 'True' for the last trident block.
|
||||
"""
|
||||
super().__init__(in_channels, out_channels, stride)
|
||||
|
||||
assert num_branch == len(dilations)
|
||||
|
||||
self.num_branch = num_branch
|
||||
self.concat_output = concat_output
|
||||
self.test_branch_idx = test_branch_idx
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
norm=get_norm(norm, out_channels),
|
||||
)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
||||
|
||||
self.conv1 = Conv2d(
|
||||
in_channels,
|
||||
bottleneck_channels,
|
||||
kernel_size=1,
|
||||
stride=stride_1x1,
|
||||
bias=False,
|
||||
norm=get_norm(norm, bottleneck_channels),
|
||||
)
|
||||
|
||||
self.conv2 = TridentConv(
|
||||
bottleneck_channels,
|
||||
bottleneck_channels,
|
||||
kernel_size=3,
|
||||
stride=stride_3x3,
|
||||
paddings=dilations,
|
||||
bias=False,
|
||||
groups=num_groups,
|
||||
dilations=dilations,
|
||||
num_branch=num_branch,
|
||||
test_branch_idx=test_branch_idx,
|
||||
norm=get_norm(norm, bottleneck_channels),
|
||||
)
|
||||
|
||||
self.conv3 = Conv2d(
|
||||
bottleneck_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm=get_norm(norm, out_channels),
|
||||
)
|
||||
|
||||
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
||||
if layer is not None: # shortcut can be None
|
||||
weight_init.c2_msra_fill(layer)
|
||||
|
||||
def forward(self, x):
|
||||
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
||||
if not isinstance(x, list):
|
||||
x = [x] * num_branch
|
||||
out = [self.conv1(b) for b in x]
|
||||
out = [F.relu_(b) for b in out]
|
||||
|
||||
out = self.conv2(out)
|
||||
out = [F.relu_(b) for b in out]
|
||||
|
||||
out = [self.conv3(b) for b in out]
|
||||
|
||||
if self.shortcut is not None:
|
||||
shortcut = [self.shortcut(b) for b in x]
|
||||
else:
|
||||
shortcut = x
|
||||
|
||||
out = [out_b + shortcut_b for out_b, shortcut_b in zip(out, shortcut)]
|
||||
out = [F.relu_(b) for b in out]
|
||||
if self.concat_output:
|
||||
out = torch.cat(out)
|
||||
return out
|
||||
|
||||
|
||||
def make_trident_stage(block_class, num_blocks, first_stride, **kwargs):
|
||||
"""
|
||||
Create a resnet stage by creating many blocks for TridentNet.
|
||||
"""
|
||||
blocks = []
|
||||
for i in range(num_blocks - 1):
|
||||
blocks.append(block_class(stride=first_stride if i == 0 else 1, **kwargs))
|
||||
kwargs["in_channels"] = kwargs["out_channels"]
|
||||
blocks.append(block_class(stride=1, concat_output=True, **kwargs))
|
||||
return blocks
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_trident_resnet_backbone(cfg, input_shape):
|
||||
"""
|
||||
Create a ResNet instance from config for TridentNet.
|
||||
|
||||
Returns:
|
||||
ResNet: a :class:`ResNet` instance.
|
||||
"""
|
||||
# need registration of new blocks/stems?
|
||||
norm = cfg.MODEL.RESNETS.NORM
|
||||
stem = BasicStem(
|
||||
in_channels=input_shape.channels,
|
||||
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
||||
norm=norm,
|
||||
)
|
||||
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
|
||||
|
||||
if freeze_at >= 1:
|
||||
for p in stem.parameters():
|
||||
p.requires_grad = False
|
||||
stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)
|
||||
|
||||
# fmt: off
|
||||
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
|
||||
depth = cfg.MODEL.RESNETS.DEPTH
|
||||
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
||||
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
||||
bottleneck_channels = num_groups * width_per_group
|
||||
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
|
||||
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
|
||||
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
||||
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
|
||||
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
|
||||
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
|
||||
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
|
||||
num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
||||
branch_dilations = cfg.MODEL.TRIDENT.BRANCH_DILATIONS
|
||||
trident_stage = cfg.MODEL.TRIDENT.TRIDENT_STAGE
|
||||
test_branch_idx = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX
|
||||
# fmt: on
|
||||
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
|
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
|
||||
|
||||
stages = []
|
||||
|
||||
res_stage_idx = {"res2": 2, "res3": 3, "res4": 4, "res5": 5}
|
||||
out_stage_idx = [res_stage_idx[f] for f in out_features]
|
||||
trident_stage_idx = res_stage_idx[trident_stage]
|
||||
max_stage_idx = max(out_stage_idx)
|
||||
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
|
||||
dilation = res5_dilation if stage_idx == 5 else 1
|
||||
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
|
||||
stage_kargs = {
|
||||
"num_blocks": num_blocks_per_stage[idx],
|
||||
"first_stride": first_stride,
|
||||
"in_channels": in_channels,
|
||||
"bottleneck_channels": bottleneck_channels,
|
||||
"out_channels": out_channels,
|
||||
"num_groups": num_groups,
|
||||
"norm": norm,
|
||||
"stride_in_1x1": stride_in_1x1,
|
||||
"dilation": dilation,
|
||||
}
|
||||
if stage_idx == trident_stage_idx:
|
||||
assert not deform_on_per_stage[
|
||||
idx
|
||||
], "Not support deformable conv in Trident blocks yet."
|
||||
stage_kargs["block_class"] = TridentBottleneckBlock
|
||||
stage_kargs["num_branch"] = num_branch
|
||||
stage_kargs["dilations"] = branch_dilations
|
||||
stage_kargs["test_branch_idx"] = test_branch_idx
|
||||
stage_kargs.pop("dilation")
|
||||
elif deform_on_per_stage[idx]:
|
||||
stage_kargs["block_class"] = DeformBottleneckBlock
|
||||
stage_kargs["deform_modulated"] = deform_modulated
|
||||
stage_kargs["deform_num_groups"] = deform_num_groups
|
||||
else:
|
||||
stage_kargs["block_class"] = BottleneckBlock
|
||||
blocks = (
|
||||
make_trident_stage(**stage_kargs)
|
||||
if stage_idx == trident_stage_idx
|
||||
else make_stage(**stage_kargs)
|
||||
)
|
||||
in_channels = out_channels
|
||||
out_channels *= 2
|
||||
bottleneck_channels *= 2
|
||||
|
||||
if freeze_at >= stage_idx:
|
||||
for block in blocks:
|
||||
block.freeze()
|
||||
stages.append(blocks)
|
||||
return ResNet(stem, stages, out_features=out_features)
|
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from detectron2.layers.wrappers import _NewEmptyTensorOp
|
||||
|
||||
|
||||
class TridentConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
paddings=0,
|
||||
dilations=1,
|
||||
groups=1,
|
||||
num_branch=1,
|
||||
test_branch_idx=-1,
|
||||
bias=False,
|
||||
norm=None,
|
||||
activation=None,
|
||||
):
|
||||
super(TridentConv, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.num_branch = num_branch
|
||||
self.stride = _pair(stride)
|
||||
self.groups = groups
|
||||
self.with_bias = bias
|
||||
if isinstance(paddings, int):
|
||||
paddings = [paddings] * self.num_branch
|
||||
if isinstance(dilations, int):
|
||||
dilations = [dilations] * self.num_branch
|
||||
self.paddings = [_pair(padding) for padding in paddings]
|
||||
self.dilations = [_pair(dilation) for dilation in dilations]
|
||||
self.test_branch_idx = test_branch_idx
|
||||
self.norm = norm
|
||||
self.activation = activation
|
||||
|
||||
assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
||||
)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0)
|
||||
|
||||
def forward(self, inputs):
|
||||
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
||||
assert len(inputs) == num_branch
|
||||
|
||||
if inputs[0].numel() == 0:
|
||||
output_shape = [
|
||||
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
|
||||
for i, p, di, k, s in zip(
|
||||
inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
|
||||
)
|
||||
]
|
||||
output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape
|
||||
return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs]
|
||||
|
||||
if self.training or self.test_branch_idx == -1:
|
||||
outputs = [
|
||||
F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups)
|
||||
for input, dilation, padding in zip(inputs, self.dilations, self.paddings)
|
||||
]
|
||||
else:
|
||||
outputs = [
|
||||
F.conv2d(
|
||||
inputs[0],
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.paddings[self.test_branch_idx],
|
||||
self.dilations[self.test_branch_idx],
|
||||
self.groups,
|
||||
)
|
||||
]
|
||||
|
||||
if self.norm is not None:
|
||||
outputs = [self.norm(x) for x in outputs]
|
||||
if self.activation is not None:
|
||||
outputs = [self.activation(x) for x in outputs]
|
||||
return outputs
|
||||
|
||||
def extra_repr(self):
|
||||
tmpstr = "in_channels=" + str(self.in_channels)
|
||||
tmpstr += ", out_channels=" + str(self.out_channels)
|
||||
tmpstr += ", kernel_size=" + str(self.kernel_size)
|
||||
tmpstr += ", num_branch=" + str(self.num_branch)
|
||||
tmpstr += ", test_branch_idx=" + str(self.test_branch_idx)
|
||||
tmpstr += ", stride=" + str(self.stride)
|
||||
tmpstr += ", paddings=" + str(self.paddings)
|
||||
tmpstr += ", dilations=" + str(self.dilations)
|
||||
tmpstr += ", groups=" + str(self.groups)
|
||||
tmpstr += ", bias=" + str(self.with_bias)
|
||||
return tmpstr
|
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from detectron2.layers import batched_nms
|
||||
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
|
||||
from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads
|
||||
from detectron2.structures import Instances
|
||||
|
||||
|
||||
def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image):
|
||||
"""
|
||||
Merge detection results from different branches of TridentNet.
|
||||
Return detection results by applying non-maximum suppression (NMS) on bounding boxes
|
||||
and keep the unsuppressed boxes and other instances (e.g mask) if any.
|
||||
|
||||
Args:
|
||||
instances (list[Instances]): A list of N * num_branch instances that store detection
|
||||
results. Contain N images and each image has num_branch instances.
|
||||
num_branch (int): Number of branches used for merging detection results for each image.
|
||||
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return
|
||||
all detections.
|
||||
|
||||
Returns:
|
||||
results: (list[Instances]): A list of N instances, one for each image in the batch,
|
||||
that stores the topk most confidence detections after merging results from multiple
|
||||
branches.
|
||||
"""
|
||||
if num_branch == 1:
|
||||
return instances
|
||||
|
||||
batch_size = len(instances) // num_branch
|
||||
results = []
|
||||
for i in range(batch_size):
|
||||
instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)])
|
||||
|
||||
# Apply per-class NMS
|
||||
keep = batched_nms(
|
||||
instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh
|
||||
)
|
||||
keep = keep[:topk_per_image]
|
||||
result = instance[keep]
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ROI_HEADS_REGISTRY.register()
|
||||
class TridentRes5ROIHeads(Res5ROIHeads):
|
||||
"""
|
||||
The TridentNet ROIHeads in a typical "C4" R-CNN model.
|
||||
See :class:`Res5ROIHeads`.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
super().__init__(cfg, input_shape)
|
||||
|
||||
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
||||
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
|
||||
|
||||
def forward(self, images, features, proposals, targets=None):
|
||||
"""
|
||||
See :class:`Res5ROIHeads.forward`.
|
||||
"""
|
||||
num_branch = self.num_branch if self.training or not self.trident_fast else 1
|
||||
all_targets = targets * num_branch if targets is not None else None
|
||||
pred_instances, losses = super().forward(images, features, proposals, all_targets)
|
||||
del images, all_targets, targets
|
||||
|
||||
if self.training:
|
||||
return pred_instances, losses
|
||||
else:
|
||||
pred_instances = merge_branch_instances(
|
||||
pred_instances,
|
||||
num_branch,
|
||||
self.box_predictor.test_nms_thresh,
|
||||
self.box_predictor.test_topk_per_image,
|
||||
)
|
||||
|
||||
return pred_instances, {}
|
||||
|
||||
|
||||
@ROI_HEADS_REGISTRY.register()
|
||||
class TridentStandardROIHeads(StandardROIHeads):
|
||||
"""
|
||||
The `StandardROIHeads` for TridentNet.
|
||||
See :class:`StandardROIHeads`.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
super(TridentStandardROIHeads, self).__init__(cfg, input_shape)
|
||||
|
||||
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
||||
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
|
||||
|
||||
def forward(self, images, features, proposals, targets=None):
|
||||
"""
|
||||
See :class:`Res5ROIHeads.forward`.
|
||||
"""
|
||||
# Use 1 branch if using trident_fast during inference.
|
||||
num_branch = self.num_branch if self.training or not self.trident_fast else 1
|
||||
# Duplicate targets for all branches in TridentNet.
|
||||
all_targets = targets * num_branch if targets is not None else None
|
||||
pred_instances, losses = super().forward(images, features, proposals, all_targets)
|
||||
del images, all_targets, targets
|
||||
|
||||
if self.training:
|
||||
return pred_instances, losses
|
||||
else:
|
||||
pred_instances = merge_branch_instances(
|
||||
pred_instances,
|
||||
num_branch,
|
||||
self.box_predictor.test_nms_thresh,
|
||||
self.box_predictor.test_topk_per_image,
|
||||
)
|
||||
|
||||
return pred_instances, {}
|
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import torch
|
||||
|
||||
from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY
|
||||
from detectron2.modeling.proposal_generator.rpn import RPN
|
||||
from detectron2.structures import ImageList
|
||||
|
||||
|
||||
@PROPOSAL_GENERATOR_REGISTRY.register()
|
||||
class TridentRPN(RPN):
|
||||
"""
|
||||
Trident RPN subnetwork.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
super(TridentRPN, self).__init__(cfg, input_shape)
|
||||
|
||||
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
||||
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
|
||||
|
||||
def forward(self, images, features, gt_instances=None):
|
||||
"""
|
||||
See :class:`RPN.forward`.
|
||||
"""
|
||||
num_branch = self.num_branch if self.training or not self.trident_fast else 1
|
||||
# Duplicate images and gt_instances for all branches in TridentNet.
|
||||
all_images = ImageList(
|
||||
torch.cat([images.tensor] * num_branch), images.image_sizes * num_branch
|
||||
)
|
||||
all_gt_instances = gt_instances * num_branch if gt_instances is not None else None
|
||||
|
||||
return super(TridentRPN, self).forward(all_images, features, all_gt_instances)
|
Reference in New Issue
Block a user