81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
"""
|
|
@Author : Peike Li
|
|
@Contact : peike.li@yahoo.com
|
|
@File : schp.py
|
|
@Time : 4/8/19 2:11 PM
|
|
@Desc :
|
|
@License : This source code is licensed under the license found in the
|
|
LICENSE file in the root directory of this source tree.
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
import modules
|
|
|
|
def moving_average(net1, net2, alpha=1):
|
|
for param1, param2 in zip(net1.parameters(), net2.parameters()):
|
|
param1.data *= (1.0 - alpha)
|
|
param1.data += param2.data * alpha
|
|
|
|
|
|
def _check_bn(module, flag):
|
|
if issubclass(module.__class__, modules.bn.InPlaceABNSync):
|
|
flag[0] = True
|
|
|
|
|
|
def check_bn(model):
|
|
flag = [False]
|
|
model.apply(lambda module: _check_bn(module, flag))
|
|
return flag[0]
|
|
|
|
|
|
def reset_bn(module):
|
|
if issubclass(module.__class__, modules.bn.InPlaceABNSync):
|
|
module.running_mean = torch.zeros_like(module.running_mean)
|
|
module.running_var = torch.ones_like(module.running_var)
|
|
|
|
|
|
def _get_momenta(module, momenta):
|
|
if issubclass(module.__class__, modules.bn.InPlaceABNSync):
|
|
momenta[module] = module.momentum
|
|
|
|
|
|
def _set_momenta(module, momenta):
|
|
if issubclass(module.__class__, modules.bn.InPlaceABNSync):
|
|
module.momentum = momenta[module]
|
|
|
|
|
|
def bn_re_estimate(loader, model):
|
|
if not check_bn(model):
|
|
print('No batch norm layer detected')
|
|
return
|
|
model.train()
|
|
momenta = {}
|
|
model.apply(reset_bn)
|
|
model.apply(lambda module: _get_momenta(module, momenta))
|
|
n = 0
|
|
for i_iter, batch in enumerate(loader):
|
|
images, labels, _ = batch
|
|
b = images.data.size(0)
|
|
momentum = b / (n + b)
|
|
for module in momenta.keys():
|
|
module.momentum = momentum
|
|
model(images)
|
|
n += b
|
|
model.apply(lambda module: _set_momenta(module, momenta))
|
|
|
|
|
|
def save_schp_checkpoint(states, is_best_parsing, output_dir, filename='schp_checkpoint.pth.tar'):
|
|
save_path = os.path.join(output_dir, filename)
|
|
if os.path.exists(save_path):
|
|
os.remove(save_path)
|
|
torch.save(states, save_path)
|
|
if is_best_parsing and 'state_dict' in states:
|
|
best_save_path = os.path.join(output_dir, 'model_parsing_best.pth.tar')
|
|
if os.path.exists(best_save_path):
|
|
os.remove(best_save_path)
|
|
torch.save(states, best_save_path)
|