227 lines
9.3 KiB
Python
227 lines
9.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
"""
|
|
@Author : Peike Li
|
|
@Contact : peike.li@yahoo.com
|
|
@File : ocnet.py
|
|
@Time : 8/4/19 3:36 PM
|
|
@Desc :
|
|
@License : This source code is licensed under the license found in the
|
|
LICENSE file in the root directory of this source tree.
|
|
"""
|
|
|
|
import functools
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Variable
|
|
from torch.nn import functional as F
|
|
|
|
from modules import InPlaceABNSync
|
|
BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
|
|
|
|
|
|
class _SelfAttentionBlock(nn.Module):
|
|
'''
|
|
The basic implementation for self-attention block/non-local block
|
|
Input:
|
|
N X C X H X W
|
|
Parameters:
|
|
in_channels : the dimension of the input feature map
|
|
key_channels : the dimension after the key/query transform
|
|
value_channels : the dimension after the value transform
|
|
scale : choose the scale to downsample the input feature maps (save memory cost)
|
|
Return:
|
|
N X C X H X W
|
|
position-aware context features.(w/o concate or add with the input)
|
|
'''
|
|
|
|
def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1):
|
|
super(_SelfAttentionBlock, self).__init__()
|
|
self.scale = scale
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.key_channels = key_channels
|
|
self.value_channels = value_channels
|
|
if out_channels == None:
|
|
self.out_channels = in_channels
|
|
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
|
self.f_key = nn.Sequential(
|
|
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
|
kernel_size=1, stride=1, padding=0),
|
|
InPlaceABNSync(self.key_channels),
|
|
)
|
|
self.f_query = self.f_key
|
|
self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
|
|
kernel_size=1, stride=1, padding=0)
|
|
self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels,
|
|
kernel_size=1, stride=1, padding=0)
|
|
nn.init.constant(self.W.weight, 0)
|
|
nn.init.constant(self.W.bias, 0)
|
|
|
|
def forward(self, x):
|
|
batch_size, h, w = x.size(0), x.size(2), x.size(3)
|
|
if self.scale > 1:
|
|
x = self.pool(x)
|
|
|
|
value = self.f_value(x).view(batch_size, self.value_channels, -1)
|
|
value = value.permute(0, 2, 1)
|
|
query = self.f_query(x).view(batch_size, self.key_channels, -1)
|
|
query = query.permute(0, 2, 1)
|
|
key = self.f_key(x).view(batch_size, self.key_channels, -1)
|
|
|
|
sim_map = torch.matmul(query, key)
|
|
sim_map = (self.key_channels ** -.5) * sim_map
|
|
sim_map = F.softmax(sim_map, dim=-1)
|
|
|
|
context = torch.matmul(sim_map, value)
|
|
context = context.permute(0, 2, 1).contiguous()
|
|
context = context.view(batch_size, self.value_channels, *x.size()[2:])
|
|
context = self.W(context)
|
|
if self.scale > 1:
|
|
context = F.upsample(input=context, size=(h, w), mode='bilinear', align_corners=True)
|
|
return context
|
|
|
|
|
|
class SelfAttentionBlock2D(_SelfAttentionBlock):
|
|
def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1):
|
|
super(SelfAttentionBlock2D, self).__init__(in_channels,
|
|
key_channels,
|
|
value_channels,
|
|
out_channels,
|
|
scale)
|
|
|
|
|
|
class BaseOC_Module(nn.Module):
|
|
"""
|
|
Implementation of the BaseOC module
|
|
Parameters:
|
|
in_features / out_features: the channels of the input / output feature maps.
|
|
dropout: we choose 0.05 as the default value.
|
|
size: you can apply multiple sizes. Here we only use one size.
|
|
Return:
|
|
features fused with Object context information.
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])):
|
|
super(BaseOC_Module, self).__init__()
|
|
self.stages = []
|
|
self.stages = nn.ModuleList(
|
|
[self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
|
|
self.conv_bn_dropout = nn.Sequential(
|
|
nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0),
|
|
InPlaceABNSync(out_channels),
|
|
nn.Dropout2d(dropout)
|
|
)
|
|
|
|
def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
|
|
return SelfAttentionBlock2D(in_channels,
|
|
key_channels,
|
|
value_channels,
|
|
output_channels,
|
|
size)
|
|
|
|
def forward(self, feats):
|
|
priors = [stage(feats) for stage in self.stages]
|
|
context = priors[0]
|
|
for i in range(1, len(priors)):
|
|
context += priors[i]
|
|
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
|
|
return output
|
|
|
|
|
|
class BaseOC_Context_Module(nn.Module):
|
|
"""
|
|
Output only the context features.
|
|
Parameters:
|
|
in_features / out_features: the channels of the input / output feature maps.
|
|
dropout: specify the dropout ratio
|
|
fusion: We provide two different fusion method, "concat" or "add"
|
|
size: we find that directly learn the attention weights on even 1/8 feature maps is hard.
|
|
Return:
|
|
features after "concat" or "add"
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])):
|
|
super(BaseOC_Context_Module, self).__init__()
|
|
self.stages = []
|
|
self.stages = nn.ModuleList(
|
|
[self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
|
|
self.conv_bn_dropout = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
|
|
InPlaceABNSync(out_channels),
|
|
)
|
|
|
|
def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
|
|
return SelfAttentionBlock2D(in_channels,
|
|
key_channels,
|
|
value_channels,
|
|
output_channels,
|
|
size)
|
|
|
|
def forward(self, feats):
|
|
priors = [stage(feats) for stage in self.stages]
|
|
context = priors[0]
|
|
for i in range(1, len(priors)):
|
|
context += priors[i]
|
|
output = self.conv_bn_dropout(context)
|
|
return output
|
|
|
|
|
|
class ASP_OC_Module(nn.Module):
|
|
def __init__(self, features, out_features=256, dilations=(12, 24, 36)):
|
|
super(ASP_OC_Module, self).__init__()
|
|
self.context = nn.Sequential(nn.Conv2d(features, out_features, kernel_size=3, padding=1, dilation=1, bias=True),
|
|
InPlaceABNSync(out_features),
|
|
BaseOC_Context_Module(in_channels=out_features, out_channels=out_features,
|
|
key_channels=out_features // 2, value_channels=out_features,
|
|
dropout=0, sizes=([2])))
|
|
self.conv2 = nn.Sequential(nn.Conv2d(features, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
|
|
InPlaceABNSync(out_features))
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
|
|
InPlaceABNSync(out_features))
|
|
self.conv4 = nn.Sequential(
|
|
nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
|
|
InPlaceABNSync(out_features))
|
|
self.conv5 = nn.Sequential(
|
|
nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
|
|
InPlaceABNSync(out_features))
|
|
|
|
self.conv_bn_dropout = nn.Sequential(
|
|
nn.Conv2d(out_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
|
|
InPlaceABNSync(out_features),
|
|
nn.Dropout2d(0.1)
|
|
)
|
|
|
|
def _cat_each(self, feat1, feat2, feat3, feat4, feat5):
|
|
assert (len(feat1) == len(feat2))
|
|
z = []
|
|
for i in range(len(feat1)):
|
|
z.append(torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]), 1))
|
|
return z
|
|
|
|
def forward(self, x):
|
|
if isinstance(x, Variable):
|
|
_, _, h, w = x.size()
|
|
elif isinstance(x, tuple) or isinstance(x, list):
|
|
_, _, h, w = x[0].size()
|
|
else:
|
|
raise RuntimeError('unknown input type')
|
|
|
|
feat1 = self.context(x)
|
|
feat2 = self.conv2(x)
|
|
feat3 = self.conv3(x)
|
|
feat4 = self.conv4(x)
|
|
feat5 = self.conv5(x)
|
|
|
|
if isinstance(x, Variable):
|
|
out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
|
|
elif isinstance(x, tuple) or isinstance(x, list):
|
|
out = self._cat_each(feat1, feat2, feat3, feat4, feat5)
|
|
else:
|
|
raise RuntimeError('unknown input type')
|
|
output = self.conv_bn_dropout(out)
|
|
return output
|