Add at new repo again
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
|
||||
## Detectron2 Demo
|
||||
|
||||
We provide a command line tool to run a simple demo of builtin models.
|
||||
The usage is explained in [GETTING_STARTED.md](../GETTING_STARTED.md).
|
||||
|
||||
See our [blog post](https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-)
|
||||
for a high-quality demo generated with this tool.
|
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import argparse
|
||||
import glob
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
import cv2
|
||||
import tqdm
|
||||
|
||||
from detectron2.config import get_cfg
|
||||
from detectron2.data.detection_utils import read_image
|
||||
from detectron2.utils.logger import setup_logger
|
||||
|
||||
from predictor import VisualizationDemo
|
||||
|
||||
# constants
|
||||
WINDOW_NAME = "COCO detections"
|
||||
|
||||
|
||||
def setup_cfg(args):
|
||||
# load config from file and command-line arguments
|
||||
cfg = get_cfg()
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
# Set score_threshold for builtin models
|
||||
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
|
||||
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
|
||||
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
|
||||
cfg.freeze()
|
||||
return cfg
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin models")
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
)
|
||||
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
|
||||
parser.add_argument("--video-input", help="Path to video file.")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
nargs="+",
|
||||
help="A list of space separated input images; "
|
||||
"or a single glob pattern such as 'directory/*.jpg'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
help="A file or directory to save output visualizations. "
|
||||
"If not given, will show output in an OpenCV window.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Minimum score for instance predictions to be shown",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
||||
default=[],
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn", force=True)
|
||||
args = get_parser().parse_args()
|
||||
setup_logger(name="fvcore")
|
||||
logger = setup_logger()
|
||||
logger.info("Arguments: " + str(args))
|
||||
|
||||
cfg = setup_cfg(args)
|
||||
|
||||
demo = VisualizationDemo(cfg)
|
||||
|
||||
if args.input:
|
||||
if len(args.input) == 1:
|
||||
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
||||
assert args.input, "The input path(s) was not found"
|
||||
for path in tqdm.tqdm(args.input, disable=not args.output):
|
||||
# use PIL, to be consistent with evaluation
|
||||
img = read_image(path, format="BGR")
|
||||
start_time = time.time()
|
||||
predictions, visualized_output = demo.run_on_image(img)
|
||||
logger.info(
|
||||
"{}: {} in {:.2f}s".format(
|
||||
path,
|
||||
"detected {} instances".format(len(predictions["instances"]))
|
||||
if "instances" in predictions
|
||||
else "finished",
|
||||
time.time() - start_time,
|
||||
)
|
||||
)
|
||||
|
||||
if args.output:
|
||||
if os.path.isdir(args.output):
|
||||
assert os.path.isdir(args.output), args.output
|
||||
out_filename = os.path.join(args.output, os.path.basename(path))
|
||||
else:
|
||||
assert len(args.input) == 1, "Please specify a directory with args.output"
|
||||
out_filename = args.output
|
||||
visualized_output.save(out_filename)
|
||||
else:
|
||||
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
||||
cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
|
||||
if cv2.waitKey(0) == 27:
|
||||
break # esc to quit
|
||||
elif args.webcam:
|
||||
assert args.input is None, "Cannot have both --input and --webcam!"
|
||||
assert args.output is None, "output not yet supported with --webcam!"
|
||||
cam = cv2.VideoCapture(0)
|
||||
for vis in tqdm.tqdm(demo.run_on_video(cam)):
|
||||
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
||||
cv2.imshow(WINDOW_NAME, vis)
|
||||
if cv2.waitKey(1) == 27:
|
||||
break # esc to quit
|
||||
cam.release()
|
||||
cv2.destroyAllWindows()
|
||||
elif args.video_input:
|
||||
video = cv2.VideoCapture(args.video_input)
|
||||
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames_per_second = video.get(cv2.CAP_PROP_FPS)
|
||||
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
basename = os.path.basename(args.video_input)
|
||||
|
||||
if args.output:
|
||||
if os.path.isdir(args.output):
|
||||
output_fname = os.path.join(args.output, basename)
|
||||
output_fname = os.path.splitext(output_fname)[0] + ".mkv"
|
||||
else:
|
||||
output_fname = args.output
|
||||
assert not os.path.isfile(output_fname), output_fname
|
||||
output_file = cv2.VideoWriter(
|
||||
filename=output_fname,
|
||||
# some installation of opencv may not support x264 (due to its license),
|
||||
# you can try other format (e.g. MPEG)
|
||||
fourcc=cv2.VideoWriter_fourcc(*"x264"),
|
||||
fps=float(frames_per_second),
|
||||
frameSize=(width, height),
|
||||
isColor=True,
|
||||
)
|
||||
assert os.path.isfile(args.video_input)
|
||||
for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
|
||||
if args.output:
|
||||
output_file.write(vis_frame)
|
||||
else:
|
||||
cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
|
||||
cv2.imshow(basename, vis_frame)
|
||||
if cv2.waitKey(1) == 27:
|
||||
break # esc to quit
|
||||
video.release()
|
||||
if args.output:
|
||||
output_file.release()
|
||||
else:
|
||||
cv2.destroyAllWindows()
|
@@ -0,0 +1,220 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import atexit
|
||||
import bisect
|
||||
import multiprocessing as mp
|
||||
from collections import deque
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from detectron2.data import MetadataCatalog
|
||||
from detectron2.engine.defaults import DefaultPredictor
|
||||
from detectron2.utils.video_visualizer import VideoVisualizer
|
||||
from detectron2.utils.visualizer import ColorMode, Visualizer
|
||||
|
||||
|
||||
class VisualizationDemo(object):
|
||||
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
|
||||
"""
|
||||
Args:
|
||||
cfg (CfgNode):
|
||||
instance_mode (ColorMode):
|
||||
parallel (bool): whether to run the model in different processes from visualization.
|
||||
Useful since the visualization logic can be slow.
|
||||
"""
|
||||
self.metadata = MetadataCatalog.get(
|
||||
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
|
||||
)
|
||||
self.cpu_device = torch.device("cpu")
|
||||
self.instance_mode = instance_mode
|
||||
|
||||
self.parallel = parallel
|
||||
if parallel:
|
||||
num_gpu = torch.cuda.device_count()
|
||||
self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
|
||||
else:
|
||||
self.predictor = DefaultPredictor(cfg)
|
||||
|
||||
def run_on_image(self, image):
|
||||
"""
|
||||
Args:
|
||||
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
||||
This is the format used by OpenCV.
|
||||
|
||||
Returns:
|
||||
predictions (dict): the output of the model.
|
||||
vis_output (VisImage): the visualized image output.
|
||||
"""
|
||||
vis_output = None
|
||||
predictions = self.predictor(image)
|
||||
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
||||
image = image[:, :, ::-1]
|
||||
visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
|
||||
if "panoptic_seg" in predictions:
|
||||
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
||||
vis_output = visualizer.draw_panoptic_seg_predictions(
|
||||
panoptic_seg.to(self.cpu_device), segments_info
|
||||
)
|
||||
else:
|
||||
if "sem_seg" in predictions:
|
||||
vis_output = visualizer.draw_sem_seg(
|
||||
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
||||
)
|
||||
if "instances" in predictions:
|
||||
instances = predictions["instances"].to(self.cpu_device)
|
||||
vis_output = visualizer.draw_instance_predictions(predictions=instances)
|
||||
|
||||
return predictions, vis_output
|
||||
|
||||
def _frame_from_video(self, video):
|
||||
while video.isOpened():
|
||||
success, frame = video.read()
|
||||
if success:
|
||||
yield frame
|
||||
else:
|
||||
break
|
||||
|
||||
def run_on_video(self, video):
|
||||
"""
|
||||
Visualizes predictions on frames of the input video.
|
||||
|
||||
Args:
|
||||
video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
|
||||
either a webcam or a video file.
|
||||
|
||||
Yields:
|
||||
ndarray: BGR visualizations of each video frame.
|
||||
"""
|
||||
video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
|
||||
|
||||
def process_predictions(frame, predictions):
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
if "panoptic_seg" in predictions:
|
||||
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
||||
vis_frame = video_visualizer.draw_panoptic_seg_predictions(
|
||||
frame, panoptic_seg.to(self.cpu_device), segments_info
|
||||
)
|
||||
elif "instances" in predictions:
|
||||
predictions = predictions["instances"].to(self.cpu_device)
|
||||
vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
|
||||
elif "sem_seg" in predictions:
|
||||
vis_frame = video_visualizer.draw_sem_seg(
|
||||
frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
||||
)
|
||||
|
||||
# Converts Matplotlib RGB format to OpenCV BGR format
|
||||
vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
|
||||
return vis_frame
|
||||
|
||||
frame_gen = self._frame_from_video(video)
|
||||
if self.parallel:
|
||||
buffer_size = self.predictor.default_buffer_size
|
||||
|
||||
frame_data = deque()
|
||||
|
||||
for cnt, frame in enumerate(frame_gen):
|
||||
frame_data.append(frame)
|
||||
self.predictor.put(frame)
|
||||
|
||||
if cnt >= buffer_size:
|
||||
frame = frame_data.popleft()
|
||||
predictions = self.predictor.get()
|
||||
yield process_predictions(frame, predictions)
|
||||
|
||||
while len(frame_data):
|
||||
frame = frame_data.popleft()
|
||||
predictions = self.predictor.get()
|
||||
yield process_predictions(frame, predictions)
|
||||
else:
|
||||
for frame in frame_gen:
|
||||
yield process_predictions(frame, self.predictor(frame))
|
||||
|
||||
|
||||
class AsyncPredictor:
|
||||
"""
|
||||
A predictor that runs the model asynchronously, possibly on >1 GPUs.
|
||||
Because rendering the visualization takes considerably amount of time,
|
||||
this helps improve throughput when rendering videos.
|
||||
"""
|
||||
|
||||
class _StopToken:
|
||||
pass
|
||||
|
||||
class _PredictWorker(mp.Process):
|
||||
def __init__(self, cfg, task_queue, result_queue):
|
||||
self.cfg = cfg
|
||||
self.task_queue = task_queue
|
||||
self.result_queue = result_queue
|
||||
super().__init__()
|
||||
|
||||
def run(self):
|
||||
predictor = DefaultPredictor(self.cfg)
|
||||
|
||||
while True:
|
||||
task = self.task_queue.get()
|
||||
if isinstance(task, AsyncPredictor._StopToken):
|
||||
break
|
||||
idx, data = task
|
||||
result = predictor(data)
|
||||
self.result_queue.put((idx, result))
|
||||
|
||||
def __init__(self, cfg, num_gpus: int = 1):
|
||||
"""
|
||||
Args:
|
||||
cfg (CfgNode):
|
||||
num_gpus (int): if 0, will run on CPU
|
||||
"""
|
||||
num_workers = max(num_gpus, 1)
|
||||
self.task_queue = mp.Queue(maxsize=num_workers * 3)
|
||||
self.result_queue = mp.Queue(maxsize=num_workers * 3)
|
||||
self.procs = []
|
||||
for gpuid in range(max(num_gpus, 1)):
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
|
||||
self.procs.append(
|
||||
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
|
||||
)
|
||||
|
||||
self.put_idx = 0
|
||||
self.get_idx = 0
|
||||
self.result_rank = []
|
||||
self.result_data = []
|
||||
|
||||
for p in self.procs:
|
||||
p.start()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
def put(self, image):
|
||||
self.put_idx += 1
|
||||
self.task_queue.put((self.put_idx, image))
|
||||
|
||||
def get(self):
|
||||
self.get_idx += 1 # the index needed for this request
|
||||
if len(self.result_rank) and self.result_rank[0] == self.get_idx:
|
||||
res = self.result_data[0]
|
||||
del self.result_data[0], self.result_rank[0]
|
||||
return res
|
||||
|
||||
while True:
|
||||
# make sure the results are returned in the correct order
|
||||
idx, res = self.result_queue.get()
|
||||
if idx == self.get_idx:
|
||||
return res
|
||||
insert = bisect.bisect(self.result_rank, idx)
|
||||
self.result_rank.insert(insert, idx)
|
||||
self.result_data.insert(insert, res)
|
||||
|
||||
def __len__(self):
|
||||
return self.put_idx - self.get_idx
|
||||
|
||||
def __call__(self, image):
|
||||
self.put(image)
|
||||
return self.get()
|
||||
|
||||
def shutdown(self):
|
||||
for _ in self.procs:
|
||||
self.task_queue.put(AsyncPredictor._StopToken())
|
||||
|
||||
@property
|
||||
def default_buffer_size(self):
|
||||
return len(self.procs) * 5
|
Reference in New Issue
Block a user