nsfw-api/nsfw_detector/predict.py
2024-07-19 23:10:58 +03:00

120 lines
4.2 KiB
Python

import argparse
import json
from os import listdir
from os.path import isfile, join, exists, isdir, abspath
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub
IMAGE_DIM = 224 # required/default image dimensionality
def load_images(image_paths, image_size, verbose=True):
'''
Function for loading images into numpy arrays for passing to model.predict
inputs:
image_paths: list of image paths to load
image_size: size into which images should be resized
verbose: show all of the image path and sizes loaded
outputs:
loaded_images: loaded images on which keras model can run predictions
loaded_image_indexes: paths of images which the function is able to process
'''
loaded_images = []
loaded_image_paths = []
if isdir(image_paths):
parent = abspath(image_paths)
image_paths = [join(parent, f) for f in listdir(
image_paths) if isfile(join(parent, f))]
elif isfile(image_paths):
image_paths = [image_paths]
for img_path in image_paths:
try:
if verbose:
print(img_path, "size:", image_size)
image = keras.preprocessing.image.load_img(
img_path, target_size=image_size)
image = keras.preprocessing.image.img_to_array(image)
image /= 255
loaded_images.append(image)
loaded_image_paths.append(img_path)
except Exception as ex:
print("Image Load Failure: ", img_path, ex)
return np.asarray(loaded_images), loaded_image_paths
def load_model(model_path):
if model_path is None or not exists(model_path):
raise ValueError(
"saved_model_path must be the valid directory of a saved model to load.")
model = tf.keras.models.load_model(model_path, custom_objects={
'KerasLayer': hub.KerasLayer})
return model
def classify(model, input_paths, image_dim=IMAGE_DIM):
""" Classify given a model, input paths (could be single string), and image dimensionality...."""
images, image_paths = load_images(input_paths, (image_dim, image_dim))
probs = classify_nd(model, images)
return dict(zip(['data'], probs))
def classify_nd(model, nd_images):
""" Classify given a model, image array (numpy)...."""
model_preds = model.predict(nd_images)
# preds = np.argsort(model_preds, axis = 1).tolist()
categories = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']
probs = []
for i, single_preds in enumerate(model_preds):
single_probs = {}
for j, pred in enumerate(single_preds):
single_probs[categories[j]] = round(float(pred), 6) * 100
probs.append(single_probs)
return probs
def main(args=None):
parser = argparse.ArgumentParser(
description="""A script to perform NFSW classification of images""",
epilog="""
Launch with default model and a test image
python nsfw_detector/predict.py --saved_model_path mobilenet_v2_140_224 --image_source test.jpg
""", formatter_class=argparse.RawTextHelpFormatter)
submain = parser.add_argument_group(
'main execution and evaluation functionality')
submain.add_argument('--image_source', dest='image_source', type=str, required=True,
help='A directory of images or a single image to classify')
submain.add_argument('--saved_model_path', dest='saved_model_path', type=str, required=True,
help='The model to load')
submain.add_argument('--image_dim', dest='image_dim', type=int, default=IMAGE_DIM,
help="The square dimension of the model's input shape")
if args is not None:
config = vars(parser.parse_args(args))
else:
config = vars(parser.parse_args())
if config['image_source'] is None or not exists(config['image_source']):
raise ValueError(
"image_source must be a valid directory with images or a single image to classify.")
model = load_model(config['saved_model_path'])
image_preds = classify(model, config['image_source'], config['image_dim'])
print(json.dumps(image_preds, indent=2), '\n')
if __name__ == "__main__":
main()