Skip to content

Commit

Permalink
fix grpc
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayank Mishra authored and Mayank Mishra committed Sep 1, 2022
1 parent f9402d0 commit 1b78ef0
Showing 1 changed file with 25 additions and 30 deletions.
55 changes: 25 additions & 30 deletions scripts/bloom-inference-server/ds_inference/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,57 @@
from transformers import AutoTokenizer

import mii
from utils import GenerateRequest, GenerateResponse, Model, get_filter_dict, get_str_dtype, print_rank_n
from utils import (
GenerateRequest,
GenerateResponse,
Model,
get_downloaded_model_path,
get_filter_dict,
get_str_dtype,
print_rank_n
)


class DSInferenceGRPCServer(Model):
def __init__(self, args: argparse.Namespace) -> None:
self.deployment_name = "ds_inference_grpc_server"

files = os.listdir(args.save_mp_checkpoint_path)
for file in files:
if (file.endswith(".json")):
checkpoints_json = json.load(
open(os.path.join(args.save_mp_checkpoint_path, file), "r"))
break
downloaded_model_path = get_downloaded_model_path(args.model_name)

if ("base_dir" in checkpoints_json):
del checkpoints_json["base_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(downloaded_model_path)
self.pad = self.tokenizer.pad_token_id

if (args.dtype in [torch.float16, torch.int8]):
checkpoints_json = os.path.join(
downloaded_model_path, "BLOOM_ds-inference_config.json")

if (args.dtype == torch.float16):
mii.deploy(
task="text-generation",
model=args.model_name,
deployment_name=self.deployment_name,
model_path=downloaded_model_path,
mii_config={
"dtype": get_str_dtype(args.dtype),
"tensor_parallel": 8,
"port_number": 50950,
"checkpoint_dict": checkpoints_json
},
model_path=args.save_mp_checkpoint_path
"checkpoint_dict": json.load(open(checkpoints_json, "r"))
}
)
else:
raise NotImplementedError("This is not yet supported")
elif (args.dtype == torch.bfloat16):
raise NotImplementedError("bfloat16 is not yet supported")

self.tokenizer = AutoTokenizer.from_pretrained(args.model_name)
self.pad = self.tokenizer.pad_token_id
self.model = mii.mii_query_handle(self.deployment_name)

def generate(self, request: GenerateRequest) -> GenerateResponse:
text = request.text

return_type = type(text)
if (return_type == str):
text = [text]

output_text = self.model.query(
{"query": text},
{"query": request.text},
**get_filter_dict(request)
).response

output_text = [_ for _ in output_text]

# Remove input from output
input_tokens = self.tokenizer(text).input_ids
input_tokens = self.tokenizer(request.text).input_ids
output_tokens = self.tokenizer(output_text).input_ids

input_token_lengths = [len(x) for x in input_tokens]
Expand All @@ -72,10 +70,6 @@ def generate(self, request: GenerateRequest) -> GenerateResponse:
output_text = self.tokenizer.batch_decode(
output_tokens, skip_special_tokens=True)

if (return_type == str):
output_text = output_text[0]
num_generated_tokens = num_generated_tokens[0]

return GenerateResponse(
text=output_text,
num_generated_tokens=num_generated_tokens
Expand All @@ -87,4 +81,5 @@ def shutdown(self) -> None:
try:
mii.terminate(self.deployment_name)
except Exception:
exit()
pass
exit()

0 comments on commit 1b78ef0

Please sign in to comment.