From bdab9c349e04f73d960e6b90f124dd69f5a5d048 Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sat, 7 Sep 2024 00:03:56 -0400 Subject: [PATCH] ci --- module-5/serving/pytriton_client.py | 5 +---- module-5/serving/pytriton_serving.py | 25 +++++++++---------------- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/module-5/serving/pytriton_client.py b/module-5/serving/pytriton_client.py index 8784be6..c591836 100644 --- a/module-5/serving/pytriton_client.py +++ b/module-5/serving/pytriton_client.py @@ -2,13 +2,10 @@ import numpy as np from pytriton.client import ModelClient - +# https://triton-inference-server.github.io/pytriton/latest/clients/ def main(): sequence = np.array([ ["one day I will see the world"], - ["I would love to learn cook the Asian street food"], - ["Carnival in Rio de Janeiro"], - ["William Shakespeare was a great writer"], ]) sequence = np.char.encode(sequence, "utf-8") diff --git a/module-5/serving/pytriton_serving.py b/module-5/serving/pytriton_serving.py index f962f46..f49b62f 100644 --- a/module-5/serving/pytriton_serving.py +++ b/module-5/serving/pytriton_serving.py @@ -7,11 +7,12 @@ from pytriton.model_config import ModelConfig, Tensor from pytriton.triton import Triton +from serving.predictor import Predictor logger = logging.getLogger("server") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s") -CLASSIFIER = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") +predictor = Predictor.default_from_model_registry() # Labels pre-cached on server side LABELS = [ @@ -30,17 +31,13 @@ @batch def _infer_fn(sequence: np.ndarray): sequence = np.char.decode(sequence.astype("bytes"), "utf-8") - sequence = sequence.tolist() + sequence = sequence.tolist()[0] logger.info(f"sequence = {sequence}") + results = predictor.predict(text=sequence) + logger.info(f"results = {results}") - classification_result = CLASSIFIER(sequence, LABELS) - result_labels = [] - for result in classification_result: - logger.debug(result) - most_probable_label = result["labels"][0] - result_labels.append([most_probable_label]) - + result_labels = ['travel' for _ in range(len(sequence))] return {"label": np.char.encode(result_labels, "utf-8")} @@ -51,13 +48,9 @@ def main(): triton.bind( model_name="BART", infer_func=_infer_fn, - inputs=[ - Tensor(name="sequence", dtype=bytes, shape=(1,)), - ], - outputs=[ - Tensor(name="label", dtype=bytes, shape=(1,)), - ], - config=ModelConfig(max_batch_size=4), + inputs=[Tensor(name="sequence", dtype=bytes, shape=(-1,)),], + outputs=[Tensor(name="label", dtype=bytes, shape=(1,)),], + config=ModelConfig(max_batch_size=1), ) logger.info("Serving inference") triton.serve()