Skip to content

Commit

Permalink
add option to disable swagger ui
Browse files Browse the repository at this point in the history
  • Loading branch information
liusy182 committed Nov 16, 2020
1 parent 26fc958 commit d6ee662
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 13 deletions.
27 changes: 25 additions & 2 deletions bentoml/cli/bento_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,27 @@ def open_api_spec(bento=None, yatai_url=None):
help='Remote YataiService URL. Optional. '
'Example: "--yatai-url http://localhost:50050"',
)
@click.option(
'--enable-swagger/--disable-swagger',
is_flag=True,
default=True,
help="Run API server with Swagger UI enabled",
envvar='BENTOML_ENABLE_SWAGGER',
)
def serve(
port, bento=None, enable_microbatch=False, run_with_ngrok=False, yatai_url=None
port,
bento=None,
enable_microbatch=False,
run_with_ngrok=False,
yatai_url=None,
enable_swagger=True,
):
saved_bundle_path = resolve_bundle_path(
bento, pip_installed_bundle_path, yatai_url
)
start_dev_server(saved_bundle_path, port, enable_microbatch, run_with_ngrok)
start_dev_server(
saved_bundle_path, port, enable_microbatch, run_with_ngrok, enable_swagger
)

# Example Usage:
# bentoml serve-gunicorn {BUNDLE_PATH} --port={PORT} --workers={WORKERS}
Expand Down Expand Up @@ -219,6 +233,13 @@ def serve(
help='Remote YataiService URL. Optional. '
'Example: "--yatai-url http://localhost:50050"',
)
@click.option(
'--enable-swagger/--disable-swagger',
is_flag=True,
default=True,
help="Run API server with Swagger UI enabled",
envvar='BENTOML_ENABLE_SWAGGER',
)
def serve_gunicorn(
port,
workers,
Expand All @@ -227,6 +248,7 @@ def serve_gunicorn(
enable_microbatch=False,
microbatch_workers=1,
yatai_url=None,
enable_swagger=True,
):
if not psutil.POSIX:
_echo(
Expand All @@ -246,6 +268,7 @@ def serve_gunicorn(
workers,
enable_microbatch,
microbatch_workers,
enable_swagger,
)

@bentoml_cli.command(
Expand Down
26 changes: 21 additions & 5 deletions bentoml/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def async_trace(*args, **kwargs):


def start_dev_server(
saved_bundle_path: str, port: int, enable_microbatch: bool, run_with_ngrok: bool
saved_bundle_path: str,
port: int,
enable_microbatch: bool,
run_with_ngrok: bool,
enable_swagger: bool,
):
logger.info("Starting BentoML API server in development mode..")

Expand Down Expand Up @@ -72,11 +76,15 @@ def start_dev_server(
outbound_port=api_server_port,
outbound_workers=1,
)
api_server = BentoAPIServer(bento_service, port=api_server_port)
api_server = BentoAPIServer(
bento_service, port=api_server_port, enable_swagger=enable_swagger
)
marshal_server.async_start(port=port)
api_server.start()
else:
api_server = BentoAPIServer(bento_service, port=port)
api_server = BentoAPIServer(
bento_service, port=port, enable_swagger=enable_swagger
)
api_server.start()


Expand All @@ -87,6 +95,7 @@ def start_prod_server(
workers: int,
enable_microbatch: bool,
microbatch_workers: int,
enable_swagger: bool,
):
logger.info("Starting BentoML API server in production mode..")

Expand Down Expand Up @@ -120,10 +129,17 @@ def start_prod_server(
)

gunicorn_app = GunicornBentoServer(
saved_bundle_path, api_server_port, workers, timeout, prometheus_lock,
saved_bundle_path,
api_server_port,
workers,
timeout,
prometheus_lock,
enable_swagger,
)
marshal_server.async_run()
gunicorn_app.run()
else:
gunicorn_app = GunicornBentoServer(saved_bundle_path, port, workers, timeout)
gunicorn_app = GunicornBentoServer(
saved_bundle_path, port, workers, timeout, enable_swagger=enable_swagger
)
gunicorn_app.run()
16 changes: 13 additions & 3 deletions bentoml/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,20 @@ class BentoAPIServer:
DEFAULT_PORT = config("apiserver").getint("default_port")
_MARSHAL_FLAG = config("marshal_server").get("marshal_request_header_flag")

def __init__(self, bento_service: BentoService, port=DEFAULT_PORT, app_name=None):
def __init__(
self,
bento_service: BentoService,
port=DEFAULT_PORT,
app_name=None,
enable_swagger=True,
):
app_name = bento_service.name if app_name is None else app_name

self.port = port
self.bento_service = bento_service
self.app = Flask(app_name, static_folder=None)
self.static_path = self.bento_service.get_web_static_content_path()
self.enable_swagger = enable_swagger

self.swagger_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'swagger_static'
Expand Down Expand Up @@ -140,11 +147,14 @@ def index_view_func(static_path):
"""
return send_from_directory(static_path, 'index.html')

@staticmethod
def swagger_ui_func():
def swagger_ui_func(self):
"""
The swagger UI route for BentoML API server
"""
if not self.enable_swagger:
return Response(
response="Swagger is disabled", status=404, mimetype="text/html"
)
return Response(
response=INDEX_HTML.format(url='docs.json'),
status=200,
Expand Down
13 changes: 11 additions & 2 deletions bentoml/server/gunicorn_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ class GunicornBentoServer(Application): # pylint: disable=abstract-method
"""

def __init__(
self, bundle_path, port=None, workers=None, timeout=None, prometheus_lock=None,
self,
bundle_path,
port=None,
workers=None,
timeout=None,
prometheus_lock=None,
enable_swagger=True,
):
self.bento_service_bundle_path = bundle_path

Expand All @@ -73,6 +79,7 @@ def __init__(
if workers:
self.options['workers'] = workers
self.prometheus_lock = prometheus_lock
self.enable_swagger = enable_swagger

super(GunicornBentoServer, self).__init__()

Expand All @@ -92,7 +99,9 @@ def load_config(self):

def load(self):
bento_service = load_from_dir(self.bento_service_bundle_path)
api_server = GunicornBentoAPIServer(bento_service, port=self.port)
api_server = GunicornBentoAPIServer(
bento_service, port=self.port, enable_swagger=self.enable_swagger
)
return api_server.app

def run(self):
Expand Down
17 changes: 17 additions & 0 deletions tests/server/test_model_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,20 @@ def test_api_function_route(bento_service, img_file):
# },
# )
# assert 200 == response.status_code


def test_api_function_route_with_disabled_swagger(bento_service):
rest_server = BentoAPIServer(bento_service, enable_swagger=False)
test_client = rest_server.app.test_client()

response = test_client.get("/")
assert 404 == response.status_code

response = test_client.get("/docs")
assert 404 == response.status_code

response = test_client.get("/healthz")
assert 200 == response.status_code

response = test_client.get("/docs.json")
assert 200 == response.status_code
4 changes: 3 additions & 1 deletion tests/utils/test_usage_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def func(yatai_service, deployment_pb):
return func


def mock_start_dev_server(bundle_path, port, enable_microbatch, run_with_ngrok):
def mock_start_dev_server(
bundle_path, port, enable_microbatch, run_with_ngrok, enable_swagger
):
raise KeyboardInterrupt()


Expand Down

0 comments on commit d6ee662

Please sign in to comment.