168 lines
4.9 KiB
Python
168 lines
4.9 KiB
Python
#!/usr/bin/env python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
"""
|
|
A script to benchmark builtin models.
|
|
|
|
Note: this script has an extra dependency of psutil.
|
|
"""
|
|
|
|
import itertools
|
|
import logging
|
|
import psutil
|
|
import torch
|
|
import tqdm
|
|
from fvcore.common.timer import Timer
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from detectron2.checkpoint import DetectionCheckpointer
|
|
from detectron2.config import get_cfg
|
|
from detectron2.data import (
|
|
DatasetFromList,
|
|
build_detection_test_loader,
|
|
build_detection_train_loader,
|
|
)
|
|
from detectron2.engine import SimpleTrainer, default_argument_parser, hooks, launch
|
|
from detectron2.modeling import build_model
|
|
from detectron2.solver import build_optimizer
|
|
from detectron2.utils import comm
|
|
from detectron2.utils.events import CommonMetricPrinter
|
|
from detectron2.utils.logger import setup_logger
|
|
|
|
logger = logging.getLogger("detectron2")
|
|
|
|
|
|
def setup(args):
|
|
cfg = get_cfg()
|
|
cfg.merge_from_file(args.config_file)
|
|
cfg.SOLVER.BASE_LR = 0.001 # Avoid NaNs. Not useful in this script anyway.
|
|
cfg.merge_from_list(args.opts)
|
|
cfg.freeze()
|
|
setup_logger(distributed_rank=comm.get_rank())
|
|
return cfg
|
|
|
|
|
|
def benchmark_data(args):
|
|
cfg = setup(args)
|
|
|
|
timer = Timer()
|
|
dataloader = build_detection_train_loader(cfg)
|
|
logger.info("Initialize loader using {} seconds.".format(timer.seconds()))
|
|
|
|
timer.reset()
|
|
itr = iter(dataloader)
|
|
for i in range(10): # warmup
|
|
next(itr)
|
|
if i == 0:
|
|
startup_time = timer.seconds()
|
|
timer = Timer()
|
|
max_iter = 1000
|
|
for _ in tqdm.trange(max_iter):
|
|
next(itr)
|
|
logger.info(
|
|
"{} iters ({} images) in {} seconds.".format(
|
|
max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds()
|
|
)
|
|
)
|
|
logger.info("Startup time: {} seconds".format(startup_time))
|
|
vram = psutil.virtual_memory()
|
|
logger.info(
|
|
"RAM Usage: {:.2f}/{:.2f} GB".format(
|
|
(vram.total - vram.available) / 1024 ** 3, vram.total / 1024 ** 3
|
|
)
|
|
)
|
|
|
|
# test for a few more rounds
|
|
for _ in range(10):
|
|
timer = Timer()
|
|
max_iter = 1000
|
|
for _ in tqdm.trange(max_iter):
|
|
next(itr)
|
|
logger.info(
|
|
"{} iters ({} images) in {} seconds.".format(
|
|
max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds()
|
|
)
|
|
)
|
|
|
|
|
|
def benchmark_train(args):
|
|
cfg = setup(args)
|
|
model = build_model(cfg)
|
|
logger.info("Model:\n{}".format(model))
|
|
if comm.get_world_size() > 1:
|
|
model = DistributedDataParallel(
|
|
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
|
|
)
|
|
optimizer = build_optimizer(cfg, model)
|
|
checkpointer = DetectionCheckpointer(model, optimizer=optimizer)
|
|
checkpointer.load(cfg.MODEL.WEIGHTS)
|
|
|
|
cfg.defrost()
|
|
cfg.DATALOADER.NUM_WORKERS = 0
|
|
data_loader = build_detection_train_loader(cfg)
|
|
dummy_data = list(itertools.islice(data_loader, 100))
|
|
|
|
def f():
|
|
data = DatasetFromList(dummy_data, copy=False)
|
|
while True:
|
|
yield from data
|
|
|
|
max_iter = 400
|
|
trainer = SimpleTrainer(model, f(), optimizer)
|
|
trainer.register_hooks(
|
|
[hooks.IterationTimer(), hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])]
|
|
)
|
|
trainer.train(1, max_iter)
|
|
|
|
|
|
@torch.no_grad()
|
|
def benchmark_eval(args):
|
|
cfg = setup(args)
|
|
model = build_model(cfg)
|
|
model.eval()
|
|
logger.info("Model:\n{}".format(model))
|
|
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
|
|
|
|
cfg.defrost()
|
|
cfg.DATALOADER.NUM_WORKERS = 0
|
|
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
|
|
dummy_data = list(itertools.islice(data_loader, 100))
|
|
|
|
def f():
|
|
while True:
|
|
yield from DatasetFromList(dummy_data, copy=False)
|
|
|
|
for _ in range(5): # warmup
|
|
model(dummy_data[0])
|
|
|
|
max_iter = 400
|
|
timer = Timer()
|
|
with tqdm.tqdm(total=max_iter) as pbar:
|
|
for idx, d in enumerate(f()):
|
|
if idx == max_iter:
|
|
break
|
|
model(d)
|
|
pbar.update()
|
|
logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds()))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = default_argument_parser()
|
|
parser.add_argument("--task", choices=["train", "eval", "data"], required=True)
|
|
args = parser.parse_args()
|
|
assert not args.eval_only
|
|
|
|
if args.task == "data":
|
|
f = benchmark_data
|
|
elif args.task == "train":
|
|
"""
|
|
Note: training speed may not be representative.
|
|
The training cost of a R-CNN model varies with the content of the data
|
|
and the quality of the model.
|
|
"""
|
|
f = benchmark_train
|
|
elif args.task == "eval":
|
|
f = benchmark_eval
|
|
# only benchmark single-GPU inference.
|
|
assert args.num_gpus == 1 and args.num_machines == 1
|
|
launch(f, args.num_gpus, args.num_machines, args.machine_rank, args.dist_url, args=(args,))
|