Skip to content

Commit

Permalink
Multi-process RESTful API (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Nov 28, 2023
1 parent f34b772 commit a5d2f26
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 11 deletions.
1 change: 1 addition & 0 deletions mii/backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _initialize_service(self, mii_config: MIIConfig) -> List[subprocess.Popen]:
f"--deployment-name {mii_config.deployment_name}",
f"--load-balancer-port {mii_config.port_number}",
f"--restful-gateway-port {mii_config.restful_api_port}",
f"--restful-gateway-procs {mii_config.restful_processes}"
]

host_gpus = defaultdict(list)
Expand Down
5 changes: 5 additions & 0 deletions mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ class MIIConfig(DeepSpeedConfigModel):
Port number to use for the RESTful API.
"""

restful_processes: int = Field(32, ge=1)
"""
Number of processes to use for the RESTful API.
"""

hostfile: str = DLTS_HOSTFILE
"""
DeepSpeed hostfile. Will be autogenerated if None is provided.
Expand Down
20 changes: 10 additions & 10 deletions mii/grpc_related/restful_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import json
import threading
import time

from flask import Flask, request
from flask import Flask, request, jsonify
from flask_restful import Resource, Api
from werkzeug.serving import make_server

Expand All @@ -20,18 +19,15 @@ def shutdown(thread):


def createRestfulGatewayApp(deployment_name, server_thread):
# client must be thread-safe
client = mii.client(deployment_name)

class RestfulGatewayService(Resource):
def __init__(self):
super().__init__()
self.client = mii.client(deployment_name)

def post(self):
data = request.get_json()
result = client.generate(**data)
result_json = json.dumps([r.to_msg_dict() for r in result])
return result_json
result = self.client.generate(**data)
return jsonify([r.to_msg_dict() for r in result])

app = Flask("RestfulGateway")

Expand All @@ -49,11 +45,15 @@ def terminate():


class RestfulGatewayThread(threading.Thread):
def __init__(self, deployment_name, rest_port):
def __init__(self, deployment_name, rest_port, rest_procs):
threading.Thread.__init__(self)

app = createRestfulGatewayApp(deployment_name, self)
self.server = make_server("127.0.0.1", rest_port, app)
self.server = make_server("127.0.0.1",
rest_port,
app,
threaded=False,
processes=rest_procs)
self.ctx = app.app_context()
self.ctx.push()

Expand Down
5 changes: 5 additions & 0 deletions mii/launch/multi_gpu_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def main() -> None:
default=0,
help="Port to use for restful gateway.",
)
parser.add_argument("--restful-gateway-procs",
type=int,
default=32,
help="Number of processes to use for restful gateway.")
args = parser.parse_args()
assert not (
args.load_balancer and args.restful_gateway
Expand All @@ -70,6 +74,7 @@ def main() -> None:
gateway_thread = RestfulGatewayThread(
deployment_name=args.deployment_name,
rest_port=args.restful_gateway_port,
rest_procs=args.restful_gateway_procs,
)
stop_event = gateway_thread.get_stop_event()
gateway_thread.start()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@ def test_restful_api(deployment, query, deployment_name, restful_api_port):
data=json_params,
headers={"Content-Type": "application/json"})
assert result.status_code == 200
assert "generated_text" in result.json()
assert "generated_text" in result.json()[0]

0 comments on commit a5d2f26

Please sign in to comment.