Skip to content

Commit

Permalink
load image either from json or form data
Browse files Browse the repository at this point in the history
  • Loading branch information
serengil committed Nov 10, 2024
1 parent 8b2475a commit 60067e2
Showing 1 changed file with 97 additions and 31 deletions.
128 changes: 97 additions & 31 deletions deepface/api/src/modules/core/routes.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,86 @@
# built-in dependencies
from typing import Union

# 3rd party dependencies
from flask import Blueprint, request
import numpy as np

# project dependencies
from deepface import DeepFace
from deepface.api.src.modules.core import service
from deepface.commons import image_utils
from deepface.commons.logger import Logger

logger = Logger()

blueprint = Blueprint("routes", __name__)

# pylint: disable=no-else-return, broad-except


@blueprint.route("/")
def home():
return f"<h1>Welcome to DeepFace API v{DeepFace.__version__}!</h1>"


def extract_image_from_request(img_key: str) -> Union[str, np.ndarray]:
"""
Extracts an image from the request either from json or a multipart/form-data file.
Args:
img_key (str): The key used to retrieve the image data
from the request (e.g., 'img1').
Returns:
img (str or np.ndarray): Given image detail (base64 encoded string, image path or url)
or the decoded image as a numpy array.
"""

# Check if the request is multipart/form-data (file input)
if request.files:
# request.files is instance of werkzeug.datastructures.ImmutableMultiDict
# file is instance of werkzeug.datastructures.FileStorage
file = request.files.get(img_key)

if file is None:
raise ValueError(f"Request form data doesn't have {img_key}")

if file.filename == "":
raise ValueError(f"No file uploaded for '{img_key}'")

img = image_utils.load_image_from_file_storage(file)

return img
# Check if the request is coming as base64, file path or url from json or form data
elif request.is_json or request.form:
input_args = request.get_json() or request.form.to_dict()

if input_args is None:
raise ValueError("empty input set passed")

# this can be base64 encoded image, and image path or url
img = input_args.get(img_key)

if not img:
raise ValueError(f"'{img_key}' not found in either json or form data request")

return img

# If neither JSON nor file input is present
raise ValueError(f"'{img_key}' not found in request in either json or form data")


@blueprint.route("/represent", methods=["POST"])
def represent():
input_args = request.get_json()
input_args = request.get_json() or request.form.to_dict()

if input_args is None:
return {"message": "empty input set passed"}

img_path = input_args.get("img") or input_args.get("img_path")
if img_path is None:
return {"message": "you must pass img_path input"}
try:
img = extract_image_from_request("img")
except Exception as err:
return {"exception": str(err)}, 400

obj = service.represent(
img_path=img_path,
img_path=img,
model_name=input_args.get("model_name", "VGG-Face"),
detector_backend=input_args.get("detector_backend", "opencv"),
enforce_detection=input_args.get("enforce_detection", True),
Expand All @@ -41,23 +96,21 @@ def represent():

@blueprint.route("/verify", methods=["POST"])
def verify():
input_args = request.get_json()

if input_args is None:
return {"message": "empty input set passed"}

img1_path = input_args.get("img1") or input_args.get("img1_path")
img2_path = input_args.get("img2") or input_args.get("img2_path")
input_args = request.get_json() or request.form.to_dict()

if img1_path is None:
return {"message": "you must pass img1_path input"}
try:
img1 = extract_image_from_request("img1")
except Exception as err:
return {"exception": str(err)}, 400

if img2_path is None:
return {"message": "you must pass img2_path input"}
try:
img2 = extract_image_from_request("img2")
except Exception as err:
return {"exception": str(err)}, 400

verification = service.verify(
img1_path=img1_path,
img2_path=img2_path,
img1_path=img1,
img2_path=img2,
model_name=input_args.get("model_name", "VGG-Face"),
detector_backend=input_args.get("detector_backend", "opencv"),
distance_metric=input_args.get("distance_metric", "cosine"),
Expand All @@ -73,18 +126,31 @@ def verify():

@blueprint.route("/analyze", methods=["POST"])
def analyze():
input_args = request.get_json()

if input_args is None:
return {"message": "empty input set passed"}

img_path = input_args.get("img") or input_args.get("img_path")
if img_path is None:
return {"message": "you must pass img_path input"}
input_args = request.get_json() or request.form.to_dict()

try:
img = extract_image_from_request("img")
except Exception as err:
return {"exception": str(err)}, 400

actions = input_args.get("actions", ["age", "gender", "emotion", "race"])
# actions is the only argument instance of list or tuple
# if request is form data, input args can either be text or file
if isinstance(actions, str):
actions = (
actions.replace("[", "")
.replace("]", "")
.replace("(", "")
.replace(")", "")
.replace('"', "")
.replace("'", "")
.replace(" ", "")
.split(",")
)

demographies = service.analyze(
img_path=img_path,
actions=input_args.get("actions", ["age", "gender", "emotion", "race"]),
img_path=img,
actions=actions,
detector_backend=input_args.get("detector_backend", "opencv"),
enforce_detection=input_args.get("enforce_detection", True),
align=input_args.get("align", True),
Expand Down

0 comments on commit 60067e2

Please sign in to comment.