41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
import torch
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def generate_edge_tensor(label, edge_width=3):
|
|
label = label.type(torch.cuda.FloatTensor)
|
|
if len(label.shape) == 2:
|
|
label = label.unsqueeze(0)
|
|
n, h, w = label.shape
|
|
edge = torch.zeros(label.shape, dtype=torch.float).cuda()
|
|
# right
|
|
edge_right = edge[:, 1:h, :]
|
|
edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255)
|
|
& (label[:, :h - 1, :] != 255)] = 1
|
|
|
|
# up
|
|
edge_up = edge[:, :, :w - 1]
|
|
edge_up[(label[:, :, :w - 1] != label[:, :, 1:w])
|
|
& (label[:, :, :w - 1] != 255)
|
|
& (label[:, :, 1:w] != 255)] = 1
|
|
|
|
# upright
|
|
edge_upright = edge[:, :h - 1, :w - 1]
|
|
edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w])
|
|
& (label[:, :h - 1, :w - 1] != 255)
|
|
& (label[:, 1:h, 1:w] != 255)] = 1
|
|
|
|
# bottomright
|
|
edge_bottomright = edge[:, :h - 1, 1:w]
|
|
edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1])
|
|
& (label[:, :h - 1, 1:w] != 255)
|
|
& (label[:, 1:h, :w - 1] != 255)] = 1
|
|
|
|
kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda()
|
|
with torch.no_grad():
|
|
edge = edge.unsqueeze(1)
|
|
edge = F.conv2d(edge, kernel, stride=1, padding=1)
|
|
edge[edge!=0] = 1
|
|
edge = edge.squeeze()
|
|
return edge
|