Add at new repo again
This commit is contained in:
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
from detectron2.config import get_cfg
|
||||
from detectron2.engine import default_setup
|
||||
from detectron2.modeling import build_model
|
||||
|
||||
from densepose import add_dataset_category_config, add_densepose_config
|
||||
|
||||
_BASE_CONFIG_DIR = "configs"
|
||||
_EVOLUTION_CONFIG_SUB_DIR = "evolution"
|
||||
_QUICK_SCHEDULES_CONFIG_SUB_DIR = "quick_schedules"
|
||||
_BASE_CONFIG_FILE_PREFIX = "Base-"
|
||||
_CONFIG_FILE_EXT = ".yaml"
|
||||
|
||||
|
||||
def _get_base_config_dir():
|
||||
"""
|
||||
Return the base directory for configurations
|
||||
"""
|
||||
return os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", _BASE_CONFIG_DIR)
|
||||
|
||||
|
||||
def _get_evolution_config_dir():
|
||||
"""
|
||||
Return the base directory for evolution configurations
|
||||
"""
|
||||
return os.path.join(_get_base_config_dir(), _EVOLUTION_CONFIG_SUB_DIR)
|
||||
|
||||
|
||||
def _get_quick_schedules_config_dir():
|
||||
"""
|
||||
Return the base directory for quick schedules configurations
|
||||
"""
|
||||
return os.path.join(_get_base_config_dir(), _QUICK_SCHEDULES_CONFIG_SUB_DIR)
|
||||
|
||||
|
||||
def _collect_config_files(config_dir):
|
||||
"""
|
||||
Collect all configuration files (i.e. densepose_*.yaml) directly in the specified directory
|
||||
"""
|
||||
start = _get_base_config_dir()
|
||||
results = []
|
||||
for entry in os.listdir(config_dir):
|
||||
path = os.path.join(config_dir, entry)
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
_, ext = os.path.splitext(entry)
|
||||
if ext != _CONFIG_FILE_EXT:
|
||||
continue
|
||||
if entry.startswith(_BASE_CONFIG_FILE_PREFIX):
|
||||
continue
|
||||
config_file = os.path.relpath(path, start)
|
||||
results.append(config_file)
|
||||
return results
|
||||
|
||||
|
||||
def get_config_files():
|
||||
"""
|
||||
Get all the configuration files (relative to the base configuration directory)
|
||||
"""
|
||||
return _collect_config_files(_get_base_config_dir())
|
||||
|
||||
|
||||
def get_evolution_config_files():
|
||||
"""
|
||||
Get all the evolution configuration files (relative to the base configuration directory)
|
||||
"""
|
||||
return _collect_config_files(_get_evolution_config_dir())
|
||||
|
||||
|
||||
def get_quick_schedules_config_files():
|
||||
"""
|
||||
Get all the quick schedules configuration files (relative to the base configuration directory)
|
||||
"""
|
||||
return _collect_config_files(_get_quick_schedules_config_dir())
|
||||
|
||||
|
||||
def _get_model_config(config_file):
|
||||
"""
|
||||
Load and return the configuration from the specified file (relative to the base configuration
|
||||
directory)
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
add_dataset_category_config(cfg)
|
||||
add_densepose_config(cfg)
|
||||
path = os.path.join(_get_base_config_dir(), config_file)
|
||||
cfg.merge_from_file(path)
|
||||
if not torch.cuda.is_available():
|
||||
cfg.MODEL_DEVICE = "cpu"
|
||||
return cfg
|
||||
|
||||
|
||||
def get_model(config_file):
|
||||
"""
|
||||
Get the model from the specified file (relative to the base configuration directory)
|
||||
"""
|
||||
cfg = _get_model_config(config_file)
|
||||
return build_model(cfg)
|
||||
|
||||
|
||||
def setup(config_file):
|
||||
"""
|
||||
Setup the configuration from the specified file (relative to the base configuration directory)
|
||||
"""
|
||||
cfg = _get_model_config(config_file)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, {})
|
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from detectron2.structures import BitMasks, Boxes, Instances
|
||||
|
||||
from .common import get_model
|
||||
|
||||
|
||||
# TODO(plabatut): Modularize detectron2 tests and re-use
|
||||
def make_model_inputs(image, instances=None):
|
||||
if instances is None:
|
||||
return {"image": image}
|
||||
|
||||
return {"image": image, "instances": instances}
|
||||
|
||||
|
||||
def make_empty_instances(h, w):
|
||||
instances = Instances((h, w))
|
||||
instances.gt_boxes = Boxes(torch.rand(0, 4))
|
||||
instances.gt_classes = torch.tensor([]).to(dtype=torch.int64)
|
||||
instances.gt_masks = BitMasks(torch.rand(0, h, w))
|
||||
return instances
|
||||
|
||||
|
||||
class ModelE2ETest(unittest.TestCase):
|
||||
CONFIG_PATH = ""
|
||||
|
||||
def setUp(self):
|
||||
self.model = get_model(self.CONFIG_PATH)
|
||||
|
||||
def _test_eval(self, sizes):
|
||||
inputs = [make_model_inputs(torch.rand(3, size[0], size[1])) for size in sizes]
|
||||
self.model.eval()
|
||||
self.model(inputs)
|
||||
|
||||
|
||||
class DensePoseRCNNE2ETest(ModelE2ETest):
|
||||
CONFIG_PATH = "densepose_rcnn_R_101_FPN_s1x.yaml"
|
||||
|
||||
def test_empty_data(self):
|
||||
self._test_eval([(200, 250), (200, 249)])
|
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
from .common import (
|
||||
get_config_files,
|
||||
get_evolution_config_files,
|
||||
get_quick_schedules_config_files,
|
||||
setup,
|
||||
)
|
||||
|
||||
|
||||
class TestSetup(unittest.TestCase):
|
||||
def _test_setup(self, config_file):
|
||||
setup(config_file)
|
||||
|
||||
def test_setup_configs(self):
|
||||
config_files = get_config_files()
|
||||
for config_file in config_files:
|
||||
self._test_setup(config_file)
|
||||
|
||||
def test_setup_evolution_configs(self):
|
||||
config_files = get_evolution_config_files()
|
||||
for config_file in config_files:
|
||||
self._test_setup(config_file)
|
||||
|
||||
def test_setup_quick_schedules_configs(self):
|
||||
config_files = get_quick_schedules_config_files()
|
||||
for config_file in config_files:
|
||||
self._test_setup(config_file)
|
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
from densepose.data.structures import normalized_coords_transform
|
||||
|
||||
|
||||
class TestStructures(unittest.TestCase):
|
||||
def test_normalized_coords_transform(self):
|
||||
bbox = (32, 24, 288, 216)
|
||||
x0, y0, w, h = bbox
|
||||
xmin, ymin, xmax, ymax = x0, y0, x0 + w, y0 + h
|
||||
f = normalized_coords_transform(*bbox)
|
||||
# Top-left
|
||||
expected_p, actual_p = (-1, -1), f((xmin, ymin))
|
||||
self.assertEqual(expected_p, actual_p)
|
||||
# Top-right
|
||||
expected_p, actual_p = (1, -1), f((xmax, ymin))
|
||||
self.assertEqual(expected_p, actual_p)
|
||||
# Bottom-left
|
||||
expected_p, actual_p = (-1, 1), f((xmin, ymax))
|
||||
self.assertEqual(expected_p, actual_p)
|
||||
# Bottom-right
|
||||
expected_p, actual_p = (1, 1), f((xmax, ymax))
|
||||
self.assertEqual(expected_p, actual_p)
|
Reference in New Issue
Block a user