Add at new repo again
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
|
||||
# TensorMask in Detectron2
|
||||
**A Foundation for Dense Object Segmentation**
|
||||
|
||||
Xinlei Chen, Ross Girshick, Kaiming He, Piotr Dollár
|
||||
|
||||
[[`arXiv`](https://arxiv.org/abs/1903.12174)] [[`BibTeX`](#CitingTensorMask)]
|
||||
|
||||
<div align="center">
|
||||
<img src="http://xinleic.xyz/images/tmask.png" width="700px" />
|
||||
</div>
|
||||
|
||||
In this repository, we release code for TensorMask in Detectron2.
|
||||
TensorMask is a dense sliding-window instance segmentation framework that, for the first time, achieves results close to the well-developed Mask R-CNN framework -- both qualitatively and quantitatively. It establishes a conceptually complementary direction for object instance segmentation research.
|
||||
|
||||
## Installation
|
||||
First install Detectron2 following the [documentation](https://detectron2.readthedocs.io/tutorials/install.html) and
|
||||
[setup the dataset](../../datasets). Then compile the TensorMask-specific op (`swap_align2nat`):
|
||||
```bash
|
||||
cd /path/to/detectron2/projects/TensorMask
|
||||
python setup.py build develop
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
To train a model, run:
|
||||
```bash
|
||||
python /path/to/detectron2/projects/TensorMask/train_net.py --config-file <config.yaml>
|
||||
```
|
||||
|
||||
For example, to launch TensorMask BiPyramid training (1x schedule) with ResNet-50 backbone on 8 GPUs,
|
||||
one should execute:
|
||||
```bash
|
||||
python /path/to/detectron2/projects/TensorMask/train_net.py --config-file configs/tensormask_R_50_FPN_1x.yaml --num-gpus 8
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Model evaluation can be done similarly (6x schedule with scale augmentation):
|
||||
```bash
|
||||
python /path/to/detectron2/projects/TensorMask/train_net.py --config-file configs/tensormask_R_50_FPN_6x.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint
|
||||
```
|
||||
|
||||
# Pretrained Models
|
||||
|
||||
| Backbone | lr sched | AP box | AP mask | download |
|
||||
| -------- | -------- | -- | --- | -------- |
|
||||
| R50 | 1x | 37.6 | 32.4 | <a href="https://dl.fbaipublicfiles.com/detectron2/TensorMask/tensormask_R_50_FPN_1x/152549419/model_final_8f325c.pkl">model</a> \| <a href="https://dl.fbaipublicfiles.com/detectron2/TensorMask/tensormask_R_50_FPN_1x/152549419/metrics.json">metrics</a> |
|
||||
| R50 | 6x | 41.4 | 35.8 | <a href="https://dl.fbaipublicfiles.com/detectron2/TensorMask/tensormask_R_50_FPN_6x/153538791/model_final_e8df31.pkl">model</a> \| <a href="https://dl.fbaipublicfiles.com/detectron2/TensorMask/tensormask_R_50_FPN_6x/153538791/metrics.json">metrics</a> |
|
||||
|
||||
|
||||
## <a name="CitingTensorMask"></a>Citing TensorMask
|
||||
|
||||
If you use TensorMask, please use the following BibTeX entry.
|
||||
|
||||
```
|
||||
@InProceedings{chen2019tensormask,
|
||||
title={Tensormask: A Foundation for Dense Object Segmentation},
|
||||
author={Chen, Xinlei and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr},
|
||||
journal={The International Conference on Computer Vision (ICCV)},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
@@ -0,0 +1,25 @@
|
||||
MODEL:
|
||||
META_ARCHITECTURE: "TensorMask"
|
||||
MASK_ON: True
|
||||
BACKBONE:
|
||||
NAME: "build_retinanet_resnet_fpn_backbone"
|
||||
RESNETS:
|
||||
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
||||
ANCHOR_GENERATOR:
|
||||
SIZES: [[44, 60], [88, 120], [176, 240], [352, 480], [704, 960], [1408, 1920]]
|
||||
ASPECT_RATIOS: [[1.0]]
|
||||
FPN:
|
||||
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
||||
FUSE_TYPE: "avg"
|
||||
TENSOR_MASK:
|
||||
ALIGNED_ON: True
|
||||
BIPYRAMID_ON: True
|
||||
DATASETS:
|
||||
TRAIN: ("coco_2017_train",)
|
||||
TEST: ("coco_2017_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 16
|
||||
BASE_LR: 0.02
|
||||
STEPS: (60000, 80000)
|
||||
MAX_ITER: 90000
|
||||
VERSION: 2
|
@@ -0,0 +1,5 @@
|
||||
_BASE_: "Base-TensorMask.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
RESNETS:
|
||||
DEPTH: 50
|
@@ -0,0 +1,11 @@
|
||||
_BASE_: "Base-TensorMask.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
|
||||
RESNETS:
|
||||
DEPTH: 50
|
||||
SOLVER:
|
||||
STEPS: (480000, 520000)
|
||||
MAX_ITER: 540000
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN_SAMPLING: "range"
|
||||
MIN_SIZE_TRAIN: (640, 800)
|
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import glob
|
||||
import os
|
||||
from setuptools import find_packages, setup
|
||||
import torch
|
||||
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
|
||||
|
||||
|
||||
def get_extensions():
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
extensions_dir = os.path.join(this_dir, "tensormask", "layers", "csrc")
|
||||
|
||||
main_source = os.path.join(extensions_dir, "vision.cpp")
|
||||
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
|
||||
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
|
||||
os.path.join(extensions_dir, "*.cu")
|
||||
)
|
||||
|
||||
sources = [main_source] + sources
|
||||
|
||||
extension = CppExtension
|
||||
|
||||
extra_compile_args = {"cxx": []}
|
||||
define_macros = []
|
||||
|
||||
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
|
||||
extension = CUDAExtension
|
||||
sources += source_cuda
|
||||
define_macros += [("WITH_CUDA", None)]
|
||||
extra_compile_args["nvcc"] = [
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
]
|
||||
|
||||
# It's better if pytorch can do this by default ..
|
||||
CC = os.environ.get("CC", None)
|
||||
if CC is not None:
|
||||
extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
|
||||
|
||||
sources = [os.path.join(extensions_dir, s) for s in sources]
|
||||
|
||||
include_dirs = [extensions_dir]
|
||||
|
||||
ext_modules = [
|
||||
extension(
|
||||
"tensormask._C",
|
||||
sources,
|
||||
include_dirs=include_dirs,
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
]
|
||||
|
||||
return ext_modules
|
||||
|
||||
|
||||
setup(
|
||||
name="tensormask",
|
||||
version="0.1",
|
||||
author="FAIR",
|
||||
packages=find_packages(exclude=("configs", "tests")),
|
||||
python_requires=">=3.6",
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
||||
)
|
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .config import add_tensormask_config
|
||||
from .arch import TensorMask
|
@@ -0,0 +1,904 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fvcore.nn import sigmoid_focal_loss_star_jit, smooth_l1_loss
|
||||
from torch import nn
|
||||
|
||||
from detectron2.layers import ShapeSpec, batched_nms, cat, paste_masks_in_image
|
||||
from detectron2.modeling.anchor_generator import DefaultAnchorGenerator
|
||||
from detectron2.modeling.backbone import build_backbone
|
||||
from detectron2.modeling.box_regression import Box2BoxTransform
|
||||
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
|
||||
from detectron2.modeling.meta_arch.retinanet import (
|
||||
permute_all_cls_and_box_to_N_HWA_K_and_concat,
|
||||
permute_to_N_HWA_K,
|
||||
)
|
||||
from detectron2.structures import Boxes, ImageList, Instances
|
||||
from detectron2.utils.logger import log_first_n
|
||||
|
||||
from tensormask.layers import SwapAlign2Nat
|
||||
|
||||
__all__ = ["TensorMask"]
|
||||
|
||||
|
||||
def _assignment_rule(
|
||||
gt_boxes,
|
||||
anchor_boxes,
|
||||
unit_lengths,
|
||||
min_anchor_size,
|
||||
scale_thresh=2.0,
|
||||
spatial_thresh=1.0,
|
||||
uniqueness_on=True,
|
||||
):
|
||||
"""
|
||||
Given two lists of boxes of N ground truth boxes and M anchor boxes,
|
||||
compute the assignment between the two, following the assignment rules in
|
||||
https://arxiv.org/abs/1903.12174.
|
||||
The box order must be (xmin, ymin, xmax, ymax), so please make sure to convert
|
||||
to BoxMode.XYXY_ABS before calling this function.
|
||||
|
||||
Args:
|
||||
gt_boxes, anchor_boxes (Boxes): two Boxes. Contains N & M boxes/anchors, respectively.
|
||||
unit_lengths (Tensor): Contains the unit lengths of M anchor boxes.
|
||||
min_anchor_size (float): Minimum size of the anchor, in pixels
|
||||
scale_thresh (float): The `scale` threshold: the maximum size of the anchor
|
||||
should not be greater than scale_thresh x max(h, w) of
|
||||
the ground truth box.
|
||||
spatial_thresh (float): The `spatial` threshold: the l2 distance between the
|
||||
center of the anchor and the ground truth box should not
|
||||
be greater than spatial_thresh x u where u is the unit length.
|
||||
|
||||
Returns:
|
||||
matches (Tensor[int64]): a vector of length M, where matches[i] is a matched
|
||||
ground-truth index in [0, N)
|
||||
match_labels (Tensor[int8]): a vector of length M, where pred_labels[i] indicates
|
||||
whether a prediction is a true or false positive or ignored
|
||||
"""
|
||||
gt_boxes, anchor_boxes = gt_boxes.tensor, anchor_boxes.tensor
|
||||
N = gt_boxes.shape[0]
|
||||
M = anchor_boxes.shape[0]
|
||||
if N == 0 or M == 0:
|
||||
return (
|
||||
gt_boxes.new_full((N,), 0, dtype=torch.int64),
|
||||
gt_boxes.new_full((N,), -1, dtype=torch.int8),
|
||||
)
|
||||
|
||||
# Containment rule
|
||||
lt = torch.min(gt_boxes[:, None, :2], anchor_boxes[:, :2]) # [N,M,2]
|
||||
rb = torch.max(gt_boxes[:, None, 2:], anchor_boxes[:, 2:]) # [N,M,2]
|
||||
union = cat([lt, rb], dim=2) # [N,M,4]
|
||||
|
||||
dummy_gt_boxes = torch.zeros_like(gt_boxes)
|
||||
anchor = dummy_gt_boxes[:, None, :] + anchor_boxes[:, :] # [N,M,4]
|
||||
|
||||
contain_matrix = torch.all(union == anchor, dim=2) # [N,M]
|
||||
|
||||
# Centrality rule, scale
|
||||
gt_size_lower = torch.max(gt_boxes[:, 2:] - gt_boxes[:, :2], dim=1)[0] # [N]
|
||||
gt_size_upper = gt_size_lower * scale_thresh # [N]
|
||||
# Fall back for small objects
|
||||
gt_size_upper[gt_size_upper < min_anchor_size] = min_anchor_size
|
||||
# Due to sampling of locations, the anchor sizes are deducted with sampling strides
|
||||
anchor_size = (
|
||||
torch.max(anchor_boxes[:, 2:] - anchor_boxes[:, :2], dim=1)[0] - unit_lengths
|
||||
) # [M]
|
||||
|
||||
size_diff_upper = gt_size_upper[:, None] - anchor_size # [N,M]
|
||||
scale_matrix = size_diff_upper >= 0 # [N,M]
|
||||
|
||||
# Centrality rule, spatial
|
||||
gt_center = (gt_boxes[:, 2:] + gt_boxes[:, :2]) / 2 # [N,2]
|
||||
anchor_center = (anchor_boxes[:, 2:] + anchor_boxes[:, :2]) / 2 # [M,2]
|
||||
offset_center = gt_center[:, None, :] - anchor_center[:, :] # [N,M,2]
|
||||
offset_center /= unit_lengths[:, None] # [N,M,2]
|
||||
spatial_square = spatial_thresh * spatial_thresh
|
||||
spatial_matrix = torch.sum(offset_center * offset_center, dim=2) <= spatial_square
|
||||
|
||||
assign_matrix = (contain_matrix & scale_matrix & spatial_matrix).int()
|
||||
|
||||
# assign_matrix is N (gt) x M (predicted)
|
||||
# Max over gt elements (dim 0) to find best gt candidate for each prediction
|
||||
matched_vals, matches = assign_matrix.max(dim=0)
|
||||
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
|
||||
|
||||
match_labels[matched_vals == 0] = 0
|
||||
match_labels[matched_vals == 1] = 1
|
||||
|
||||
# find all the elements that match to ground truths multiple times
|
||||
not_unique_idxs = assign_matrix.sum(dim=0) > 1
|
||||
if uniqueness_on:
|
||||
match_labels[not_unique_idxs] = 0
|
||||
else:
|
||||
match_labels[not_unique_idxs] = -1
|
||||
|
||||
return matches, match_labels
|
||||
|
||||
|
||||
# TODO make the paste_mask function in d2 core support mask list
|
||||
def _paste_mask_lists_in_image(masks, boxes, image_shape, threshold=0.5):
|
||||
"""
|
||||
Paste a list of masks that are of various resolutions (e.g., 28 x 28) into an image.
|
||||
The location, height, and width for pasting each mask is determined by their
|
||||
corresponding bounding boxes in boxes.
|
||||
|
||||
Args:
|
||||
masks (list(Tensor)): A list of Tensor of shape (1, Hmask_i, Wmask_i).
|
||||
Values are in [0, 1]. The list length, Bimg, is the
|
||||
number of detected object instances in the image.
|
||||
boxes (Boxes): A Boxes of length Bimg. boxes.tensor[i] and masks[i] correspond
|
||||
to the same object instance.
|
||||
image_shape (tuple): height, width
|
||||
threshold (float): A threshold in [0, 1] for converting the (soft) masks to
|
||||
binary masks.
|
||||
|
||||
Returns:
|
||||
img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the
|
||||
number of detected object instances and Himage, Wimage are the image width
|
||||
and height. img_masks[i] is a binary mask for object instance i.
|
||||
"""
|
||||
if len(masks) == 0:
|
||||
return torch.empty((0, 1) + image_shape, dtype=torch.uint8)
|
||||
|
||||
# Loop over masks groups. Each group has the same mask prediction size.
|
||||
img_masks = []
|
||||
ind_masks = []
|
||||
mask_sizes = torch.tensor([m.shape[-1] for m in masks])
|
||||
unique_sizes = torch.unique(mask_sizes)
|
||||
for msize in unique_sizes.tolist():
|
||||
cur_ind = torch.where(mask_sizes == msize)[0]
|
||||
ind_masks.append(cur_ind)
|
||||
|
||||
cur_masks = cat([masks[i] for i in cur_ind])
|
||||
cur_boxes = boxes[cur_ind]
|
||||
img_masks.append(paste_masks_in_image(cur_masks, cur_boxes, image_shape, threshold))
|
||||
|
||||
img_masks = cat(img_masks)
|
||||
ind_masks = cat(ind_masks)
|
||||
|
||||
img_masks_out = torch.empty_like(img_masks)
|
||||
img_masks_out[ind_masks, :, :] = img_masks
|
||||
|
||||
return img_masks_out
|
||||
|
||||
|
||||
def _postprocess(results, result_mask_info, output_height, output_width, mask_threshold=0.5):
|
||||
"""
|
||||
Post-process the output boxes for TensorMask.
|
||||
The input images are often resized when entering an object detector.
|
||||
As a result, we often need the outputs of the detector in a different
|
||||
resolution from its inputs.
|
||||
|
||||
This function will postprocess the raw outputs of TensorMask
|
||||
to produce outputs according to the desired output resolution.
|
||||
|
||||
Args:
|
||||
results (Instances): the raw outputs from the detector.
|
||||
`results.image_size` contains the input image resolution the detector sees.
|
||||
This object might be modified in-place. Note that it does not contain the field
|
||||
`pred_masks`, which is provided by another input `result_masks`.
|
||||
result_mask_info (list[Tensor], Boxes): a pair of two items for mask related results.
|
||||
The first item is a list of #detection tensors, each is the predicted masks.
|
||||
The second item is the anchors corresponding to the predicted masks.
|
||||
output_height, output_width: the desired output resolution.
|
||||
|
||||
Returns:
|
||||
Instances: the postprocessed output from the model, based on the output resolution
|
||||
"""
|
||||
scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0])
|
||||
results = Instances((output_height, output_width), **results.get_fields())
|
||||
|
||||
output_boxes = results.pred_boxes
|
||||
output_boxes.tensor[:, 0::2] *= scale_x
|
||||
output_boxes.tensor[:, 1::2] *= scale_y
|
||||
output_boxes.clip(results.image_size)
|
||||
|
||||
inds_nonempty = output_boxes.nonempty()
|
||||
results = results[inds_nonempty]
|
||||
result_masks, result_anchors = result_mask_info
|
||||
if result_masks:
|
||||
result_anchors.tensor[:, 0::2] *= scale_x
|
||||
result_anchors.tensor[:, 1::2] *= scale_y
|
||||
result_masks = [x for (i, x) in zip(inds_nonempty.tolist(), result_masks) if i]
|
||||
results.pred_masks = _paste_mask_lists_in_image(
|
||||
result_masks,
|
||||
result_anchors[inds_nonempty],
|
||||
results.image_size,
|
||||
threshold=mask_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class TensorMaskAnchorGenerator(DefaultAnchorGenerator):
|
||||
"""
|
||||
For a set of image sizes and feature maps, computes a set of anchors for TensorMask.
|
||||
It also computes the unit lengths and indexes for each anchor box.
|
||||
"""
|
||||
|
||||
def grid_anchors_with_unit_lengths_and_indexes(self, grid_sizes):
|
||||
anchors = []
|
||||
unit_lengths = []
|
||||
indexes = []
|
||||
for lvl, (size, stride, base_anchors) in enumerate(
|
||||
zip(grid_sizes, self.strides, self.cell_anchors)
|
||||
):
|
||||
grid_height, grid_width = size
|
||||
device = base_anchors.device
|
||||
shifts_x = torch.arange(
|
||||
0, grid_width * stride, step=stride, dtype=torch.float32, device=device
|
||||
)
|
||||
shifts_y = torch.arange(
|
||||
0, grid_height * stride, step=stride, dtype=torch.float32, device=device
|
||||
)
|
||||
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
|
||||
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=2)
|
||||
# Stack anchors in shapes of (HWA, 4)
|
||||
cur_anchor = (shifts[:, :, None, :] + base_anchors.view(1, 1, -1, 4)).view(-1, 4)
|
||||
anchors.append(cur_anchor)
|
||||
unit_lengths.append(
|
||||
torch.full((cur_anchor.shape[0],), stride, dtype=torch.float32, device=device)
|
||||
)
|
||||
# create mask indexes using mesh grid
|
||||
shifts_l = torch.full((1,), lvl, dtype=torch.int64, device=device)
|
||||
shifts_i = torch.zeros((1,), dtype=torch.int64, device=device)
|
||||
shifts_h = torch.arange(0, grid_height, dtype=torch.int64, device=device)
|
||||
shifts_w = torch.arange(0, grid_width, dtype=torch.int64, device=device)
|
||||
shifts_a = torch.arange(0, base_anchors.shape[0], dtype=torch.int64, device=device)
|
||||
grids = torch.meshgrid(shifts_l, shifts_i, shifts_h, shifts_w, shifts_a)
|
||||
|
||||
indexes.append(torch.stack(grids, dim=5).view(-1, 5))
|
||||
|
||||
return anchors, unit_lengths, indexes
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Returns:
|
||||
list[list[Boxes]]: a list of #image elements. Each is a list of #feature level Boxes.
|
||||
The Boxes contains anchors of this image on the specific feature level.
|
||||
list[list[Tensor]]: a list of #image elements. Each is a list of #feature level tensors.
|
||||
The tensor contains strides, or unit lengths for the anchors.
|
||||
list[list[Tensor]]: a list of #image elements. Each is a list of #feature level tensors.
|
||||
The Tensor contains indexes for the anchors, with the last dimension meaning
|
||||
(L, N, H, W, A), where L is level, I is image (not set yet), H is height,
|
||||
W is width, and A is anchor.
|
||||
"""
|
||||
num_images = len(features[0])
|
||||
grid_sizes = [feature_map.shape[-2:] for feature_map in features]
|
||||
anchors_list, lengths_list, indexes_list = self.grid_anchors_with_unit_lengths_and_indexes(
|
||||
grid_sizes
|
||||
)
|
||||
|
||||
# Convert anchors from Tensor to Boxes
|
||||
anchors_per_im = [Boxes(x) for x in anchors_list]
|
||||
|
||||
# TODO it can be simplified to not return duplicated information for
|
||||
# each image, just like detectron2's own AnchorGenerator
|
||||
anchors = [copy.deepcopy(anchors_per_im) for _ in range(num_images)]
|
||||
unit_lengths = [copy.deepcopy(lengths_list) for _ in range(num_images)]
|
||||
indexes = [copy.deepcopy(indexes_list) for _ in range(num_images)]
|
||||
|
||||
return anchors, unit_lengths, indexes
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class TensorMask(nn.Module):
|
||||
"""
|
||||
TensorMask model. Creates FPN backbone, anchors and a head for classification
|
||||
and box regression. Calculates and applies proper losses to class, box, and
|
||||
masks.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
# fmt: off
|
||||
self.num_classes = cfg.MODEL.TENSOR_MASK.NUM_CLASSES
|
||||
self.in_features = cfg.MODEL.TENSOR_MASK.IN_FEATURES
|
||||
self.anchor_sizes = cfg.MODEL.ANCHOR_GENERATOR.SIZES
|
||||
self.num_levels = len(cfg.MODEL.ANCHOR_GENERATOR.SIZES)
|
||||
# Loss parameters:
|
||||
self.focal_loss_alpha = cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_ALPHA
|
||||
self.focal_loss_gamma = cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_GAMMA
|
||||
# Inference parameters:
|
||||
self.score_threshold = cfg.MODEL.TENSOR_MASK.SCORE_THRESH_TEST
|
||||
self.topk_candidates = cfg.MODEL.TENSOR_MASK.TOPK_CANDIDATES_TEST
|
||||
self.nms_threshold = cfg.MODEL.TENSOR_MASK.NMS_THRESH_TEST
|
||||
self.detections_im = cfg.TEST.DETECTIONS_PER_IMAGE
|
||||
# Mask parameters:
|
||||
self.mask_on = cfg.MODEL.MASK_ON
|
||||
self.mask_loss_weight = cfg.MODEL.TENSOR_MASK.MASK_LOSS_WEIGHT
|
||||
self.mask_pos_weight = torch.tensor(cfg.MODEL.TENSOR_MASK.POSITIVE_WEIGHT,
|
||||
dtype=torch.float32)
|
||||
self.bipyramid_on = cfg.MODEL.TENSOR_MASK.BIPYRAMID_ON
|
||||
# fmt: on
|
||||
|
||||
# build the backbone
|
||||
self.backbone = build_backbone(cfg)
|
||||
|
||||
backbone_shape = self.backbone.output_shape()
|
||||
feature_shapes = [backbone_shape[f] for f in self.in_features]
|
||||
feature_strides = [x.stride for x in feature_shapes]
|
||||
# build anchors
|
||||
self.anchor_generator = TensorMaskAnchorGenerator(cfg, feature_shapes)
|
||||
self.num_anchors = self.anchor_generator.num_cell_anchors[0]
|
||||
anchors_min_level = cfg.MODEL.ANCHOR_GENERATOR.SIZES[0]
|
||||
self.mask_sizes = [size // feature_strides[0] for size in anchors_min_level]
|
||||
self.min_anchor_size = min(anchors_min_level) - feature_strides[0]
|
||||
|
||||
# head of the TensorMask
|
||||
self.head = TensorMaskHead(
|
||||
cfg, self.num_levels, self.num_anchors, self.mask_sizes, feature_shapes
|
||||
)
|
||||
# box transform
|
||||
self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.TENSOR_MASK.BBOX_REG_WEIGHTS)
|
||||
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1))
|
||||
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.pixel_mean.device
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
"""
|
||||
Args:
|
||||
batched_inputs: a list, batched outputs of :class:`DetectionTransform` .
|
||||
Each item in the list contains the inputs for one image.
|
||||
For now, each item in the list is a dict that contains:
|
||||
image: Tensor, image in (C, H, W) format.
|
||||
instances: Instances
|
||||
Other information that's included in the original dicts, such as:
|
||||
"height", "width" (int): the output resolution of the model, used in inference.
|
||||
See :meth:`postprocess` for details.
|
||||
Returns:
|
||||
losses (dict[str: Tensor]): mapping from a named loss to a tensor
|
||||
storing the loss. Used during training only.
|
||||
"""
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
if "instances" in batched_inputs[0]:
|
||||
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
||||
elif "targets" in batched_inputs[0]:
|
||||
log_first_n(
|
||||
logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10
|
||||
)
|
||||
gt_instances = [x["targets"].to(self.device) for x in batched_inputs]
|
||||
else:
|
||||
gt_instances = None
|
||||
|
||||
features = self.backbone(images.tensor)
|
||||
features = [features[f] for f in self.in_features]
|
||||
# apply the TensorMask head
|
||||
pred_logits, pred_deltas, pred_masks = self.head(features)
|
||||
# generate anchors based on features, is it image specific?
|
||||
anchors, unit_lengths, indexes = self.anchor_generator(features)
|
||||
|
||||
if self.training:
|
||||
# get ground truths for class labels and box targets, it will label each anchor
|
||||
gt_class_info, gt_delta_info, gt_mask_info, num_fg = self.get_ground_truth(
|
||||
anchors, unit_lengths, indexes, gt_instances
|
||||
)
|
||||
# compute the loss
|
||||
return self.losses(
|
||||
gt_class_info,
|
||||
gt_delta_info,
|
||||
gt_mask_info,
|
||||
num_fg,
|
||||
pred_logits,
|
||||
pred_deltas,
|
||||
pred_masks,
|
||||
)
|
||||
else:
|
||||
# do inference to get the output
|
||||
results = self.inference(pred_logits, pred_deltas, pred_masks, anchors, indexes, images)
|
||||
processed_results = []
|
||||
for results_im, input_im, image_size in zip(
|
||||
results, batched_inputs, images.image_sizes
|
||||
):
|
||||
height = input_im.get("height", image_size[0])
|
||||
width = input_im.get("width", image_size[1])
|
||||
# this is to do post-processing with the image size
|
||||
result_box, result_mask = results_im
|
||||
r = _postprocess(result_box, result_mask, height, width)
|
||||
processed_results.append({"instances": r})
|
||||
return processed_results
|
||||
|
||||
def losses(
|
||||
self,
|
||||
gt_class_info,
|
||||
gt_delta_info,
|
||||
gt_mask_info,
|
||||
num_fg,
|
||||
pred_logits,
|
||||
pred_deltas,
|
||||
pred_masks,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
For `gt_class_info`, `gt_delta_info`, `gt_mask_info` and `num_fg` parameters, see
|
||||
:meth:`TensorMask.get_ground_truth`.
|
||||
For `pred_logits`, `pred_deltas` and `pred_masks`, see
|
||||
:meth:`TensorMaskHead.forward`.
|
||||
|
||||
Returns:
|
||||
losses (dict[str: Tensor]): mapping from a named loss to a scalar tensor
|
||||
storing the loss. Used during training only. The potential dict keys are:
|
||||
"loss_cls", "loss_box_reg" and "loss_mask".
|
||||
"""
|
||||
gt_classes_target, gt_valid_inds = gt_class_info
|
||||
gt_deltas, gt_fg_inds = gt_delta_info
|
||||
gt_masks, gt_mask_inds = gt_mask_info
|
||||
loss_normalizer = torch.tensor(max(1, num_fg), dtype=torch.float32, device=self.device)
|
||||
|
||||
# classification and regression
|
||||
pred_logits, pred_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
|
||||
pred_logits, pred_deltas, self.num_classes
|
||||
)
|
||||
loss_cls = (
|
||||
sigmoid_focal_loss_star_jit(
|
||||
pred_logits[gt_valid_inds],
|
||||
gt_classes_target[gt_valid_inds],
|
||||
alpha=self.focal_loss_alpha,
|
||||
gamma=self.focal_loss_gamma,
|
||||
reduction="sum",
|
||||
)
|
||||
/ loss_normalizer
|
||||
)
|
||||
|
||||
if num_fg == 0:
|
||||
loss_box_reg = pred_deltas.sum() * 0
|
||||
else:
|
||||
loss_box_reg = (
|
||||
smooth_l1_loss(pred_deltas[gt_fg_inds], gt_deltas, beta=0.0, reduction="sum")
|
||||
/ loss_normalizer
|
||||
)
|
||||
losses = {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
|
||||
|
||||
# mask prediction
|
||||
if self.mask_on:
|
||||
loss_mask = 0
|
||||
for lvl in range(self.num_levels):
|
||||
cur_level_factor = 2 ** lvl if self.bipyramid_on else 1
|
||||
for anc in range(self.num_anchors):
|
||||
cur_gt_mask_inds = gt_mask_inds[lvl][anc]
|
||||
if cur_gt_mask_inds is None:
|
||||
loss_mask += pred_masks[lvl][anc][0, 0, 0, 0] * 0
|
||||
else:
|
||||
cur_mask_size = self.mask_sizes[anc] * cur_level_factor
|
||||
# TODO maybe there are numerical issues when mask sizes are large
|
||||
cur_size_divider = torch.tensor(
|
||||
self.mask_loss_weight / (cur_mask_size ** 2),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
cur_pred_masks = pred_masks[lvl][anc][
|
||||
cur_gt_mask_inds[:, 0], # N
|
||||
:, # V x U
|
||||
cur_gt_mask_inds[:, 1], # H
|
||||
cur_gt_mask_inds[:, 2], # W
|
||||
]
|
||||
|
||||
loss_mask += F.binary_cross_entropy_with_logits(
|
||||
cur_pred_masks.view(-1, cur_mask_size, cur_mask_size), # V, U
|
||||
gt_masks[lvl][anc].to(dtype=torch.float32),
|
||||
reduction="sum",
|
||||
weight=cur_size_divider,
|
||||
pos_weight=self.mask_pos_weight,
|
||||
)
|
||||
losses["loss_mask"] = loss_mask / loss_normalizer
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def get_ground_truth(self, anchors, unit_lengths, indexes, targets):
|
||||
"""
|
||||
Args:
|
||||
anchors (list[list[Boxes]]): a list of N=#image elements. Each is a
|
||||
list of #feature level Boxes. The Boxes contains anchors of
|
||||
this image on the specific feature level.
|
||||
unit_lengths (list[list[Tensor]]): a list of N=#image elements. Each is a
|
||||
list of #feature level Tensor. The tensor contains unit lengths for anchors of
|
||||
this image on the specific feature level.
|
||||
indexes (list[list[Tensor]]): a list of N=#image elements. Each is a
|
||||
list of #feature level Tensor. The tensor contains the 5D index of
|
||||
each anchor, the second dimension means (L, N, H, W, A), where L
|
||||
is level, I is image, H is height, W is width, and A is anchor.
|
||||
targets (list[Instances]): a list of N `Instances`s. The i-th
|
||||
`Instances` contains the ground-truth per-instance annotations
|
||||
for the i-th input image. Specify `targets` during training only.
|
||||
|
||||
Returns:
|
||||
gt_class_info (Tensor, Tensor): A pair of two tensors for classification.
|
||||
The first one is an integer tensor of shape (R, #classes) storing ground-truth
|
||||
labels for each anchor. R is the total number of anchors in the batch.
|
||||
The second one is an integer tensor of shape (R,), to indicate which
|
||||
anchors are valid for loss computation, which anchors are not.
|
||||
gt_delta_info (Tensor, Tensor): A pair of two tensors for boxes.
|
||||
The first one, of shape (F, 4). F=#foreground anchors.
|
||||
The last dimension represents ground-truth box2box transform
|
||||
targets (dx, dy, dw, dh) that map each anchor to its matched ground-truth box.
|
||||
Only foreground anchors have values in this tensor. Could be `None` if F=0.
|
||||
The second one, of shape (R,), is an integer tensor indicating which anchors
|
||||
are foreground ones used for box regression. Could be `None` if F=0.
|
||||
gt_mask_info (list[list[Tensor]], list[list[Tensor]]): A pair of two lists for masks.
|
||||
The first one is a list of P=#feature level elements. Each is a
|
||||
list of A=#anchor tensors. Each tensor contains the ground truth
|
||||
masks of the same size and for the same feature level. Could be `None`.
|
||||
The second one is a list of P=#feature level elements. Each is a
|
||||
list of A=#anchor tensors. Each tensor contains the location of the ground truth
|
||||
masks of the same size and for the same feature level. The second dimension means
|
||||
(N, H, W), where N is image, H is height, and W is width. Could be `None`.
|
||||
num_fg (int): F=#foreground anchors, used later for loss normalization.
|
||||
"""
|
||||
gt_classes = []
|
||||
gt_deltas = []
|
||||
gt_masks = [[[] for _ in range(self.num_anchors)] for _ in range(self.num_levels)]
|
||||
gt_mask_inds = [[[] for _ in range(self.num_anchors)] for _ in range(self.num_levels)]
|
||||
|
||||
anchors = [Boxes.cat(anchors_i) for anchors_i in anchors]
|
||||
unit_lengths = [cat(unit_lengths_i) for unit_lengths_i in unit_lengths]
|
||||
indexes = [cat(indexes_i) for indexes_i in indexes]
|
||||
|
||||
num_fg = 0
|
||||
for i, (anchors_im, unit_lengths_im, indexes_im, targets_im) in enumerate(
|
||||
zip(anchors, unit_lengths, indexes, targets)
|
||||
):
|
||||
# Initialize all
|
||||
gt_classes_i = torch.full_like(
|
||||
unit_lengths_im, self.num_classes, dtype=torch.int64, device=self.device
|
||||
)
|
||||
# Ground truth classes
|
||||
has_gt = len(targets_im) > 0
|
||||
if has_gt:
|
||||
# Compute the pairwise matrix
|
||||
gt_matched_inds, anchor_labels = _assignment_rule(
|
||||
targets_im.gt_boxes, anchors_im, unit_lengths_im, self.min_anchor_size
|
||||
)
|
||||
# Find the foreground instances
|
||||
fg_inds = anchor_labels == 1
|
||||
fg_anchors = anchors_im[fg_inds]
|
||||
num_fg += len(fg_anchors)
|
||||
# Find the ground truths for foreground instances
|
||||
gt_fg_matched_inds = gt_matched_inds[fg_inds]
|
||||
# Assign labels for foreground instances
|
||||
gt_classes_i[fg_inds] = targets_im.gt_classes[gt_fg_matched_inds]
|
||||
# Anchors with label -1 are ignored, others are left as negative
|
||||
gt_classes_i[anchor_labels == -1] = -1
|
||||
|
||||
# Boxes
|
||||
# Ground truth box regression, only for foregrounds
|
||||
matched_gt_boxes = targets_im[gt_fg_matched_inds].gt_boxes
|
||||
# Compute box regression offsets for foregrounds only
|
||||
gt_deltas_i = self.box2box_transform.get_deltas(
|
||||
fg_anchors.tensor, matched_gt_boxes.tensor
|
||||
)
|
||||
gt_deltas.append(gt_deltas_i)
|
||||
|
||||
# Masks
|
||||
if self.mask_on:
|
||||
# Compute masks for each level and each anchor
|
||||
matched_indexes = indexes_im[fg_inds, :]
|
||||
for lvl in range(self.num_levels):
|
||||
ids_lvl = matched_indexes[:, 0] == lvl
|
||||
if torch.any(ids_lvl):
|
||||
cur_level_factor = 2 ** lvl if self.bipyramid_on else 1
|
||||
for anc in range(self.num_anchors):
|
||||
ids_lvl_anchor = ids_lvl & (matched_indexes[:, 4] == anc)
|
||||
if torch.any(ids_lvl_anchor):
|
||||
gt_masks[lvl][anc].append(
|
||||
targets_im[
|
||||
gt_fg_matched_inds[ids_lvl_anchor]
|
||||
].gt_masks.crop_and_resize(
|
||||
fg_anchors[ids_lvl_anchor].tensor,
|
||||
self.mask_sizes[anc] * cur_level_factor,
|
||||
)
|
||||
)
|
||||
# Select (N, H, W) dimensions
|
||||
gt_mask_inds_lvl_anc = matched_indexes[ids_lvl_anchor, 1:4]
|
||||
# Set the image index to the current image
|
||||
gt_mask_inds_lvl_anc[:, 0] = i
|
||||
gt_mask_inds[lvl][anc].append(gt_mask_inds_lvl_anc)
|
||||
gt_classes.append(gt_classes_i)
|
||||
|
||||
# Classes and boxes
|
||||
gt_classes = cat(gt_classes)
|
||||
gt_valid_inds = gt_classes >= 0
|
||||
gt_fg_inds = gt_valid_inds & (gt_classes < self.num_classes)
|
||||
gt_classes_target = torch.zeros(
|
||||
(gt_classes.shape[0], self.num_classes), dtype=torch.float32, device=self.device
|
||||
)
|
||||
gt_classes_target[gt_fg_inds, gt_classes[gt_fg_inds]] = 1
|
||||
gt_deltas = cat(gt_deltas) if gt_deltas else None
|
||||
|
||||
# Masks
|
||||
gt_masks = [[cat(mla) if mla else None for mla in ml] for ml in gt_masks]
|
||||
gt_mask_inds = [[cat(ila) if ila else None for ila in il] for il in gt_mask_inds]
|
||||
return (
|
||||
(gt_classes_target, gt_valid_inds),
|
||||
(gt_deltas, gt_fg_inds),
|
||||
(gt_masks, gt_mask_inds),
|
||||
num_fg,
|
||||
)
|
||||
|
||||
def inference(self, pred_logits, pred_deltas, pred_masks, anchors, indexes, images):
|
||||
"""
|
||||
Arguments:
|
||||
pred_logits, pred_deltas, pred_masks: Same as the output of:
|
||||
meth:`TensorMaskHead.forward`
|
||||
anchors, indexes: Same as the input of meth:`TensorMask.get_ground_truth`
|
||||
images (ImageList): the input images
|
||||
|
||||
Returns:
|
||||
results (List[Instances]): a list of #images elements.
|
||||
"""
|
||||
assert len(anchors) == len(images)
|
||||
results = []
|
||||
|
||||
pred_logits = [permute_to_N_HWA_K(x, self.num_classes) for x in pred_logits]
|
||||
pred_deltas = [permute_to_N_HWA_K(x, 4) for x in pred_deltas]
|
||||
|
||||
pred_logits = cat(pred_logits, dim=1)
|
||||
pred_deltas = cat(pred_deltas, dim=1)
|
||||
|
||||
for img_idx, (anchors_im, indexes_im) in enumerate(zip(anchors, indexes)):
|
||||
# Get the size of the current image
|
||||
image_size = images.image_sizes[img_idx]
|
||||
|
||||
logits_im = pred_logits[img_idx]
|
||||
deltas_im = pred_deltas[img_idx]
|
||||
|
||||
if self.mask_on:
|
||||
masks_im = [[mla[img_idx] for mla in ml] for ml in pred_masks]
|
||||
else:
|
||||
masks_im = [None] * self.num_levels
|
||||
results_im = self.inference_single_image(
|
||||
logits_im,
|
||||
deltas_im,
|
||||
masks_im,
|
||||
Boxes.cat(anchors_im),
|
||||
cat(indexes_im),
|
||||
tuple(image_size),
|
||||
)
|
||||
results.append(results_im)
|
||||
return results
|
||||
|
||||
def inference_single_image(
|
||||
self, pred_logits, pred_deltas, pred_masks, anchors, indexes, image_size
|
||||
):
|
||||
"""
|
||||
Single-image inference. Return bounding-box detection results by thresholding
|
||||
on scores and applying non-maximum suppression (NMS).
|
||||
|
||||
Arguments:
|
||||
pred_logits (list[Tensor]): list of #feature levels. Each entry contains
|
||||
tensor of size (AxHxW, K)
|
||||
pred_deltas (list[Tensor]): Same shape as 'pred_logits' except that K becomes 4.
|
||||
pred_masks (list[list[Tensor]]): List of #feature levels, each is a list of #anchors.
|
||||
Each entry contains tensor of size (M_i*M_i, H, W). `None` if mask_on=False.
|
||||
anchors (list[Boxes]): list of #feature levels. Each entry contains
|
||||
a Boxes object, which contains all the anchors for that
|
||||
image in that feature level.
|
||||
image_size (tuple(H, W)): a tuple of the image height and width.
|
||||
|
||||
Returns:
|
||||
Same as `inference`, but for only one image.
|
||||
"""
|
||||
pred_logits = pred_logits.flatten().sigmoid_()
|
||||
# We get top locations across all levels to accelerate the inference speed,
|
||||
# which does not seem to affect the accuracy.
|
||||
# First select values above the threshold
|
||||
logits_top_idxs = torch.where(pred_logits > self.score_threshold)[0]
|
||||
# Then get the top values
|
||||
num_topk = min(self.topk_candidates, logits_top_idxs.shape[0])
|
||||
pred_prob, topk_idxs = pred_logits[logits_top_idxs].sort(descending=True)
|
||||
# Keep top k scoring values
|
||||
pred_prob = pred_prob[:num_topk]
|
||||
# Keep top k values
|
||||
top_idxs = logits_top_idxs[topk_idxs[:num_topk]]
|
||||
|
||||
# class index
|
||||
cls_idxs = top_idxs % self.num_classes
|
||||
# HWA index
|
||||
top_idxs //= self.num_classes
|
||||
# predict boxes
|
||||
pred_boxes = self.box2box_transform.apply_deltas(
|
||||
pred_deltas[top_idxs], anchors[top_idxs].tensor
|
||||
)
|
||||
# apply nms
|
||||
keep = batched_nms(pred_boxes, pred_prob, cls_idxs, self.nms_threshold)
|
||||
# pick the top ones
|
||||
keep = keep[: self.detections_im]
|
||||
|
||||
results = Instances(image_size)
|
||||
results.pred_boxes = Boxes(pred_boxes[keep])
|
||||
results.scores = pred_prob[keep]
|
||||
results.pred_classes = cls_idxs[keep]
|
||||
|
||||
# deal with masks
|
||||
result_masks, result_anchors = [], None
|
||||
if self.mask_on:
|
||||
# index and anchors, useful for masks
|
||||
top_indexes = indexes[top_idxs]
|
||||
top_anchors = anchors[top_idxs]
|
||||
result_indexes = top_indexes[keep]
|
||||
result_anchors = top_anchors[keep]
|
||||
# Get masks and do sigmoid
|
||||
for lvl, _, h, w, anc in result_indexes.tolist():
|
||||
cur_size = self.mask_sizes[anc] * (2 ** lvl if self.bipyramid_on else 1)
|
||||
result_masks.append(
|
||||
torch.sigmoid(pred_masks[lvl][anc][:, h, w].view(1, cur_size, cur_size))
|
||||
)
|
||||
|
||||
return results, (result_masks, result_anchors)
|
||||
|
||||
def preprocess_image(self, batched_inputs):
|
||||
"""
|
||||
Normalize, pad and batch the input images.
|
||||
"""
|
||||
images = [x["image"].to(self.device) for x in batched_inputs]
|
||||
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
||||
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
|
||||
return images
|
||||
|
||||
|
||||
class TensorMaskHead(nn.Module):
|
||||
def __init__(self, cfg, num_levels, num_anchors, mask_sizes, input_shape: List[ShapeSpec]):
|
||||
"""
|
||||
TensorMask head.
|
||||
"""
|
||||
super().__init__()
|
||||
# fmt: off
|
||||
self.in_features = cfg.MODEL.TENSOR_MASK.IN_FEATURES
|
||||
in_channels = input_shape[0].channels
|
||||
num_classes = cfg.MODEL.TENSOR_MASK.NUM_CLASSES
|
||||
cls_channels = cfg.MODEL.TENSOR_MASK.CLS_CHANNELS
|
||||
num_convs = cfg.MODEL.TENSOR_MASK.NUM_CONVS
|
||||
# box parameters
|
||||
bbox_channels = cfg.MODEL.TENSOR_MASK.BBOX_CHANNELS
|
||||
# mask parameters
|
||||
self.mask_on = cfg.MODEL.MASK_ON
|
||||
self.mask_sizes = mask_sizes
|
||||
mask_channels = cfg.MODEL.TENSOR_MASK.MASK_CHANNELS
|
||||
self.align_on = cfg.MODEL.TENSOR_MASK.ALIGNED_ON
|
||||
self.bipyramid_on = cfg.MODEL.TENSOR_MASK.BIPYRAMID_ON
|
||||
# fmt: on
|
||||
|
||||
# class subnet
|
||||
cls_subnet = []
|
||||
cur_channels = in_channels
|
||||
for _ in range(num_convs):
|
||||
cls_subnet.append(
|
||||
nn.Conv2d(cur_channels, cls_channels, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
cur_channels = cls_channels
|
||||
cls_subnet.append(nn.ReLU())
|
||||
|
||||
self.cls_subnet = nn.Sequential(*cls_subnet)
|
||||
self.cls_score = nn.Conv2d(
|
||||
cur_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
modules_list = [self.cls_subnet, self.cls_score]
|
||||
|
||||
# box subnet
|
||||
bbox_subnet = []
|
||||
cur_channels = in_channels
|
||||
for _ in range(num_convs):
|
||||
bbox_subnet.append(
|
||||
nn.Conv2d(cur_channels, bbox_channels, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
cur_channels = bbox_channels
|
||||
bbox_subnet.append(nn.ReLU())
|
||||
|
||||
self.bbox_subnet = nn.Sequential(*bbox_subnet)
|
||||
self.bbox_pred = nn.Conv2d(
|
||||
cur_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
modules_list.extend([self.bbox_subnet, self.bbox_pred])
|
||||
|
||||
# mask subnet
|
||||
if self.mask_on:
|
||||
mask_subnet = []
|
||||
cur_channels = in_channels
|
||||
for _ in range(num_convs):
|
||||
mask_subnet.append(
|
||||
nn.Conv2d(cur_channels, mask_channels, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
cur_channels = mask_channels
|
||||
mask_subnet.append(nn.ReLU())
|
||||
|
||||
self.mask_subnet = nn.Sequential(*mask_subnet)
|
||||
modules_list.append(self.mask_subnet)
|
||||
for mask_size in self.mask_sizes:
|
||||
cur_mask_module = "mask_pred_%02d" % mask_size
|
||||
self.add_module(
|
||||
cur_mask_module,
|
||||
nn.Conv2d(
|
||||
cur_channels, mask_size * mask_size, kernel_size=1, stride=1, padding=0
|
||||
),
|
||||
)
|
||||
modules_list.append(getattr(self, cur_mask_module))
|
||||
if self.align_on:
|
||||
if self.bipyramid_on:
|
||||
for lvl in range(num_levels):
|
||||
cur_mask_module = "align2nat_%02d" % lvl
|
||||
lambda_val = 2 ** lvl
|
||||
setattr(self, cur_mask_module, SwapAlign2Nat(lambda_val))
|
||||
# Also the fusing layer, stay at the same channel size
|
||||
mask_fuse = [
|
||||
nn.Conv2d(cur_channels, cur_channels, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
]
|
||||
self.mask_fuse = nn.Sequential(*mask_fuse)
|
||||
modules_list.append(self.mask_fuse)
|
||||
else:
|
||||
self.align2nat = SwapAlign2Nat(1)
|
||||
|
||||
# Initialization
|
||||
for modules in modules_list:
|
||||
for layer in modules.modules():
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
|
||||
torch.nn.init.constant_(layer.bias, 0)
|
||||
|
||||
# Use prior in model initialization to improve stability
|
||||
bias_value = -(math.log((1 - 0.01) / 0.01))
|
||||
torch.nn.init.constant_(self.cls_score.bias, bias_value)
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Arguments:
|
||||
features (list[Tensor]): FPN feature map tensors in high to low resolution.
|
||||
Each tensor in the list correspond to different feature levels.
|
||||
|
||||
Returns:
|
||||
pred_logits (list[Tensor]): #lvl tensors, each has shape (N, AxK, Hi, Wi).
|
||||
The tensor predicts the classification probability
|
||||
at each spatial position for each of the A anchors and K object
|
||||
classes.
|
||||
pred_deltas (list[Tensor]): #lvl tensors, each has shape (N, Ax4, Hi, Wi).
|
||||
The tensor predicts 4-vector (dx,dy,dw,dh) box
|
||||
regression values for every anchor. These values are the
|
||||
relative offset between the anchor and the ground truth box.
|
||||
pred_masks (list(list[Tensor])): #lvl list of tensors, each is a list of
|
||||
A tensors of shape (N, M_{i,a}, Hi, Wi).
|
||||
The tensor predicts a dense set of M_ixM_i masks at every location.
|
||||
"""
|
||||
pred_logits = [self.cls_score(self.cls_subnet(x)) for x in features]
|
||||
pred_deltas = [self.bbox_pred(self.bbox_subnet(x)) for x in features]
|
||||
|
||||
pred_masks = None
|
||||
if self.mask_on:
|
||||
mask_feats = [self.mask_subnet(x) for x in features]
|
||||
|
||||
if self.bipyramid_on:
|
||||
mask_feat_high_res = mask_feats[0]
|
||||
H, W = mask_feat_high_res.shape[-2:]
|
||||
mask_feats_up = []
|
||||
for lvl, mask_feat in enumerate(mask_feats):
|
||||
lambda_val = 2.0 ** lvl
|
||||
mask_feat_up = mask_feat
|
||||
if lvl > 0:
|
||||
mask_feat_up = F.interpolate(
|
||||
mask_feat, scale_factor=lambda_val, mode="bilinear", align_corners=False
|
||||
)
|
||||
mask_feats_up.append(
|
||||
self.mask_fuse(mask_feat_up[:, :, :H, :W] + mask_feat_high_res)
|
||||
)
|
||||
mask_feats = mask_feats_up
|
||||
|
||||
pred_masks = []
|
||||
for lvl, mask_feat in enumerate(mask_feats):
|
||||
cur_masks = []
|
||||
for mask_size in self.mask_sizes:
|
||||
cur_mask_module = getattr(self, "mask_pred_%02d" % mask_size)
|
||||
cur_mask = cur_mask_module(mask_feat)
|
||||
if self.align_on:
|
||||
if self.bipyramid_on:
|
||||
cur_mask_module = getattr(self, "align2nat_%02d" % lvl)
|
||||
cur_mask = cur_mask_module(cur_mask)
|
||||
else:
|
||||
cur_mask = self.align2nat(cur_mask)
|
||||
cur_masks.append(cur_mask)
|
||||
pred_masks.append(cur_masks)
|
||||
return pred_logits, pred_deltas, pred_masks
|
@@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
from detectron2.config import CfgNode as CN
|
||||
|
||||
|
||||
def add_tensormask_config(cfg):
|
||||
"""
|
||||
Add config for TensorMask.
|
||||
"""
|
||||
cfg.MODEL.TENSOR_MASK = CN()
|
||||
|
||||
# Anchor parameters
|
||||
cfg.MODEL.TENSOR_MASK.IN_FEATURES = ["p2", "p3", "p4", "p5", "p6", "p7"]
|
||||
|
||||
# Convolutions to use in the towers
|
||||
cfg.MODEL.TENSOR_MASK.NUM_CONVS = 4
|
||||
|
||||
# Number of foreground classes.
|
||||
cfg.MODEL.TENSOR_MASK.NUM_CLASSES = 80
|
||||
# Channel size for the classification tower
|
||||
cfg.MODEL.TENSOR_MASK.CLS_CHANNELS = 256
|
||||
|
||||
cfg.MODEL.TENSOR_MASK.SCORE_THRESH_TEST = 0.05
|
||||
# Only the top (1000 * #levels) candidate boxes across all levels are
|
||||
# considered jointly during test (to improve speed)
|
||||
cfg.MODEL.TENSOR_MASK.TOPK_CANDIDATES_TEST = 6000
|
||||
cfg.MODEL.TENSOR_MASK.NMS_THRESH_TEST = 0.5
|
||||
|
||||
# Box parameters
|
||||
# Channel size for the box tower
|
||||
cfg.MODEL.TENSOR_MASK.BBOX_CHANNELS = 128
|
||||
# Weights on (dx, dy, dw, dh)
|
||||
cfg.MODEL.TENSOR_MASK.BBOX_REG_WEIGHTS = (1.5, 1.5, 0.75, 0.75)
|
||||
|
||||
# Loss parameters
|
||||
cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_GAMMA = 3.0
|
||||
cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_ALPHA = 0.3
|
||||
|
||||
# Mask parameters
|
||||
# Channel size for the mask tower
|
||||
cfg.MODEL.TENSOR_MASK.MASK_CHANNELS = 128
|
||||
# Mask loss weight
|
||||
cfg.MODEL.TENSOR_MASK.MASK_LOSS_WEIGHT = 2.0
|
||||
# weight on positive pixels within the mask
|
||||
cfg.MODEL.TENSOR_MASK.POSITIVE_WEIGHT = 1.5
|
||||
# Whether to predict in the aligned representation
|
||||
cfg.MODEL.TENSOR_MASK.ALIGNED_ON = False
|
||||
# Whether to use the bipyramid architecture
|
||||
cfg.MODEL.TENSOR_MASK.BIPYRAMID_ON = False
|
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .swap_align2nat import SwapAlign2Nat, swap_align2nat
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
#pragma once
|
||||
#include <torch/types.h>
|
||||
|
||||
namespace tensormask {
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
at::Tensor SwapAlign2Nat_forward_cuda(
|
||||
const at::Tensor& X,
|
||||
const int lambda_val,
|
||||
const float pad_val);
|
||||
|
||||
at::Tensor SwapAlign2Nat_backward_cuda(
|
||||
const at::Tensor& gY,
|
||||
const int lambda_val,
|
||||
const int batch_size,
|
||||
const int channel,
|
||||
const int height,
|
||||
const int width);
|
||||
#endif
|
||||
|
||||
inline at::Tensor SwapAlign2Nat_forward(
|
||||
const at::Tensor& X,
|
||||
const int lambda_val,
|
||||
const float pad_val) {
|
||||
if (X.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return SwapAlign2Nat_forward_cuda(X, lambda_val, pad_val);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
||||
|
||||
inline at::Tensor SwapAlign2Nat_backward(
|
||||
const at::Tensor& gY,
|
||||
const int lambda_val,
|
||||
const int batch_size,
|
||||
const int channel,
|
||||
const int height,
|
||||
const int width) {
|
||||
if (gY.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return SwapAlign2Nat_backward_cuda(
|
||||
gY, lambda_val, batch_size, channel, height, width);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
||||
|
||||
} // namespace tensormask
|
@@ -0,0 +1,526 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
|
||||
// TODO make it in a common file
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T get_pixel_val(
|
||||
const T* tensor,
|
||||
const int idx,
|
||||
const int H,
|
||||
const int W,
|
||||
const int y,
|
||||
const int x,
|
||||
const int V,
|
||||
const int U,
|
||||
const int v,
|
||||
const int u,
|
||||
const T pad_val) {
|
||||
if ((y < 0) || (y >= H) || (x < 0) || (x >= W) || (v < 0) || (v >= V) ||
|
||||
(u < 0) || (u >= U)) {
|
||||
return pad_val;
|
||||
} else {
|
||||
return tensor[(((idx * V + v) * U + u) * H + y) * W + x];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline void add_pixel_val(
|
||||
T* tensor,
|
||||
const T val,
|
||||
const int idx,
|
||||
const int H,
|
||||
const int W,
|
||||
const int y,
|
||||
const int x,
|
||||
const int V,
|
||||
const int U,
|
||||
const int v,
|
||||
const int u) {
|
||||
if ((val == 0.) || (y < 0) || (y >= H) || (x < 0) || (x >= W) || (v < 0) ||
|
||||
(v >= V) || (u < 0) || (u >= U)) {
|
||||
return;
|
||||
} else {
|
||||
atomicAdd(tensor + ((((idx * V + v) * U + u) * H + y) * W + x), val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SwapAlign2NatForwardFeat(
|
||||
const int nthreads,
|
||||
const T* bottom_data,
|
||||
const int Vout,
|
||||
const int Uout,
|
||||
const float hVout,
|
||||
const float hUout,
|
||||
const int Vin,
|
||||
const int Uin,
|
||||
const float lambda,
|
||||
const int Hin,
|
||||
const int Win,
|
||||
const int Hout,
|
||||
const int Wout,
|
||||
const T pad_val,
|
||||
T* top_data) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int idx = index;
|
||||
const int x = idx % Wout;
|
||||
idx /= Wout;
|
||||
const int y = idx % Hout;
|
||||
idx /= Hout;
|
||||
const int u = idx % Uout;
|
||||
idx /= Uout;
|
||||
const int v = idx % Vout;
|
||||
idx /= Vout;
|
||||
|
||||
const float ox = x * lambda + u - hUout + 0.5;
|
||||
const int xf = static_cast<int>(floor(ox));
|
||||
const int xc = static_cast<int>(ceil(ox));
|
||||
const float xwc = ox - xf;
|
||||
const float xwf = 1. - xwc;
|
||||
|
||||
const float oy = y * lambda + v - hVout + 0.5;
|
||||
const int yf = static_cast<int>(floor(oy));
|
||||
const int yc = static_cast<int>(ceil(oy));
|
||||
const float ywc = oy - yf;
|
||||
const float ywf = 1. - ywc;
|
||||
|
||||
const float ou = (u + 0.5) / lambda - 0.5;
|
||||
const int uf = static_cast<int>(floor(ou));
|
||||
const int uc = static_cast<int>(ceil(ou));
|
||||
const float uwc = ou - uf;
|
||||
const float uwf = 1. - uwc;
|
||||
|
||||
const float ov = (v + 0.5) / lambda - 0.5;
|
||||
const int vf = static_cast<int>(floor(ov));
|
||||
const int vc = static_cast<int>(ceil(ov));
|
||||
const float vwc = ov - vf;
|
||||
const float vwf = 1. - vwc;
|
||||
|
||||
T val = ywf * xwf * vwf * uwf *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yf, xf, Vin, Uin, vf, uf, pad_val) +
|
||||
ywf * xwf * vwf * uwc *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yf, xf, Vin, Uin, vf, uc, pad_val) +
|
||||
ywf * xwf * vwc * uwf *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yf, xf, Vin, Uin, vc, uf, pad_val) +
|
||||
ywf * xwf * vwc * uwc *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yf, xf, Vin, Uin, vc, uc, pad_val) +
|
||||
ywf * xwc * vwf * uwf *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yf, xc, Vin, Uin, vf, uf, pad_val) +
|
||||
ywf * xwc * vwf * uwc *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yf, xc, Vin, Uin, vf, uc, pad_val) +
|
||||
ywf * xwc * vwc * uwf *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yf, xc, Vin, Uin, vc, uf, pad_val) +
|
||||
ywf * xwc * vwc * uwc *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yf, xc, Vin, Uin, vc, uc, pad_val) +
|
||||
ywc * xwf * vwf * uwf *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yc, xf, Vin, Uin, vf, uf, pad_val) +
|
||||
ywc * xwf * vwf * uwc *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yc, xf, Vin, Uin, vf, uc, pad_val) +
|
||||
ywc * xwf * vwc * uwf *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yc, xf, Vin, Uin, vc, uf, pad_val) +
|
||||
ywc * xwf * vwc * uwc *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yc, xf, Vin, Uin, vc, uc, pad_val) +
|
||||
ywc * xwc * vwf * uwf *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yc, xc, Vin, Uin, vf, uf, pad_val) +
|
||||
ywc * xwc * vwf * uwc *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yc, xc, Vin, Uin, vf, uc, pad_val) +
|
||||
ywc * xwc * vwc * uwf *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yc, xc, Vin, Uin, vc, uf, pad_val) +
|
||||
ywc * xwc * vwc * uwc *
|
||||
get_pixel_val(
|
||||
bottom_data, idx, Hin, Win, yc, xc, Vin, Uin, vc, uc, pad_val);
|
||||
|
||||
top_data[index] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SwapAlign2NatBackwardFeat(
|
||||
const int nthreads,
|
||||
const T* top_diff,
|
||||
const int Vout,
|
||||
const int Uout,
|
||||
const float hVout,
|
||||
const float hUout,
|
||||
const int Vin,
|
||||
const int Uin,
|
||||
const float lambda,
|
||||
const int Hin,
|
||||
const int Win,
|
||||
const int Hout,
|
||||
const int Wout,
|
||||
T* bottom_diff) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int idx = index;
|
||||
const int x = idx % Wout;
|
||||
idx /= Wout;
|
||||
const int y = idx % Hout;
|
||||
idx /= Hout;
|
||||
const int u = idx % Uout;
|
||||
idx /= Uout;
|
||||
const int v = idx % Vout;
|
||||
idx /= Vout;
|
||||
|
||||
const float ox = x * lambda + u - hUout + 0.5;
|
||||
const int xf = static_cast<int>(floor(ox));
|
||||
const int xc = static_cast<int>(ceil(ox));
|
||||
const float xwc = ox - xf;
|
||||
const float xwf = 1. - xwc;
|
||||
|
||||
const float oy = y * lambda + v - hVout + 0.5;
|
||||
const int yf = static_cast<int>(floor(oy));
|
||||
const int yc = static_cast<int>(ceil(oy));
|
||||
const float ywc = oy - yf;
|
||||
const float ywf = 1. - ywc;
|
||||
|
||||
const float ou = (u + 0.5) / lambda - 0.5;
|
||||
const int uf = static_cast<int>(floor(ou));
|
||||
const int uc = static_cast<int>(ceil(ou));
|
||||
const float uwc = ou - uf;
|
||||
const float uwf = 1. - uwc;
|
||||
|
||||
const float ov = (v + 0.5) / lambda - 0.5;
|
||||
const int vf = static_cast<int>(floor(ov));
|
||||
const int vc = static_cast<int>(ceil(ov));
|
||||
const float vwc = ov - vf;
|
||||
const float vwf = 1. - vwc;
|
||||
|
||||
const T grad = top_diff[index];
|
||||
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywf * xwf * vwf * uwf * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yf,
|
||||
xf,
|
||||
Vin,
|
||||
Uin,
|
||||
vf,
|
||||
uf);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywf * xwf * vwf * uwc * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yf,
|
||||
xf,
|
||||
Vin,
|
||||
Uin,
|
||||
vf,
|
||||
uc);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywf * xwf * vwc * uwf * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yf,
|
||||
xf,
|
||||
Vin,
|
||||
Uin,
|
||||
vc,
|
||||
uf);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywf * xwf * vwc * uwc * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yf,
|
||||
xf,
|
||||
Vin,
|
||||
Uin,
|
||||
vc,
|
||||
uc);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywf * xwc * vwf * uwf * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yf,
|
||||
xc,
|
||||
Vin,
|
||||
Uin,
|
||||
vf,
|
||||
uf);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywf * xwc * vwf * uwc * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yf,
|
||||
xc,
|
||||
Vin,
|
||||
Uin,
|
||||
vf,
|
||||
uc);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywf * xwc * vwc * uwf * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yf,
|
||||
xc,
|
||||
Vin,
|
||||
Uin,
|
||||
vc,
|
||||
uf);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywf * xwc * vwc * uwc * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yf,
|
||||
xc,
|
||||
Vin,
|
||||
Uin,
|
||||
vc,
|
||||
uc);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywc * xwf * vwf * uwf * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yc,
|
||||
xf,
|
||||
Vin,
|
||||
Uin,
|
||||
vf,
|
||||
uf);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywc * xwf * vwf * uwc * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yc,
|
||||
xf,
|
||||
Vin,
|
||||
Uin,
|
||||
vf,
|
||||
uc);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywc * xwf * vwc * uwf * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yc,
|
||||
xf,
|
||||
Vin,
|
||||
Uin,
|
||||
vc,
|
||||
uf);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywc * xwf * vwc * uwc * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yc,
|
||||
xf,
|
||||
Vin,
|
||||
Uin,
|
||||
vc,
|
||||
uc);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywc * xwc * vwf * uwf * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yc,
|
||||
xc,
|
||||
Vin,
|
||||
Uin,
|
||||
vf,
|
||||
uf);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywc * xwc * vwf * uwc * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yc,
|
||||
xc,
|
||||
Vin,
|
||||
Uin,
|
||||
vf,
|
||||
uc);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywc * xwc * vwc * uwf * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yc,
|
||||
xc,
|
||||
Vin,
|
||||
Uin,
|
||||
vc,
|
||||
uf);
|
||||
add_pixel_val(
|
||||
bottom_diff,
|
||||
ywc * xwc * vwc * uwc * grad,
|
||||
idx,
|
||||
Hin,
|
||||
Win,
|
||||
yc,
|
||||
xc,
|
||||
Vin,
|
||||
Uin,
|
||||
vc,
|
||||
uc);
|
||||
}
|
||||
}
|
||||
|
||||
namespace tensormask {
|
||||
|
||||
at::Tensor SwapAlign2Nat_forward_cuda(
|
||||
const at::Tensor& X,
|
||||
const int lambda_val,
|
||||
const float pad_val) {
|
||||
AT_ASSERTM(X.device().is_cuda(), "input must be a CUDA tensor");
|
||||
AT_ASSERTM(X.ndimension() == 4, "input must be a 4D tensor");
|
||||
AT_ASSERTM(lambda_val >= 1, "lambda should be greater or equal to 1");
|
||||
const int N = X.size(0);
|
||||
const int C = X.size(1);
|
||||
const int Vin = static_cast<int>(sqrt(static_cast<float>(C)));
|
||||
const int Uin = C / Vin;
|
||||
AT_ASSERTM(
|
||||
C == Vin * Uin && Vin == Uin, "#channels should be a square number");
|
||||
const int Vout = lambda_val * Vin;
|
||||
const int Uout = lambda_val * Uin;
|
||||
const int Hin = X.size(2);
|
||||
const int Win = X.size(3);
|
||||
const float lambda = static_cast<float>(lambda_val);
|
||||
const int Hout = static_cast<int>(ceil(Hin / lambda));
|
||||
const int Wout = static_cast<int>(ceil(Win / lambda));
|
||||
const float hVout = Vout / 2.;
|
||||
const float hUout = Uout / 2.;
|
||||
|
||||
at::cuda::CUDAGuard device_guard(X.device());
|
||||
|
||||
at::Tensor Y = at::empty({N, Vout * Uout, Hout, Wout}, X.options());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(at::cuda::ATenCeilDiv(Y.numel(), 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
if (Y.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return Y;
|
||||
}
|
||||
|
||||
auto X_ = X.contiguous();
|
||||
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "SwapAlign2Nat_forward", [&] {
|
||||
SwapAlign2NatForwardFeat<scalar_t><<<grid, block, 0, stream>>>(
|
||||
Y.numel(),
|
||||
X_.data_ptr<scalar_t>(),
|
||||
Vout,
|
||||
Uout,
|
||||
hVout,
|
||||
hUout,
|
||||
Vin,
|
||||
Uin,
|
||||
lambda,
|
||||
Hin,
|
||||
Win,
|
||||
Hout,
|
||||
Wout,
|
||||
pad_val,
|
||||
Y.data_ptr<scalar_t>());
|
||||
});
|
||||
cudaDeviceSynchronize();
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return Y;
|
||||
}
|
||||
|
||||
at::Tensor SwapAlign2Nat_backward_cuda(
|
||||
const at::Tensor& gY,
|
||||
const int lambda_val,
|
||||
const int batch_size,
|
||||
const int channel,
|
||||
const int height,
|
||||
const int width) {
|
||||
AT_ASSERTM(gY.device().is_cuda(), "input gradient must be a CUDA tensor");
|
||||
AT_ASSERTM(gY.ndimension() == 4, "input gradient must be a 4D tensor");
|
||||
AT_ASSERTM(lambda_val >= 1, "lambda should be greater or equal to 1");
|
||||
const int Vin = static_cast<int>(sqrt(static_cast<float>(channel)));
|
||||
const int Uin = channel / Vin;
|
||||
const int Vout = lambda_val * Vin;
|
||||
const int Uout = lambda_val * Uin;
|
||||
const float hVout = Vout / 2.;
|
||||
const float hUout = Uout / 2.;
|
||||
const int Hout = gY.size(2);
|
||||
const int Wout = gY.size(3);
|
||||
|
||||
at::cuda::CUDAGuard device_guard(gY.device());
|
||||
|
||||
at::Tensor gX = at::zeros({batch_size, channel, height, width}, gY.options());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(at::cuda::ATenCeilDiv(gY.numel(), 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
// handle possibly empty gradients
|
||||
if (gY.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return gX;
|
||||
}
|
||||
|
||||
auto gY_ = gY.contiguous();
|
||||
AT_DISPATCH_FLOATING_TYPES(gY.scalar_type(), "SwapAlign2Nat_backward", [&] {
|
||||
SwapAlign2NatBackwardFeat<scalar_t><<<grid, block, 0, stream>>>(
|
||||
gY.numel(),
|
||||
gY_.data_ptr<scalar_t>(),
|
||||
Vout,
|
||||
Uout,
|
||||
hVout,
|
||||
hUout,
|
||||
Vin,
|
||||
Uin,
|
||||
static_cast<float>(lambda_val),
|
||||
height,
|
||||
width,
|
||||
Hout,
|
||||
Wout,
|
||||
gX.data_ptr<scalar_t>());
|
||||
});
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return gX;
|
||||
}
|
||||
|
||||
} // namespace tensormask
|
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "SwapAlign2Nat/SwapAlign2Nat.h"
|
||||
|
||||
namespace tensormask {
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"swap_align2nat_forward",
|
||||
&SwapAlign2Nat_forward,
|
||||
"SwapAlign2Nat_forward");
|
||||
m.def(
|
||||
"swap_align2nat_backward",
|
||||
&SwapAlign2Nat_backward,
|
||||
"SwapAlign2Nat_backward");
|
||||
}
|
||||
|
||||
} // namespace tensormask
|
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from tensormask import _C
|
||||
|
||||
|
||||
class _SwapAlign2Nat(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, X, lambda_val, pad_val):
|
||||
ctx.lambda_val = lambda_val
|
||||
ctx.input_shape = X.size()
|
||||
|
||||
Y = _C.swap_align2nat_forward(X, lambda_val, pad_val)
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, gY):
|
||||
lambda_val = ctx.lambda_val
|
||||
bs, ch, h, w = ctx.input_shape
|
||||
|
||||
gX = _C.swap_align2nat_backward(gY, lambda_val, bs, ch, h, w)
|
||||
|
||||
return gX, None, None
|
||||
|
||||
|
||||
swap_align2nat = _SwapAlign2Nat.apply
|
||||
|
||||
|
||||
class SwapAlign2Nat(nn.Module):
|
||||
"""
|
||||
The op `SwapAlign2Nat` described in https://arxiv.org/abs/1903.12174.
|
||||
Given an input tensor that predicts masks of shape (N, C=VxU, H, W),
|
||||
apply the op, it will return masks of shape (N, V'xU', H', W') where
|
||||
the unit lengths of (V, U) and (H, W) are swapped, and the mask representation
|
||||
is transformed from aligned to natural.
|
||||
Args:
|
||||
lambda_val (int): the relative unit length ratio between (V, U) and (H, W),
|
||||
as we always have larger unit lengths for (V, U) than (H, W),
|
||||
lambda_val is always >= 1.
|
||||
pad_val (float): padding value for the values falling outside of the input
|
||||
tensor, default set to -6 as sigmoid(-6) is ~0, indicating
|
||||
that is no masks outside of the tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, lambda_val, pad_val=-6.0):
|
||||
super(SwapAlign2Nat, self).__init__()
|
||||
self.lambda_val = lambda_val
|
||||
self.pad_val = pad_val
|
||||
|
||||
def forward(self, X):
|
||||
return swap_align2nat(X, self.lambda_val, self.pad_val)
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = self.__class__.__name__ + "("
|
||||
tmpstr += "lambda_val=" + str(self.lambda_val)
|
||||
tmpstr += ", pad_val=" + str(self.pad_val)
|
||||
tmpstr += ")"
|
||||
return tmpstr
|
@@ -0,0 +1 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
from torch.autograd import gradcheck
|
||||
|
||||
from tensormask.layers.swap_align2nat import SwapAlign2Nat
|
||||
|
||||
|
||||
class SwapAlign2NatTest(unittest.TestCase):
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
def test_swap_align2nat_gradcheck_cuda(self):
|
||||
dtype = torch.float64
|
||||
device = torch.device("cuda")
|
||||
m = SwapAlign2Nat(2).to(dtype=dtype, device=device)
|
||||
x = torch.rand(2, 4, 10, 10, dtype=dtype, device=device, requires_grad=True)
|
||||
|
||||
self.assertTrue(gradcheck(m, x), "gradcheck failed for SwapAlign2Nat CUDA")
|
||||
|
||||
def _swap_align2nat(self, tensor, lambda_val):
|
||||
"""
|
||||
The basic setup for testing Swap_Align
|
||||
"""
|
||||
op = SwapAlign2Nat(lambda_val, pad_val=0.0)
|
||||
input = torch.from_numpy(tensor[None, :, :, :].astype("float32"))
|
||||
output = op.forward(input.cuda()).cpu().numpy()
|
||||
return output[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
TensorMask Training Script.
|
||||
|
||||
This script is a simplified version of the training script in detectron2/tools.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import detectron2.utils.comm as comm
|
||||
from detectron2.checkpoint import DetectionCheckpointer
|
||||
from detectron2.config import get_cfg
|
||||
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
|
||||
from detectron2.evaluation import COCOEvaluator, verify_results
|
||||
|
||||
from tensormask import add_tensormask_config
|
||||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
||||
if output_folder is None:
|
||||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||||
return COCOEvaluator(dataset_name, cfg, True, output_folder)
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
add_tensormask_config(cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
model = Trainer.build_model(cfg)
|
||||
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
res = Trainer.test(cfg, model)
|
||||
if comm.is_main_process():
|
||||
verify_results(cfg, res)
|
||||
return res
|
||||
|
||||
trainer = Trainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
launch(
|
||||
main,
|
||||
args.num_gpus,
|
||||
num_machines=args.num_machines,
|
||||
machine_rank=args.machine_rank,
|
||||
dist_url=args.dist_url,
|
||||
args=(args,),
|
||||
)
|
Reference in New Issue
Block a user