Add at new repo again
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .config import add_tridentnet_config
|
||||
from .trident_backbone import (
|
||||
TridentBottleneckBlock,
|
||||
build_trident_resnet_backbone,
|
||||
make_trident_stage,
|
||||
)
|
||||
from .trident_rpn import TridentRPN
|
||||
from .trident_rcnn import TridentRes5ROIHeads, TridentStandardROIHeads
|
@@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
from detectron2.config import CfgNode as CN
|
||||
|
||||
|
||||
def add_tridentnet_config(cfg):
|
||||
"""
|
||||
Add config for tridentnet.
|
||||
"""
|
||||
_C = cfg
|
||||
|
||||
_C.MODEL.TRIDENT = CN()
|
||||
|
||||
# Number of branches for TridentNet.
|
||||
_C.MODEL.TRIDENT.NUM_BRANCH = 3
|
||||
# Specify the dilations for each branch.
|
||||
_C.MODEL.TRIDENT.BRANCH_DILATIONS = [1, 2, 3]
|
||||
# Specify the stage for applying trident blocks. Default stage is Res4 according to the
|
||||
# TridentNet paper.
|
||||
_C.MODEL.TRIDENT.TRIDENT_STAGE = "res4"
|
||||
# Specify the test branch index TridentNet Fast inference:
|
||||
# - use -1 to aggregate results of all branches during inference.
|
||||
# - otherwise, only using specified branch for fast inference. Recommended setting is
|
||||
# to use the middle branch.
|
||||
_C.MODEL.TRIDENT.TEST_BRANCH_IDX = 1
|
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import fvcore.nn.weight_init as weight_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm
|
||||
from detectron2.modeling import BACKBONE_REGISTRY, ResNet, ResNetBlockBase, make_stage
|
||||
from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock
|
||||
|
||||
from .trident_conv import TridentConv
|
||||
|
||||
__all__ = ["TridentBottleneckBlock", "make_trident_stage", "build_trident_resnet_backbone"]
|
||||
|
||||
|
||||
class TridentBottleneckBlock(ResNetBlockBase):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
*,
|
||||
bottleneck_channels,
|
||||
stride=1,
|
||||
num_groups=1,
|
||||
norm="BN",
|
||||
stride_in_1x1=False,
|
||||
num_branch=3,
|
||||
dilations=(1, 2, 3),
|
||||
concat_output=False,
|
||||
test_branch_idx=-1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_branch (int): the number of branches in TridentNet.
|
||||
dilations (tuple): the dilations of multiple branches in TridentNet.
|
||||
concat_output (bool): if concatenate outputs of multiple branches in TridentNet.
|
||||
Use 'True' for the last trident block.
|
||||
"""
|
||||
super().__init__(in_channels, out_channels, stride)
|
||||
|
||||
assert num_branch == len(dilations)
|
||||
|
||||
self.num_branch = num_branch
|
||||
self.concat_output = concat_output
|
||||
self.test_branch_idx = test_branch_idx
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
norm=get_norm(norm, out_channels),
|
||||
)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
||||
|
||||
self.conv1 = Conv2d(
|
||||
in_channels,
|
||||
bottleneck_channels,
|
||||
kernel_size=1,
|
||||
stride=stride_1x1,
|
||||
bias=False,
|
||||
norm=get_norm(norm, bottleneck_channels),
|
||||
)
|
||||
|
||||
self.conv2 = TridentConv(
|
||||
bottleneck_channels,
|
||||
bottleneck_channels,
|
||||
kernel_size=3,
|
||||
stride=stride_3x3,
|
||||
paddings=dilations,
|
||||
bias=False,
|
||||
groups=num_groups,
|
||||
dilations=dilations,
|
||||
num_branch=num_branch,
|
||||
test_branch_idx=test_branch_idx,
|
||||
norm=get_norm(norm, bottleneck_channels),
|
||||
)
|
||||
|
||||
self.conv3 = Conv2d(
|
||||
bottleneck_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm=get_norm(norm, out_channels),
|
||||
)
|
||||
|
||||
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
||||
if layer is not None: # shortcut can be None
|
||||
weight_init.c2_msra_fill(layer)
|
||||
|
||||
def forward(self, x):
|
||||
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
||||
if not isinstance(x, list):
|
||||
x = [x] * num_branch
|
||||
out = [self.conv1(b) for b in x]
|
||||
out = [F.relu_(b) for b in out]
|
||||
|
||||
out = self.conv2(out)
|
||||
out = [F.relu_(b) for b in out]
|
||||
|
||||
out = [self.conv3(b) for b in out]
|
||||
|
||||
if self.shortcut is not None:
|
||||
shortcut = [self.shortcut(b) for b in x]
|
||||
else:
|
||||
shortcut = x
|
||||
|
||||
out = [out_b + shortcut_b for out_b, shortcut_b in zip(out, shortcut)]
|
||||
out = [F.relu_(b) for b in out]
|
||||
if self.concat_output:
|
||||
out = torch.cat(out)
|
||||
return out
|
||||
|
||||
|
||||
def make_trident_stage(block_class, num_blocks, first_stride, **kwargs):
|
||||
"""
|
||||
Create a resnet stage by creating many blocks for TridentNet.
|
||||
"""
|
||||
blocks = []
|
||||
for i in range(num_blocks - 1):
|
||||
blocks.append(block_class(stride=first_stride if i == 0 else 1, **kwargs))
|
||||
kwargs["in_channels"] = kwargs["out_channels"]
|
||||
blocks.append(block_class(stride=1, concat_output=True, **kwargs))
|
||||
return blocks
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_trident_resnet_backbone(cfg, input_shape):
|
||||
"""
|
||||
Create a ResNet instance from config for TridentNet.
|
||||
|
||||
Returns:
|
||||
ResNet: a :class:`ResNet` instance.
|
||||
"""
|
||||
# need registration of new blocks/stems?
|
||||
norm = cfg.MODEL.RESNETS.NORM
|
||||
stem = BasicStem(
|
||||
in_channels=input_shape.channels,
|
||||
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
||||
norm=norm,
|
||||
)
|
||||
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
|
||||
|
||||
if freeze_at >= 1:
|
||||
for p in stem.parameters():
|
||||
p.requires_grad = False
|
||||
stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)
|
||||
|
||||
# fmt: off
|
||||
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
|
||||
depth = cfg.MODEL.RESNETS.DEPTH
|
||||
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
||||
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
||||
bottleneck_channels = num_groups * width_per_group
|
||||
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
|
||||
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
|
||||
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
||||
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
|
||||
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
|
||||
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
|
||||
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
|
||||
num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
||||
branch_dilations = cfg.MODEL.TRIDENT.BRANCH_DILATIONS
|
||||
trident_stage = cfg.MODEL.TRIDENT.TRIDENT_STAGE
|
||||
test_branch_idx = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX
|
||||
# fmt: on
|
||||
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
|
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
|
||||
|
||||
stages = []
|
||||
|
||||
res_stage_idx = {"res2": 2, "res3": 3, "res4": 4, "res5": 5}
|
||||
out_stage_idx = [res_stage_idx[f] for f in out_features]
|
||||
trident_stage_idx = res_stage_idx[trident_stage]
|
||||
max_stage_idx = max(out_stage_idx)
|
||||
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
|
||||
dilation = res5_dilation if stage_idx == 5 else 1
|
||||
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
|
||||
stage_kargs = {
|
||||
"num_blocks": num_blocks_per_stage[idx],
|
||||
"first_stride": first_stride,
|
||||
"in_channels": in_channels,
|
||||
"bottleneck_channels": bottleneck_channels,
|
||||
"out_channels": out_channels,
|
||||
"num_groups": num_groups,
|
||||
"norm": norm,
|
||||
"stride_in_1x1": stride_in_1x1,
|
||||
"dilation": dilation,
|
||||
}
|
||||
if stage_idx == trident_stage_idx:
|
||||
assert not deform_on_per_stage[
|
||||
idx
|
||||
], "Not support deformable conv in Trident blocks yet."
|
||||
stage_kargs["block_class"] = TridentBottleneckBlock
|
||||
stage_kargs["num_branch"] = num_branch
|
||||
stage_kargs["dilations"] = branch_dilations
|
||||
stage_kargs["test_branch_idx"] = test_branch_idx
|
||||
stage_kargs.pop("dilation")
|
||||
elif deform_on_per_stage[idx]:
|
||||
stage_kargs["block_class"] = DeformBottleneckBlock
|
||||
stage_kargs["deform_modulated"] = deform_modulated
|
||||
stage_kargs["deform_num_groups"] = deform_num_groups
|
||||
else:
|
||||
stage_kargs["block_class"] = BottleneckBlock
|
||||
blocks = (
|
||||
make_trident_stage(**stage_kargs)
|
||||
if stage_idx == trident_stage_idx
|
||||
else make_stage(**stage_kargs)
|
||||
)
|
||||
in_channels = out_channels
|
||||
out_channels *= 2
|
||||
bottleneck_channels *= 2
|
||||
|
||||
if freeze_at >= stage_idx:
|
||||
for block in blocks:
|
||||
block.freeze()
|
||||
stages.append(blocks)
|
||||
return ResNet(stem, stages, out_features=out_features)
|
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from detectron2.layers.wrappers import _NewEmptyTensorOp
|
||||
|
||||
|
||||
class TridentConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
paddings=0,
|
||||
dilations=1,
|
||||
groups=1,
|
||||
num_branch=1,
|
||||
test_branch_idx=-1,
|
||||
bias=False,
|
||||
norm=None,
|
||||
activation=None,
|
||||
):
|
||||
super(TridentConv, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.num_branch = num_branch
|
||||
self.stride = _pair(stride)
|
||||
self.groups = groups
|
||||
self.with_bias = bias
|
||||
if isinstance(paddings, int):
|
||||
paddings = [paddings] * self.num_branch
|
||||
if isinstance(dilations, int):
|
||||
dilations = [dilations] * self.num_branch
|
||||
self.paddings = [_pair(padding) for padding in paddings]
|
||||
self.dilations = [_pair(dilation) for dilation in dilations]
|
||||
self.test_branch_idx = test_branch_idx
|
||||
self.norm = norm
|
||||
self.activation = activation
|
||||
|
||||
assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
||||
)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0)
|
||||
|
||||
def forward(self, inputs):
|
||||
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
||||
assert len(inputs) == num_branch
|
||||
|
||||
if inputs[0].numel() == 0:
|
||||
output_shape = [
|
||||
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
|
||||
for i, p, di, k, s in zip(
|
||||
inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
|
||||
)
|
||||
]
|
||||
output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape
|
||||
return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs]
|
||||
|
||||
if self.training or self.test_branch_idx == -1:
|
||||
outputs = [
|
||||
F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups)
|
||||
for input, dilation, padding in zip(inputs, self.dilations, self.paddings)
|
||||
]
|
||||
else:
|
||||
outputs = [
|
||||
F.conv2d(
|
||||
inputs[0],
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.paddings[self.test_branch_idx],
|
||||
self.dilations[self.test_branch_idx],
|
||||
self.groups,
|
||||
)
|
||||
]
|
||||
|
||||
if self.norm is not None:
|
||||
outputs = [self.norm(x) for x in outputs]
|
||||
if self.activation is not None:
|
||||
outputs = [self.activation(x) for x in outputs]
|
||||
return outputs
|
||||
|
||||
def extra_repr(self):
|
||||
tmpstr = "in_channels=" + str(self.in_channels)
|
||||
tmpstr += ", out_channels=" + str(self.out_channels)
|
||||
tmpstr += ", kernel_size=" + str(self.kernel_size)
|
||||
tmpstr += ", num_branch=" + str(self.num_branch)
|
||||
tmpstr += ", test_branch_idx=" + str(self.test_branch_idx)
|
||||
tmpstr += ", stride=" + str(self.stride)
|
||||
tmpstr += ", paddings=" + str(self.paddings)
|
||||
tmpstr += ", dilations=" + str(self.dilations)
|
||||
tmpstr += ", groups=" + str(self.groups)
|
||||
tmpstr += ", bias=" + str(self.with_bias)
|
||||
return tmpstr
|
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from detectron2.layers import batched_nms
|
||||
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
|
||||
from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads
|
||||
from detectron2.structures import Instances
|
||||
|
||||
|
||||
def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image):
|
||||
"""
|
||||
Merge detection results from different branches of TridentNet.
|
||||
Return detection results by applying non-maximum suppression (NMS) on bounding boxes
|
||||
and keep the unsuppressed boxes and other instances (e.g mask) if any.
|
||||
|
||||
Args:
|
||||
instances (list[Instances]): A list of N * num_branch instances that store detection
|
||||
results. Contain N images and each image has num_branch instances.
|
||||
num_branch (int): Number of branches used for merging detection results for each image.
|
||||
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return
|
||||
all detections.
|
||||
|
||||
Returns:
|
||||
results: (list[Instances]): A list of N instances, one for each image in the batch,
|
||||
that stores the topk most confidence detections after merging results from multiple
|
||||
branches.
|
||||
"""
|
||||
if num_branch == 1:
|
||||
return instances
|
||||
|
||||
batch_size = len(instances) // num_branch
|
||||
results = []
|
||||
for i in range(batch_size):
|
||||
instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)])
|
||||
|
||||
# Apply per-class NMS
|
||||
keep = batched_nms(
|
||||
instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh
|
||||
)
|
||||
keep = keep[:topk_per_image]
|
||||
result = instance[keep]
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ROI_HEADS_REGISTRY.register()
|
||||
class TridentRes5ROIHeads(Res5ROIHeads):
|
||||
"""
|
||||
The TridentNet ROIHeads in a typical "C4" R-CNN model.
|
||||
See :class:`Res5ROIHeads`.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
super().__init__(cfg, input_shape)
|
||||
|
||||
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
||||
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
|
||||
|
||||
def forward(self, images, features, proposals, targets=None):
|
||||
"""
|
||||
See :class:`Res5ROIHeads.forward`.
|
||||
"""
|
||||
num_branch = self.num_branch if self.training or not self.trident_fast else 1
|
||||
all_targets = targets * num_branch if targets is not None else None
|
||||
pred_instances, losses = super().forward(images, features, proposals, all_targets)
|
||||
del images, all_targets, targets
|
||||
|
||||
if self.training:
|
||||
return pred_instances, losses
|
||||
else:
|
||||
pred_instances = merge_branch_instances(
|
||||
pred_instances,
|
||||
num_branch,
|
||||
self.box_predictor.test_nms_thresh,
|
||||
self.box_predictor.test_topk_per_image,
|
||||
)
|
||||
|
||||
return pred_instances, {}
|
||||
|
||||
|
||||
@ROI_HEADS_REGISTRY.register()
|
||||
class TridentStandardROIHeads(StandardROIHeads):
|
||||
"""
|
||||
The `StandardROIHeads` for TridentNet.
|
||||
See :class:`StandardROIHeads`.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
super(TridentStandardROIHeads, self).__init__(cfg, input_shape)
|
||||
|
||||
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
||||
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
|
||||
|
||||
def forward(self, images, features, proposals, targets=None):
|
||||
"""
|
||||
See :class:`Res5ROIHeads.forward`.
|
||||
"""
|
||||
# Use 1 branch if using trident_fast during inference.
|
||||
num_branch = self.num_branch if self.training or not self.trident_fast else 1
|
||||
# Duplicate targets for all branches in TridentNet.
|
||||
all_targets = targets * num_branch if targets is not None else None
|
||||
pred_instances, losses = super().forward(images, features, proposals, all_targets)
|
||||
del images, all_targets, targets
|
||||
|
||||
if self.training:
|
||||
return pred_instances, losses
|
||||
else:
|
||||
pred_instances = merge_branch_instances(
|
||||
pred_instances,
|
||||
num_branch,
|
||||
self.box_predictor.test_nms_thresh,
|
||||
self.box_predictor.test_topk_per_image,
|
||||
)
|
||||
|
||||
return pred_instances, {}
|
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import torch
|
||||
|
||||
from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY
|
||||
from detectron2.modeling.proposal_generator.rpn import RPN
|
||||
from detectron2.structures import ImageList
|
||||
|
||||
|
||||
@PROPOSAL_GENERATOR_REGISTRY.register()
|
||||
class TridentRPN(RPN):
|
||||
"""
|
||||
Trident RPN subnetwork.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, input_shape):
|
||||
super(TridentRPN, self).__init__(cfg, input_shape)
|
||||
|
||||
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
|
||||
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
|
||||
|
||||
def forward(self, images, features, gt_instances=None):
|
||||
"""
|
||||
See :class:`RPN.forward`.
|
||||
"""
|
||||
num_branch = self.num_branch if self.training or not self.trident_fast else 1
|
||||
# Duplicate images and gt_instances for all branches in TridentNet.
|
||||
all_images = ImageList(
|
||||
torch.cat([images.tensor] * num_branch), images.image_sizes * num_branch
|
||||
)
|
||||
all_gt_instances = gt_instances * num_branch if gt_instances is not None else None
|
||||
|
||||
return super(TridentRPN, self).forward(all_images, features, all_gt_instances)
|
Reference in New Issue
Block a user