diff --git a/deepface/api/src/modules/core/routes.py b/deepface/api/src/modules/core/routes.py
index 4830bec2..9cb2e747 100644
--- a/deepface/api/src/modules/core/routes.py
+++ b/deepface/api/src/modules/core/routes.py
@@ -1,31 +1,86 @@
+# built-in dependencies
+from typing import Union
+
+# 3rd party dependencies
from flask import Blueprint, request
+import numpy as np
+
+# project dependencies
from deepface import DeepFace
from deepface.api.src.modules.core import service
+from deepface.commons import image_utils
from deepface.commons.logger import Logger
logger = Logger()
blueprint = Blueprint("routes", __name__)
+# pylint: disable=no-else-return, broad-except
+
@blueprint.route("/")
def home():
return f"
Welcome to DeepFace API v{DeepFace.__version__}!
"
+def extract_image_from_request(img_key: str) -> Union[str, np.ndarray]:
+ """
+ Extracts an image from the request either from json or a multipart/form-data file.
+
+ Args:
+ img_key (str): The key used to retrieve the image data
+ from the request (e.g., 'img1').
+
+ Returns:
+ img (str or np.ndarray): Given image detail (base64 encoded string, image path or url)
+ or the decoded image as a numpy array.
+ """
+
+ # Check if the request is multipart/form-data (file input)
+ if request.files:
+ # request.files is instance of werkzeug.datastructures.ImmutableMultiDict
+ # file is instance of werkzeug.datastructures.FileStorage
+ file = request.files.get(img_key)
+
+ if file is None:
+ raise ValueError(f"Request form data doesn't have {img_key}")
+
+ if file.filename == "":
+ raise ValueError(f"No file uploaded for '{img_key}'")
+
+ img = image_utils.load_image_from_file_storage(file)
+
+ return img
+ # Check if the request is coming as base64, file path or url from json or form data
+ elif request.is_json or request.form:
+ input_args = request.get_json() or request.form.to_dict()
+
+ if input_args is None:
+ raise ValueError("empty input set passed")
+
+ # this can be base64 encoded image, and image path or url
+ img = input_args.get(img_key)
+
+ if not img:
+ raise ValueError(f"'{img_key}' not found in either json or form data request")
+
+ return img
+
+ # If neither JSON nor file input is present
+ raise ValueError(f"'{img_key}' not found in request in either json or form data")
+
+
@blueprint.route("/represent", methods=["POST"])
def represent():
- input_args = request.get_json()
+ input_args = request.get_json() or request.form.to_dict()
- if input_args is None:
- return {"message": "empty input set passed"}
-
- img_path = input_args.get("img") or input_args.get("img_path")
- if img_path is None:
- return {"message": "you must pass img_path input"}
+ try:
+ img = extract_image_from_request("img")
+ except Exception as err:
+ return {"exception": str(err)}, 400
obj = service.represent(
- img_path=img_path,
+ img_path=img,
model_name=input_args.get("model_name", "VGG-Face"),
detector_backend=input_args.get("detector_backend", "opencv"),
enforce_detection=input_args.get("enforce_detection", True),
@@ -41,23 +96,21 @@ def represent():
@blueprint.route("/verify", methods=["POST"])
def verify():
- input_args = request.get_json()
-
- if input_args is None:
- return {"message": "empty input set passed"}
-
- img1_path = input_args.get("img1") or input_args.get("img1_path")
- img2_path = input_args.get("img2") or input_args.get("img2_path")
+ input_args = request.get_json() or request.form.to_dict()
- if img1_path is None:
- return {"message": "you must pass img1_path input"}
+ try:
+ img1 = extract_image_from_request("img1")
+ except Exception as err:
+ return {"exception": str(err)}, 400
- if img2_path is None:
- return {"message": "you must pass img2_path input"}
+ try:
+ img2 = extract_image_from_request("img2")
+ except Exception as err:
+ return {"exception": str(err)}, 400
verification = service.verify(
- img1_path=img1_path,
- img2_path=img2_path,
+ img1_path=img1,
+ img2_path=img2,
model_name=input_args.get("model_name", "VGG-Face"),
detector_backend=input_args.get("detector_backend", "opencv"),
distance_metric=input_args.get("distance_metric", "cosine"),
@@ -73,18 +126,31 @@ def verify():
@blueprint.route("/analyze", methods=["POST"])
def analyze():
- input_args = request.get_json()
-
- if input_args is None:
- return {"message": "empty input set passed"}
-
- img_path = input_args.get("img") or input_args.get("img_path")
- if img_path is None:
- return {"message": "you must pass img_path input"}
+ input_args = request.get_json() or request.form.to_dict()
+
+ try:
+ img = extract_image_from_request("img")
+ except Exception as err:
+ return {"exception": str(err)}, 400
+
+ actions = input_args.get("actions", ["age", "gender", "emotion", "race"])
+ # actions is the only argument instance of list or tuple
+ # if request is form data, input args can either be text or file
+ if isinstance(actions, str):
+ actions = (
+ actions.replace("[", "")
+ .replace("]", "")
+ .replace("(", "")
+ .replace(")", "")
+ .replace('"', "")
+ .replace("'", "")
+ .replace(" ", "")
+ .split(",")
+ )
demographies = service.analyze(
- img_path=img_path,
- actions=input_args.get("actions", ["age", "gender", "emotion", "race"]),
+ img_path=img,
+ actions=actions,
detector_backend=input_args.get("detector_backend", "opencv"),
enforce_detection=input_args.get("enforce_detection", True),
align=input_args.get("align", True),