Add at new repo again

This commit is contained in:
2025-01-28 21:48:35 +00:00
commit 6e660ddb3c
564 changed files with 75575 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .config import add_tridentnet_config
from .trident_backbone import (
TridentBottleneckBlock,
build_trident_resnet_backbone,
make_trident_stage,
)
from .trident_rpn import TridentRPN
from .trident_rcnn import TridentRes5ROIHeads, TridentStandardROIHeads

View File

@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.config import CfgNode as CN
def add_tridentnet_config(cfg):
"""
Add config for tridentnet.
"""
_C = cfg
_C.MODEL.TRIDENT = CN()
# Number of branches for TridentNet.
_C.MODEL.TRIDENT.NUM_BRANCH = 3
# Specify the dilations for each branch.
_C.MODEL.TRIDENT.BRANCH_DILATIONS = [1, 2, 3]
# Specify the stage for applying trident blocks. Default stage is Res4 according to the
# TridentNet paper.
_C.MODEL.TRIDENT.TRIDENT_STAGE = "res4"
# Specify the test branch index TridentNet Fast inference:
# - use -1 to aggregate results of all branches during inference.
# - otherwise, only using specified branch for fast inference. Recommended setting is
# to use the middle branch.
_C.MODEL.TRIDENT.TEST_BRANCH_IDX = 1

View File

