This repository has been archived by the owner on Jan 14, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathapp.py
124 lines (106 loc) · 3.8 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import sys
import time
import logging
import os
from humanfriendly import format_timespan
from typing import Optional
from wsgicors import CORS
from molten import (
App, Route, Settings, HTTP_401, HTTPError, Header,
annotate, ResponseRendererMiddleware, Response, JSONRenderer,
JSONParser,MultiPartParser
)
from molten.openapi import OpenAPIHandler, OpenAPIUIHandler, Metadata, HTTPSecurityScheme
from molten.contrib.prometheus import expose_metrics, prometheus_middleware
from molten.contrib.request_id import RequestIdMiddleware
from model import ModelData, FeedbackData
import model
from logger import setup_logging
from config import CONFIG
# Application Version
VERSION='v0.1.0'
def auth_middleware(handler):
"""
Authentication Middleware to check for
Bearer token header
"""
def middleware(authorization: Optional[Header]):
if authorization and authorization[len("Bearer "):] == CONFIG['token'] or getattr(handler, "no_auth", False):
return handler()
raise HTTPError(HTTP_401, {"error": "bad credentials"})
return middleware
# Add OpenAPI and Swaager support to our APIs
get_schema = OpenAPIHandler(
metadata=Metadata(
title='ML app',
description='A test ML application',
version=VERSION
),
)
# Initialize objects
start_time = time.time()
get_docs = OpenAPIUIHandler()
setup_logging()
# Annotate these objects to be accessible without authentication
get_schema = annotate(no_auth=True)(get_schema)
get_docs = annotate(no_auth=True)(get_docs)
def health() -> str:
"""
Setup root to return application status and serve as health check or readiness endpoint
Kubernetes or similar resource managers can start service when it replied 200 OK back
and restart in case of failure
"""
return {"version": VERSION,
"uptime": format_timespan(time.time() - start_time)
}
def predict(data: ModelData) -> str:
"""
Pass the request data as ModelData object,
as this can be customised in the model.py file to adapt based
on deployed model to make predictions
Parameters:
data: Parse the request body data based on your model schema and
pass this to predict method to make prediction
"""
return model.predict(data)
def feedback(data: FeedbackData) -> str:
"""
Pass the request data as FeedbackData object,
as this can be customised in the model.py file to adapt based
on deployed model to make predictions
Parameters:
data: Parse the request body data based on your model schema and
pass this to predict method to make prediction
"""
return model.feedback(data)
# Load our pre trained model
model.load()
logging.info(f"Loaded model: {CONFIG['model_name']}")
# Setup the list of middlewares to be enabled
middlewares = [prometheus_middleware, RequestIdMiddleware(), ResponseRendererMiddleware()]
# If token is not empty setup the authentication middleware as well openAPI config
if CONFIG['token'] != "":
middlewares.append(auth_middleware)
get_schema.security_schemes = [HTTPSecurityScheme("Bearer Auth", "bearer")]
get_schema.default_security_scheme="Bearer Auth"
# Initialize the application with all the required routes and middlewares
app = App(
middleware=middlewares,
routes=[
Route("/_docs", get_docs),
Route("/_schema", get_schema),
Route("/metrics", expose_metrics),
Route("/", health),
Route("/v1/predict", predict, method="POST"),
Route("/v1/feedback", feedback, method="POST"),
],
parsers=[
JSONParser(),
MultiPartParser(),
],
renderers=[JSONRenderer()]
)
# If running in production, setup CORS for our application
if os.getenv("ENVIRONMENT") == "production":
app = CORS(app, headers="*", methods="*", origin="*", maxage="86400")