Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance serving evaluation endpoints #595

Merged
merged 12 commits into from
Mar 1, 2024
115 changes: 89 additions & 26 deletions cornac/serving/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from cornac.metrics import *

try:
from flask import Flask, jsonify, request
from flask import Flask, jsonify, request, abort
except ImportError:
exit("Flask is required in order to serve models.\n" + "Run: pip3 install Flask")

Expand Down Expand Up @@ -197,34 +197,11 @@
return "Unable to evaluate. 'train_set' is not provided", 400

query = request.json
validate_query(query)

query_metrics = query.get("metrics")
rating_threshold = query.get("rating_threshold", 1.0)
exclude_unknowns = (
query.get("exclude_unknowns", "true").lower() == "true"
) # exclude unknown users/items by default, otherwise specified
user_based = (
query.get("user_based", "true").lower() == "true"
) # user_based evaluation by default, otherwise specified

if query_metrics is None:
return "metrics is required", 400
elif not isinstance(query_metrics, list):
return "metrics must be an array of metrics", 400

# organize metrics
metrics = []
for metric in query_metrics:
try:
metrics.append(_safe_eval(metric))
except:
return (
f"Invalid metric initiation: {metric}.\n"
+ "Please input correct metrics (e.g., 'RMSE()', 'Recall(k=10)')",
400,
)

rating_metrics, ranking_metrics = BaseMethod.organize_metrics(metrics)

# read data
data = []
Expand All @@ -244,6 +221,85 @@
exclude_unknowns=exclude_unknowns,
)

return process_evaluation(test_set, query, exclude_unknowns)


def validate_query(query):
query_metrics = query.get("metrics")

if query_metrics is None:
abort(400, "metrics is required")

Check warning on line 231 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L231

Added line #L231 was not covered by tests
elif not isinstance(query_metrics, list):
abort(400, "metrics must be an array of metrics")

Check warning on line 233 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L233

Added line #L233 was not covered by tests


@app.route("/evaluate-json", methods=["POST"])
def evaluate_json():
global model, train_set, metric_classnames

# Input validation
if model is None:
abort(400, "Model is not yet loaded. Please try again later.")

Check warning on line 242 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L241-L242

Added lines #L241 - L242 were not covered by tests

if train_set is None:
abort(400, "Unable to evaluate. 'train_set' is not provided")

Check warning on line 245 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L244-L245

Added lines #L244 - L245 were not covered by tests

query = request.get_json()

Check warning on line 247 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L247

Added line #L247 was not covered by tests

validate_query(query)

Check warning on line 249 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L249

Added line #L249 was not covered by tests

if "data" not in query:
abort(400, "Evaluation data is not provided. 'data' is required in the form of a list of tuples (uid, iid, rating).")

Check warning on line 252 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L251-L252

Added lines #L251 - L252 were not covered by tests

exclude_unknowns = (

Check warning on line 254 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L254

Added line #L254 was not covered by tests
query.get("exclude_unknowns", "true").lower() == "true"
) # exclude unknown users/items by default, otherwise specified

# read data
data = query.get("data")

Check warning on line 259 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L259

Added line #L259 was not covered by tests

if not len(data):
raise ValueError("No data available to evaluate the model.")

Check warning on line 262 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L261-L262

Added lines #L261 - L262 were not covered by tests

# convert rows of data to tuples
for i, row in enumerate(data):
data[i] = tuple(row)

Check warning on line 266 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L265-L266

Added lines #L265 - L266 were not covered by tests

test_set = Dataset.build(

Check warning on line 268 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L268

Added line #L268 was not covered by tests
data,
fmt="UIR",
global_uid_map=train_set.uid_map,
global_iid_map=train_set.iid_map,
exclude_unknowns=exclude_unknowns,
)

return process_evaluation(test_set, query, exclude_unknowns)

Check warning on line 276 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L276

Added line #L276 was not covered by tests


def process_evaluation(test_set, query, exclude_unknowns):
global model, train_set

rating_threshold = query.get("rating_threshold", 1.0)
user_based = (
query.get("user_based", "true").lower() == "true"
) # user_based evaluation by default, otherwise specified

query_metrics = query.get("metrics")

# organize metrics
metrics = []
for metric in query_metrics:
try:
metrics.append(_safe_eval(metric))
except:
return (

Check warning on line 295 in cornac/serving/app.py

View check run for this annotation

Codecov / codecov/patch

cornac/serving/app.py#L294-L295

Added lines #L294 - L295 were not covered by tests
f"Invalid metric initiation: {metric}.\n"
+ "Please input correct metrics (e.g., 'RMSE()', 'Recall(k=10)')",
400,
)

rating_metrics, ranking_metrics = BaseMethod.organize_metrics(metrics)

# evaluation
result = BaseMethod.eval(
model=model,
Expand All @@ -258,10 +314,17 @@
verbose=False,
)

# map user index back into the original user ID
metric_user_results = {}
for metric, user_results in result.metric_user_results.items():
metric_user_results[metric] = {
train_set.user_ids[int(k)]: v for k, v in user_results.items()
}

# response
response = {
"result": result.metric_avg_results,
"query": query,
"user_result": metric_user_results,
}

return jsonify(response), 200
Expand Down
3 changes: 2 additions & 1 deletion tests/cornac/serving/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ def test_evaluate_json(client):
response = client.post('/evaluate', json=json_data)
# assert response.content_type == 'application/json'
assert response.status_code == 200
assert len(response.json['query']['metrics']) == 2
assert 'RMSE' in response.json['result']
assert 'Recall@5' in response.json['result']
assert 'RMSE' in response.json['user_result']
assert 'Recall@5' in response.json['user_result']


def test_evalulate_incorrect_get(client):
Expand Down
Loading