-
Notifications
You must be signed in to change notification settings - Fork 868
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* stateful inference-core layer * add grpc layer * add google rpc submodule * fmt * update sequence batch img * update sequence batch img * fmt * delete used file * fmt * fix log and update doc * update log * fmt * make BatchAggregator as base * fix conflict * fix conflict * add SequenceBatchAggregator * update ci for submodule * refactor * fmt * fmt * fix lint * code refactor * update readme * update readme * fmt * fmt * test workflow * revert test * revert test response * fmt * fmt * update readme * allow number ofjobGroup is larger than batchsize * fmt * fix typo * add stateful test data * fmt * fmt * fmt * fmt * set default maxNumSequence * fmt * fmt * revert back config.properties * fmt
- Loading branch information
Showing
53 changed files
with
1,747 additions
and
270 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "third_party/google/rpc"] | ||
path = third_party/google/rpc | ||
url = https://github.com/googleapis/googleapis.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
140 changes: 140 additions & 0 deletions
140
examples/large_models/Huggingface_accelerate/llama2/custom_handler_code.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import logging | ||
from abc import ABC | ||
|
||
import torch | ||
import transformers | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from ts.context import Context | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.info("Transformers version %s", transformers.__version__) | ||
|
||
|
||
class LlamaHandler(BaseHandler, ABC): | ||
""" | ||
Transformers handler class for sequence, token classification and question answering. | ||
""" | ||
|
||
def __init__(self): | ||
super(LlamaHandler, self).__init__() | ||
self.max_length = None | ||
self.max_new_tokens = None | ||
self.tokenizer = None | ||
self.initialized = False | ||
|
||
def initialize(self, ctx: Context): | ||
"""In this initialize function, the HF large model is loaded and | ||
partitioned using DeepSpeed. | ||
Args: | ||
ctx (context): It is a JSON Object containing information | ||
pertaining to the model artifacts parameters. | ||
""" | ||
model_dir = ctx.system_properties.get("model_dir") | ||
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"]) | ||
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) | ||
model_name = ctx.model_yaml_config["handler"]["model_name"] | ||
model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}' | ||
seed = int(ctx.model_yaml_config["handler"]["manual_seed"]) | ||
torch.manual_seed(seed) | ||
|
||
logger.info("Model %s loading tokenizer", ctx.model_name) | ||
self.model = AutoModelForCausalLM.from_pretrained( | ||
model_path, | ||
device_map="balanced", | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch.float16, | ||
load_in_8bit=True, | ||
trust_remote_code=True, | ||
) | ||
if ctx.model_yaml_config["handler"]["fast_kernels"]: | ||
from optimum.bettertransformer import BetterTransformer | ||
|
||
try: | ||
self.model = BetterTransformer.transform(self.model) | ||
except RuntimeError as error: | ||
logger.warning( | ||
"HuggingFace Optimum is not supporting this model,for the list of supported models, please refer to this doc,https://huggingface.co/docs/optimum/bettertransformer/overview" | ||
) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
|
||
logger.info("Model %s loaded successfully", ctx.model_name) | ||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
""" | ||
Basic text preprocessing, based on the user's choice of application mode. | ||
Args: | ||
requests (list): A list of dictionaries with a "data" or "body" field, each | ||
containing the input text to be processed. | ||
Returns: | ||
tuple: A tuple with two tensors: the batch of input ids and the batch of | ||
attention masks. | ||
""" | ||
input_texts = [data.get("data") or data.get("body") for data in requests] | ||
input_ids_batch, attention_mask_batch = [], [] | ||
for input_text in input_texts: | ||
input_ids, attention_mask = self.encode_input_text(input_text) | ||
input_ids_batch.append(input_ids) | ||
attention_mask_batch.append(attention_mask) | ||
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device) | ||
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device) | ||
return input_ids_batch, attention_mask_batch | ||
|
||
def encode_input_text(self, input_text): | ||
""" | ||
Encodes a single input text using the tokenizer. | ||
Args: | ||
input_text (str): The input text to be encoded. | ||
Returns: | ||
tuple: A tuple with two tensors: the encoded input ids and the attention mask. | ||
""" | ||
if isinstance(input_text, (bytes, bytearray)): | ||
input_text = input_text.decode("utf-8") | ||
logger.info("Received text: '%s'", input_text) | ||
inputs = self.tokenizer.encode_plus( | ||
input_text, | ||
max_length=self.max_length, | ||
padding=False, | ||
add_special_tokens=True, | ||
return_tensors="pt", | ||
truncation=True, | ||
) | ||
input_ids = inputs["input_ids"] | ||
attention_mask = inputs["attention_mask"] | ||
return input_ids, attention_mask | ||
|
||
def inference(self, input_batch): | ||
""" | ||
Predicts the class (or classes) of the received text using the serialized transformers | ||
checkpoint. | ||
Args: | ||
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch | ||
of attention masks, as returned by the preprocess function. | ||
Returns: | ||
list: A list of strings with the predicted values for each input text in the batch. | ||
""" | ||
input_ids_batch, attention_mask_batch = input_batch | ||
input_ids_batch = input_ids_batch.to(self.device) | ||
outputs = self.model.generate( | ||
input_ids_batch, | ||
attention_mask=attention_mask_batch, | ||
max_length=self.max_new_tokens, | ||
) | ||
|
||
inferences = self.tokenizer.batch_decode( | ||
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
|
||
logger.info("Generated text: %s", inferences) | ||
return inferences | ||
|
||
def postprocess(self, inference_output): | ||
"""Post Process Function converts the predicted response into Torchserve readable format. | ||
Args: | ||
inference_output (list): It contains the predicted response of the input text. | ||
Returns: | ||
(list): Returns a list of the Predictions and Explanations. | ||
""" | ||
return inference_output |
Oops, something went wrong.