@@ -0,0 +1,223 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn.functional as F
from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm
from detectron2.modeling import BACKBONE_REGISTRY, ResNet, ResNetBlockBase, make_stage
from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock
from .trident_conv import TridentConv
__all__ = ["TridentBottleneckBlock", "make_trident_stage", "build_trident_resnet_backbone"]
class TridentBottleneckBlock(ResNetBlockBase):
def __init__(
self,
in_channels,
out_channels,
*,
bottleneck_channels,
stride=1,
num_groups=1,
norm="BN",
stride_in_1x1=False,
num_branch=3,
dilations=(1, 2, 3),
concat_output=False,
test_branch_idx=-1,
):
"""
Args:
num_branch (int): the number of branches in TridentNet.
dilations (tuple): the dilations of multiple branches in TridentNet.
concat_output (bool): if concatenate outputs of multiple branches in TridentNet.
Use 'True' for the last trident block.
"""
super().__init__(in_channels, out_channels, stride)
assert num_branch == len(dilations)
self.num_branch = num_branch
self.concat_output = concat_output
self.test_branch_idx = test_branch_idx
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=get_norm(norm, bottleneck_channels),
)
self.conv2 = TridentConv(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
paddings=dilations,
bias=False,
groups=num_groups,
dilations=dilations,
num_branch=num_branch,
test_branch_idx=test_branch_idx,
norm=get_norm(norm, bottleneck_channels),
)
self.conv3 = Conv2d(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
def forward(self, x):
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
if not isinstance(x, list):
x = [x] * num_branch
out = [self.conv1(b) for b in x]
out = [F.relu_(b) for b in out]
out = self.conv2(out)
out = [F.relu_(b) for b in out]
out = [self.conv3(b) for b in out]
if self.shortcut is not None:
shortcut = [self.shortcut(b) for b in x]
else:
shortcut = x
out = [out_b + shortcut_b for out_b, shortcut_b in zip(out, shortcut)]
out = [F.relu_(b) for b in out]
if self.concat_output:
out = torch.cat(out)
return out
def make_trident_stage(block_class, num_blocks, first_stride, **kwargs):
"""
Create a resnet stage by creating many blocks for TridentNet.
"""
blocks = []
for i in range(num_blocks - 1):
blocks.append(block_class(stride=first_stride if i == 0 else 1, **kwargs))
kwargs["in_channels"] = kwargs["out_channels"]
blocks.append(block_class(stride=1, concat_output=True, **kwargs))
return blocks
@BACKBONE_REGISTRY.register()
def build_trident_resnet_backbone(cfg, input_shape):
"""
Create a ResNet instance from config for TridentNet.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# need registration of new blocks/stems?
norm = cfg.MODEL.RESNETS.NORM
stem = BasicStem(
in_channels=input_shape.channels,
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
norm=norm,
)
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
if freeze_at >= 1:
for p in stem.parameters():
p.requires_grad = False
stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)
# fmt: off
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
depth = cfg.MODEL.RESNETS.DEPTH
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
bottleneck_channels = num_groups * width_per_group
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
branch_dilations = cfg.MODEL.TRIDENT.BRANCH_DILATIONS
trident_stage = cfg.MODEL.TRIDENT.TRIDENT_STAGE
test_branch_idx = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX
# fmt: on
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
stages = []
res_stage_idx = {"res2": 2, "res3": 3, "res4": 4, "res5": 5}
out_stage_idx = [res_stage_idx[f] for f in out_features]
trident_stage_idx = res_stage_idx[trident_stage]
max_stage_idx = max(out_stage_idx)
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
dilation = res5_dilation if stage_idx == 5 else 1
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
stage_kargs = {
"num_blocks": num_blocks_per_stage[idx],
"first_stride": first_stride,
"in_channels": in_channels,
"bottleneck_channels": bottleneck_channels,
"out_channels": out_channels,
"num_groups": num_groups,
"norm": norm,
"stride_in_1x1": stride_in_1x1,
"dilation": dilation,
}
if stage_idx == trident_stage_idx:
assert not deform_on_per_stage[
idx
], "Not support deformable conv in Trident blocks yet."
stage_kargs["block_class"] = TridentBottleneckBlock
stage_kargs["num_branch"] = num_branch
stage_kargs["dilations"] = branch_dilations
stage_kargs["test_branch_idx"] = test_branch_idx
stage_kargs.pop("dilation")
elif deform_on_per_stage[idx]:
stage_kargs["block_class"] = DeformBottleneckBlock
stage_kargs["deform_modulated"] = deform_modulated
stage_kargs["deform_num_groups"] = deform_num_groups
else:
stage_kargs["block_class"] = BottleneckBlock
blocks = (
make_trident_stage(**stage_kargs)
if stage_idx == trident_stage_idx
else make_stage(**stage_kargs)
)
in_channels = out_channels
out_channels *= 2
bottleneck_channels *= 2
if freeze_at >= stage_idx:
for block in blocks:
block.freeze()
stages.append(blocks)
return ResNet(stem, stages, out_features=out_features)

View File

@@ -0,0 +1,107 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
from detectron2.layers.wrappers import _NewEmptyTensorOp
class TridentConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
paddings=0,
dilations=1,
groups=1,
num_branch=1,
test_branch_idx=-1,
bias=False,
norm=None,
activation=None,
):
super(TridentConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.num_branch = num_branch
self.stride = _pair(stride)
self.groups = groups
self.with_bias = bias
if isinstance(paddings, int):
paddings = [paddings] * self.num_branch
if isinstance(dilations, int):
dilations = [dilations] * self.num_branch
self.paddings = [_pair(padding) for padding in paddings]
self.dilations = [_pair(dilation) for dilation in dilations]
self.test_branch_idx = test_branch_idx
self.norm = norm
self.activation = activation
assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
if self.bias is not None:
nn.init.constant_(self.bias, 0)
def forward(self, inputs):
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
assert len(inputs) == num_branch
if inputs[0].numel() == 0:
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
for i, p, di, k, s in zip(
inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
)
]
output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape
return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs]
if self.training or self.test_branch_idx == -1:
outputs = [
F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups)
for input, dilation, padding in zip(inputs, self.dilations, self.paddings)
]
else:
outputs = [
F.conv2d(
inputs[0],
self.weight,
self.bias,
self.stride,
self.paddings[self.test_branch_idx],
self.dilations[self.test_branch_idx],
self.groups,
)
]
if self.norm is not None:
outputs = [self.norm(x) for x in outputs]
if self.activation is not None:
outputs = [self.activation(x) for x in outputs]
return outputs
def extra_repr(self):
tmpstr = "in_channels=" + str(self.in_channels)
tmpstr += ", out_channels=" + str(self.out_channels)
tmpstr += ", kernel_size=" + str(self.kernel_size)
tmpstr += ", num_branch=" + str(self.num_branch)
tmpstr += ", test_branch_idx=" + str(self.test_branch_idx)
tmpstr += ", stride=" + str(self.stride)
tmpstr += ", paddings=" + str(self.paddings)
tmpstr += ", dilations=" + str(self.dilations)
tmpstr += ", groups=" + str(self.groups)
tmpstr += ", bias=" + str(self.with_bias)
return tmpstr

View File

@@ -0,0 +1,116 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.layers import batched_nms
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads
from detectron2.structures import Instances
def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image):
"""
Merge detection results from different branches of TridentNet.
Return detection results by applying non-maximum suppression (NMS) on bounding boxes
and keep the unsuppressed boxes and other instances (e.g mask) if any.
Args:
instances (list[Instances]): A list of N * num_branch instances that store detection
results. Contain N images and each image has num_branch instances.
num_branch (int): Number of branches used for merging detection results for each image.
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return
all detections.
Returns:
results: (list[Instances]): A list of N instances, one for each image in the batch,
that stores the topk most confidence detections after merging results from multiple
branches.
"""
if num_branch == 1:
return instances
batch_size = len(instances) // num_branch
results = []
for i in range(batch_size):
instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)])
# Apply per-class NMS
keep = batched_nms(
instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh
)
keep = keep[:topk_per_image]
result = instance[keep]
results.append(result)
return results
@ROI_HEADS_REGISTRY.register()
class TridentRes5ROIHeads(Res5ROIHeads):
"""
The TridentNet ROIHeads in a typical "C4" R-CNN model.
See :class:`Res5ROIHeads`.
"""
def __init__(self, cfg, input_shape):
super().__init__(cfg, input_shape)
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
def forward(self, images, features, proposals, targets=None):
"""
See :class:`Res5ROIHeads.forward`.
"""
num_branch = self.num_branch if self.training or not self.trident_fast else 1
all_targets = targets * num_branch if targets is not None else None
pred_instances, losses = super().forward(images, features, proposals, all_targets)
del images, all_targets, targets
if self.training:
return pred_instances, losses
else:
pred_instances = merge_branch_instances(
pred_instances,
num_branch,
self.box_predictor.test_nms_thresh,
self.box_predictor.test_topk_per_image,
)
return pred_instances, {}
@ROI_HEADS_REGISTRY.register()
class TridentStandardROIHeads(StandardROIHeads):
"""
The `StandardROIHeads` for TridentNet.
See :class:`StandardROIHeads`.
"""
def __init__(self, cfg, input_shape):
super(TridentStandardROIHeads, self).__init__(cfg, input_shape)
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
def forward(self, images, features, proposals, targets=None):
"""
See :class:`Res5ROIHeads.forward`.
"""
# Use 1 branch if using trident_fast during inference.
num_branch = self.num_branch if self.training or not self.trident_fast else 1
# Duplicate targets for all branches in TridentNet.
all_targets = targets * num_branch if targets is not None else None
pred_instances, losses = super().forward(images, features, proposals, all_targets)
del images, all_targets, targets
if self.training:
return pred_instances, losses
else:
pred_instances = merge_branch_instances(
pred_instances,
num_branch,
self.box_predictor.test_nms_thresh,
self.box_predictor.test_topk_per_image,
)
return pred_instances, {}

View File

@@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY
from detectron2.modeling.proposal_generator.rpn import RPN
from detectron2.structures import ImageList
@PROPOSAL_GENERATOR_REGISTRY.register()
class TridentRPN(RPN):
"""
Trident RPN subnetwork.
"""
def __init__(self, cfg, input_shape):
super(TridentRPN, self).__init__(cfg, input_shape)
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1
def forward(self, images, features, gt_instances=None):
"""
See :class:`RPN.forward`.
"""
num_branch = self.num_branch if self.training or not self.trident_fast else 1
# Duplicate images and gt_instances for all branches in TridentNet.
all_images = ImageList(
torch.cat([images.tensor] * num_branch), images.image_sizes * num_branch
)
all_gt_instances = gt_instances * num_branch if gt_instances is not None else None
return super(TridentRPN, self).forward(all_images, features, all_gt_instances)