Add at new repo again
This commit is contained in:
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
@Author : Peike Li
|
||||
@Contact : peike.li@yahoo.com
|
||||
@File : datasets.py
|
||||
@Time : 8/4/19 3:35 PM
|
||||
@Desc :
|
||||
@License : This source code is licensed under the license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
import cv2
|
||||
from torch.utils import data
|
||||
from utils.transforms import get_affine_transform
|
||||
|
||||
|
||||
class CropDataSet(data.Dataset):
|
||||
def __init__(self, root, split_name, crop_size=[473, 473], scale_factor=0.25,
|
||||
rotation_factor=30, ignore_label=255, transform=None):
|
||||
self.root = root
|
||||
self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
|
||||
self.crop_size = np.asarray(crop_size)
|
||||
self.ignore_label = ignore_label
|
||||
self.scale_factor = scale_factor
|
||||
self.rotation_factor = rotation_factor
|
||||
self.flip_prob = 0.5
|
||||
self.transform = transform
|
||||
self.split_name = split_name
|
||||
|
||||
list_path = os.path.join(self.root, self.split_name + '.txt')
|
||||
train_list = [i_id.strip() for i_id in open(list_path)]
|
||||
|
||||
self.train_list = train_list
|
||||
self.number_samples = len(self.train_list)
|
||||
|
||||
def __len__(self):
|
||||
return self.number_samples
|
||||
|
||||
def _box2cs(self, box):
|
||||
x, y, w, h = box[:4]
|
||||
return self._xywh2cs(x, y, w, h)
|
||||
|
||||
def _xywh2cs(self, x, y, w, h):
|
||||
center = np.zeros((2), dtype=np.float32)
|
||||
center[0] = x + w * 0.5
|
||||
center[1] = y + h * 0.5
|
||||
if w > self.aspect_ratio * h:
|
||||
h = w * 1.0 / self.aspect_ratio
|
||||
elif w < self.aspect_ratio * h:
|
||||
w = h * self.aspect_ratio
|
||||
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
|
||||
return center, scale
|
||||
|
||||
def __getitem__(self, index):
|
||||
train_item = self.train_list[index]
|
||||
|
||||
im_path = os.path.join(self.root, self.split_name + '_images', train_item + '.jpg')
|
||||
parsing_anno_path = os.path.join(self.root, self.split_name + '_segmentations', train_item + '.png')
|
||||
|
||||
im = cv2.imread(im_path, cv2.IMREAD_COLOR)
|
||||
h, w, _ = im.shape
|
||||
parsing_anno = np.zeros((h, w), dtype=np.long)
|
||||
|
||||
# Get person center and scale
|
||||
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
||||
r = 0
|
||||
|
||||
if self.split_name != 'test':
|
||||
# Get pose annotation
|
||||
parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE)
|
||||
sf = self.scale_factor
|
||||
rf = self.rotation_factor
|
||||
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
|
||||
r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) if random.random() <= 0.6 else 0
|
||||
|
||||
if random.random() <= self.flip_prob:
|
||||
im = im[:, ::-1, :]
|
||||
parsing_anno = parsing_anno[:, ::-1]
|
||||
person_center[0] = im.shape[1] - person_center[0] - 1
|
||||
right_idx = [15, 17, 19]
|
||||
left_idx = [14, 16, 18]
|
||||
for i in range(0, 3):
|
||||
right_pos = np.where(parsing_anno == right_idx[i])
|
||||
left_pos = np.where(parsing_anno == left_idx[i])
|
||||
parsing_anno[right_pos[0], right_pos[1]] = left_idx[i]
|
||||
parsing_anno[left_pos[0], left_pos[1]] = right_idx[i]
|
||||
|
||||
trans = get_affine_transform(person_center, s, r, self.crop_size)
|
||||
input = cv2.warpAffine(
|
||||
im,
|
||||
trans,
|
||||
(int(self.crop_size[1]), int(self.crop_size[0])),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(0, 0, 0))
|
||||
|
||||
if self.transform:
|
||||
input = self.transform(input)
|
||||
|
||||
meta = {
|
||||
'name': train_item,
|
||||
'center': person_center,
|
||||
'height': h,
|
||||
'width': w,
|
||||
'scale': s,
|
||||
'rotation': r
|
||||
}
|
||||
|
||||
if self.split_name == 'val' or self.split_name == 'test':
|
||||
return input, meta
|
||||
else:
|
||||
label_parsing = cv2.warpAffine(
|
||||
parsing_anno,
|
||||
trans,
|
||||
(int(self.crop_size[1]), int(self.crop_size[0])),
|
||||
flags=cv2.INTER_NEAREST,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(255))
|
||||
|
||||
label_parsing = torch.from_numpy(label_parsing)
|
||||
|
||||
return input, label_parsing, meta
|
||||
|
||||
|
||||
class CropDataValSet(data.Dataset):
|
||||
def __init__(self, root, split_name='crop_pic', crop_size=[473, 473], transform=None, flip=False):
|
||||
self.root = root
|
||||
self.crop_size = crop_size
|
||||
self.transform = transform
|
||||
self.flip = flip
|
||||
self.split_name = split_name
|
||||
self.root = root
|
||||
self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
|
||||
self.crop_size = np.asarray(crop_size)
|
||||
|
||||
list_path = os.path.join(self.root, self.split_name + '.txt')
|
||||
val_list = [i_id.strip() for i_id in open(list_path)]
|
||||
|
||||
self.val_list = val_list
|
||||
self.number_samples = len(self.val_list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.val_list)
|
||||
|
||||
def _box2cs(self, box):
|
||||
x, y, w, h = box[:4]
|
||||
return self._xywh2cs(x, y, w, h)
|
||||
|
||||
def _xywh2cs(self, x, y, w, h):
|
||||
center = np.zeros((2), dtype=np.float32)
|
||||
center[0] = x + w * 0.5
|
||||
center[1] = y + h * 0.5
|
||||
if w > self.aspect_ratio * h:
|
||||
h = w * 1.0 / self.aspect_ratio
|
||||
elif w < self.aspect_ratio * h:
|
||||
w = h * self.aspect_ratio
|
||||
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
|
||||
|
||||
return center, scale
|
||||
|
||||
def __getitem__(self, index):
|
||||
val_item = self.val_list[index]
|
||||
# Load training image
|
||||
im_path = os.path.join(self.root, self.split_name, val_item + '.jpg')
|
||||
im = cv2.imread(im_path, cv2.IMREAD_COLOR)
|
||||
h, w, _ = im.shape
|
||||
# Get person center and scale
|
||||
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
||||
r = 0
|
||||
trans = get_affine_transform(person_center, s, r, self.crop_size)
|
||||
input = cv2.warpAffine(
|
||||
im,
|
||||
trans,
|
||||
(int(self.crop_size[1]), int(self.crop_size[0])),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(0, 0, 0))
|
||||
input = self.transform(input)
|
||||
flip_input = input.flip(dims=[-1])
|
||||
if self.flip:
|
||||
batch_input_im = torch.stack([input, flip_input])
|
||||
else:
|
||||
batch_input_im = input
|
||||
|
||||
meta = {
|
||||
'name': val_item,
|
||||
'center': person_center,
|
||||
'height': h,
|
||||
'width': w,
|
||||
'scale': s,
|
||||
'rotation': r
|
||||
}
|
||||
|
||||
return batch_input_im, meta
|
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
@Author : Peike Li
|
||||
@Contact : peike.li@yahoo.com
|
||||
@File : evaluate.py
|
||||
@Time : 8/4/19 3:36 PM
|
||||
@Desc :
|
||||
@License : This source code is licensed under the license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch.utils import data
|
||||
from tqdm import tqdm
|
||||
from PIL import Image as PILImage
|
||||
import torchvision.transforms as transforms
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
import networks
|
||||
from utils.miou import compute_mean_ioU
|
||||
from utils.transforms import BGR2RGB_transform
|
||||
from utils.transforms import transform_parsing, transform_logits
|
||||
from mhp_extension.global_local_parsing.global_local_datasets import CropDataValSet
|
||||
|
||||
|
||||
def get_arguments():
|
||||
"""Parse all the arguments provided from the CLI.
|
||||
|
||||
Returns:
|
||||
A list of parsed arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
|
||||
|
||||
# Network Structure
|
||||
parser.add_argument("--arch", type=str, default='resnet101')
|
||||
# Data Preference
|
||||
parser.add_argument("--data-dir", type=str, default='./data/LIP')
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
parser.add_argument("--split-name", type=str, default='crop_pic')
|
||||
parser.add_argument("--input-size", type=str, default='473,473')
|
||||
parser.add_argument("--num-classes", type=int, default=20)
|
||||
parser.add_argument("--ignore-label", type=int, default=255)
|
||||
parser.add_argument("--random-mirror", action="store_true")
|
||||
parser.add_argument("--random-scale", action="store_true")
|
||||
# Evaluation Preference
|
||||
parser.add_argument("--log-dir", type=str, default='./log')
|
||||
parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar')
|
||||
parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
|
||||
parser.add_argument("--save-results", action="store_true", help="whether to save the results.")
|
||||
parser.add_argument("--flip", action="store_true", help="random flip during the test.")
|
||||
parser.add_argument("--multi-scales", type=str, default='1', help="multiple scales during the test")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_palette(num_cls):
|
||||
""" Returns the color map for visualizing the segmentation mask.
|
||||
Args:
|
||||
num_cls: Number of classes
|
||||
Returns:
|
||||
The color map
|
||||
"""
|
||||
n = num_cls
|
||||
palette = [0] * (n * 3)
|
||||
for j in range(0, n):
|
||||
lab = j
|
||||
palette[j * 3 + 0] = 0
|
||||
palette[j * 3 + 1] = 0
|
||||
palette[j * 3 + 2] = 0
|
||||
i = 0
|
||||
while lab:
|
||||
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
|
||||
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
|
||||
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
|
||||
i += 1
|
||||
lab >>= 3
|
||||
return palette
|
||||
|
||||
|
||||
def multi_scale_testing(model, batch_input_im, crop_size=[473, 473], flip=True, multi_scales=[1]):
|
||||
flipped_idx = (15, 14, 17, 16, 19, 18)
|
||||
if len(batch_input_im.shape) > 4:
|
||||
batch_input_im = batch_input_im.squeeze()
|
||||
if len(batch_input_im.shape) == 3:
|
||||
batch_input_im = batch_input_im.unsqueeze(0)
|
||||
|
||||
interp = torch.nn.Upsample(size=crop_size, mode='bilinear', align_corners=True)
|
||||
ms_outputs = []
|
||||
for s in multi_scales:
|
||||
interp_im = torch.nn.Upsample(scale_factor=s, mode='bilinear', align_corners=True)
|
||||
scaled_im = interp_im(batch_input_im)
|
||||
parsing_output = model(scaled_im)
|
||||
parsing_output = parsing_output[0][-1]
|
||||
output = parsing_output[0]
|
||||
if flip:
|
||||
flipped_output = parsing_output[1]
|
||||
flipped_output[14:20, :, :] = flipped_output[flipped_idx, :, :]
|
||||
output += flipped_output.flip(dims=[-1])
|
||||
output *= 0.5
|
||||
output = interp(output.unsqueeze(0))
|
||||
ms_outputs.append(output[0])
|
||||
ms_fused_parsing_output = torch.stack(ms_outputs)
|
||||
ms_fused_parsing_output = ms_fused_parsing_output.mean(0)
|
||||
ms_fused_parsing_output = ms_fused_parsing_output.permute(1, 2, 0) # HWC
|
||||
parsing = torch.argmax(ms_fused_parsing_output, dim=2)
|
||||
parsing = parsing.data.cpu().numpy()
|
||||
ms_fused_parsing_output = ms_fused_parsing_output.data.cpu().numpy()
|
||||
return parsing, ms_fused_parsing_output
|
||||
|
||||
|
||||
def main():
|
||||
"""Create the model and start the evaluation process."""
|
||||
args = get_arguments()
|
||||
multi_scales = [float(i) for i in args.multi_scales.split(',')]
|
||||
gpus = [int(i) for i in args.gpu.split(',')]
|
||||
assert len(gpus) == 1
|
||||
if not args.gpu == 'None':
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
|
||||
h, w = map(int, args.input_size.split(','))
|
||||
input_size = [h, w]
|
||||
|
||||
model = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=None)
|
||||
|
||||
IMAGE_MEAN = model.mean
|
||||
IMAGE_STD = model.std
|
||||
INPUT_SPACE = model.input_space
|
||||
print('image mean: {}'.format(IMAGE_MEAN))
|
||||
print('image std: {}'.format(IMAGE_STD))
|
||||
print('input space:{}'.format(INPUT_SPACE))
|
||||
if INPUT_SPACE == 'BGR':
|
||||
print('BGR Transformation')
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=IMAGE_MEAN,
|
||||
std=IMAGE_STD),
|
||||
|
||||
])
|
||||
if INPUT_SPACE == 'RGB':
|
||||
print('RGB Transformation')
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
BGR2RGB_transform(),
|
||||
transforms.Normalize(mean=IMAGE_MEAN,
|
||||
std=IMAGE_STD),
|
||||
])
|
||||
|
||||
# Data loader
|
||||
lip_test_dataset = CropDataValSet(args.data_dir, args.split_name, crop_size=input_size, transform=transform,
|
||||
flip=args.flip)
|
||||
num_samples = len(lip_test_dataset)
|
||||
print('Totoal testing sample numbers: {}'.format(num_samples))
|
||||
testloader = data.DataLoader(lip_test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
|
||||
|
||||
# Load model weight
|
||||
state_dict = torch.load(args.model_restore)
|
||||
from collections import OrderedDict
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] # remove `module.`
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
|
||||
sp_results_dir = os.path.join(args.log_dir, args.split_name + '_parsing')
|
||||
if not os.path.exists(sp_results_dir):
|
||||
os.makedirs(sp_results_dir)
|
||||
|
||||
palette = get_palette(20)
|
||||
parsing_preds = []
|
||||
scales = np.zeros((num_samples, 2), dtype=np.float32)
|
||||
centers = np.zeros((num_samples, 2), dtype=np.int32)
|
||||
with torch.no_grad():
|
||||
for idx, batch in enumerate(tqdm(testloader)):
|
||||
image, meta = batch
|
||||
if (len(image.shape) > 4):
|
||||
image = image.squeeze()
|
||||
im_name = meta['name'][0]
|
||||
c = meta['center'].numpy()[0]
|
||||
s = meta['scale'].numpy()[0]
|
||||
w = meta['width'].numpy()[0]
|
||||
h = meta['height'].numpy()[0]
|
||||
scales[idx, :] = s
|
||||
centers[idx, :] = c
|
||||
parsing, logits = multi_scale_testing(model, image.cuda(), crop_size=input_size, flip=args.flip,
|
||||
multi_scales=multi_scales)
|
||||
if args.save_results:
|
||||
parsing_result = transform_parsing(parsing, c, s, w, h, input_size)
|
||||
parsing_result_path = os.path.join(sp_results_dir, im_name + '.png')
|
||||
output_im = PILImage.fromarray(np.asarray(parsing_result, dtype=np.uint8))
|
||||
output_im.putpalette(palette)
|
||||
output_im.save(parsing_result_path)
|
||||
# save logits
|
||||
logits_result = transform_logits(logits, c, s, w, h, input_size)
|
||||
logits_result_path = os.path.join(sp_results_dir, im_name + '.npy')
|
||||
np.save(logits_result_path, logits_result)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
@Author : Peike Li
|
||||
@Contact : peike.li@yahoo.com
|
||||
@File : train.py
|
||||
@Time : 8/4/19 3:36 PM
|
||||
@Desc :
|
||||
@License : This source code is licensed under the license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import timeit
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils import data
|
||||
|
||||
import networks
|
||||
import utils.schp as schp
|
||||
from datasets.datasets import LIPDataSet
|
||||
from datasets.target_generation import generate_edge_tensor
|
||||
from utils.transforms import BGR2RGB_transform
|
||||
from utils.criterion import CriterionAll
|
||||
from utils.encoding import DataParallelModel, DataParallelCriterion
|
||||
from utils.warmup_scheduler import SGDRScheduler
|
||||
|
||||
|
||||
def get_arguments():
|
||||
"""Parse all the arguments provided from the CLI.
|
||||
Returns:
|
||||
A list of parsed arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
|
||||
|
||||
# Network Structure
|
||||
parser.add_argument("--arch", type=str, default='resnet101')
|
||||
# Data Preference
|
||||
parser.add_argument("--data-dir", type=str, default='./data/LIP')
|
||||
parser.add_argument("--batch-size", type=int, default=16)
|
||||
parser.add_argument("--input-size", type=str, default='473,473')
|
||||
parser.add_argument("--split-name", type=str, default='crop_pic')
|
||||
parser.add_argument("--num-classes", type=int, default=20)
|
||||
parser.add_argument("--ignore-label", type=int, default=255)
|
||||
parser.add_argument("--random-mirror", action="store_true")
|
||||
parser.add_argument("--random-scale", action="store_true")
|
||||
# Training Strategy
|
||||
parser.add_argument("--learning-rate", type=float, default=7e-3)
|
||||
parser.add_argument("--momentum", type=float, default=0.9)
|
||||
parser.add_argument("--weight-decay", type=float, default=5e-4)
|
||||
parser.add_argument("--gpu", type=str, default='0,1,2')
|
||||
parser.add_argument("--start-epoch", type=int, default=0)
|
||||
parser.add_argument("--epochs", type=int, default=150)
|
||||
parser.add_argument("--eval-epochs", type=int, default=10)
|
||||
parser.add_argument("--imagenet-pretrain", type=str, default='./pretrain_model/resnet101-imagenet.pth')
|
||||
parser.add_argument("--log-dir", type=str, default='./log')
|
||||
parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar')
|
||||
parser.add_argument("--schp-start", type=int, default=100, help='schp start epoch')
|
||||
parser.add_argument("--cycle-epochs", type=int, default=10, help='schp cyclical epoch')
|
||||
parser.add_argument("--schp-restore", type=str, default='./log/schp_checkpoint.pth.tar')
|
||||
parser.add_argument("--lambda-s", type=float, default=1, help='segmentation loss weight')
|
||||
parser.add_argument("--lambda-e", type=float, default=1, help='edge loss weight')
|
||||
parser.add_argument("--lambda-c", type=float, default=0.1, help='segmentation-edge consistency loss weight')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_arguments()
|
||||
print(args)
|
||||
|
||||
start_epoch = 0
|
||||
cycle_n = 0
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.makedirs(args.log_dir)
|
||||
with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file:
|
||||
json.dump(vars(args), opt_file)
|
||||
|
||||
gpus = [int(i) for i in args.gpu.split(',')]
|
||||
if not args.gpu == 'None':
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
|
||||
input_size = list(map(int, args.input_size.split(',')))
|
||||
|
||||
cudnn.enabled = True
|
||||
cudnn.benchmark = True
|
||||
|
||||
# Model Initialization
|
||||
AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain)
|
||||
model = DataParallelModel(AugmentCE2P)
|
||||
model.cuda()
|
||||
|
||||
IMAGE_MEAN = AugmentCE2P.mean
|
||||
IMAGE_STD = AugmentCE2P.std
|
||||
INPUT_SPACE = AugmentCE2P.input_space
|
||||
print('image mean: {}'.format(IMAGE_MEAN))
|
||||
print('image std: {}'.format(IMAGE_STD))
|
||||
print('input space:{}'.format(INPUT_SPACE))
|
||||
|
||||
restore_from = args.model_restore
|
||||
if os.path.exists(restore_from):
|
||||
print('Resume training from {}'.format(restore_from))
|
||||
checkpoint = torch.load(restore_from)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
start_epoch = checkpoint['epoch']
|
||||
|
||||
SCHP_AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain)
|
||||
schp_model = DataParallelModel(SCHP_AugmentCE2P)
|
||||
schp_model.cuda()
|
||||
|
||||
if os.path.exists(args.schp_restore):
|
||||
print('Resuming schp checkpoint from {}'.format(args.schp_restore))
|
||||
schp_checkpoint = torch.load(args.schp_restore)
|
||||
schp_model_state_dict = schp_checkpoint['state_dict']
|
||||
cycle_n = schp_checkpoint['cycle_n']
|
||||
schp_model.load_state_dict(schp_model_state_dict)
|
||||
|
||||
# Loss Function
|
||||
criterion = CriterionAll(lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c,
|
||||
num_classes=args.num_classes)
|
||||
criterion = DataParallelCriterion(criterion)
|
||||
criterion.cuda()
|
||||
|
||||
# Data Loader
|
||||
if INPUT_SPACE == 'BGR':
|
||||
print('BGR Transformation')
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=IMAGE_MEAN,
|
||||
std=IMAGE_STD),
|
||||
])
|
||||
|
||||
elif INPUT_SPACE == 'RGB':
|
||||
print('RGB Transformation')
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
BGR2RGB_transform(),
|
||||
transforms.Normalize(mean=IMAGE_MEAN,
|
||||
std=IMAGE_STD),
|
||||
])
|
||||
|
||||
train_dataset = LIPDataSet(args.data_dir, args.split_name, crop_size=input_size, transform=transform)
|
||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size * len(gpus),
|
||||
num_workers=16, shuffle=True, pin_memory=True, drop_last=True)
|
||||
print('Total training samples: {}'.format(len(train_dataset)))
|
||||
|
||||
# Optimizer Initialization
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs,
|
||||
eta_min=args.learning_rate / 100, warmup_epoch=10,
|
||||
start_cyclical=args.schp_start, cyclical_base_lr=args.learning_rate / 2,
|
||||
cyclical_epoch=args.cycle_epochs)
|
||||
|
||||
total_iters = args.epochs * len(train_loader)
|
||||
start = timeit.default_timer()
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
lr_scheduler.step(epoch=epoch)
|
||||
lr = lr_scheduler.get_lr()[0]
|
||||
|
||||
model.train()
|
||||
for i_iter, batch in enumerate(train_loader):
|
||||
i_iter += len(train_loader) * epoch
|
||||
|
||||
images, labels, _ = batch
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
|
||||
edges = generate_edge_tensor(labels)
|
||||
labels = labels.type(torch.cuda.LongTensor)
|
||||
edges = edges.type(torch.cuda.LongTensor)
|
||||
|
||||
preds = model(images)
|
||||
|
||||
# Online Self Correction Cycle with Label Refinement
|
||||
if cycle_n >= 1:
|
||||
with torch.no_grad():
|
||||
soft_preds = schp_model(images)
|
||||
soft_parsing = []
|
||||
soft_edge = []
|
||||
for soft_pred in soft_preds:
|
||||
soft_parsing.append(soft_pred[0][-1])
|
||||
soft_edge.append(soft_pred[1][-1])
|
||||
soft_preds = torch.cat(soft_parsing, dim=0)
|
||||
soft_edges = torch.cat(soft_edge, dim=0)
|
||||
else:
|
||||
soft_preds = None
|
||||
soft_edges = None
|
||||
|
||||
loss = criterion(preds, [labels, edges, soft_preds, soft_edges], cycle_n)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i_iter % 100 == 0:
|
||||
print('iter = {} of {} completed, lr = {}, loss = {}'.format(i_iter, total_iters, lr,
|
||||
loss.data.cpu().numpy()))
|
||||
if (epoch + 1) % (args.eval_epochs) == 0:
|
||||
schp.save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
}, False, args.log_dir, filename='checkpoint_{}.pth.tar'.format(epoch + 1))
|
||||
|
||||
# Self Correction Cycle with Model Aggregation
|
||||
if (epoch + 1) >= args.schp_start and (epoch + 1 - args.schp_start) % args.cycle_epochs == 0:
|
||||
print('Self-correction cycle number {}'.format(cycle_n))
|
||||
schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
|
||||
cycle_n += 1
|
||||
schp.bn_re_estimate(train_loader, schp_model)
|
||||
schp.save_schp_checkpoint({
|
||||
'state_dict': schp_model.state_dict(),
|
||||
'cycle_n': cycle_n,
|
||||
}, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
end = timeit.default_timer()
|
||||
print('epoch = {} of {} completed using {} s'.format(epoch, args.epochs,
|
||||
(end - start) / (epoch - start_epoch + 1)))
|
||||
|
||||
end = timeit.default_timer()
|
||||
print('Training Finished in {} seconds'.format(end - start))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
|
||||
DATASET = 'VIP' # DATASET: MHPv2 or CIHP or VIP
|
||||
TYPE = 'crop_pic' # crop_pic or DemoDataset
|
||||
IMG_DIR = '../demo/cropped_img/crop_pic'
|
||||
SAVE_DIR = '../demo/cropped_img'
|
||||
|
||||
if not os.path.exists(SAVE_DIR):
|
||||
os.makedirs(SAVE_DIR)
|
||||
|
||||
with open(os.path.join(SAVE_DIR, TYPE + '.txt'), "w") as f:
|
||||
for img_name in os.listdir(IMG_DIR):
|
||||
f.write(img_name[:-4] + '\n')
|
Reference in New Issue
Block a user