Skip to content

Commit

Permalink
ci
Browse files Browse the repository at this point in the history
  • Loading branch information
truskovskiyk committed Sep 7, 2024
1 parent e0962bb commit bdab9c3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
5 changes: 1 addition & 4 deletions module-5/serving/pytriton_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import numpy as np
from pytriton.client import ModelClient


# https://triton-inference-server.github.io/pytriton/latest/clients/
def main():
sequence = np.array([
["one day I will see the world"],
["I would love to learn cook the Asian street food"],
["Carnival in Rio de Janeiro"],
["William Shakespeare was a great writer"],
])
sequence = np.char.encode(sequence, "utf-8")

Expand Down
25 changes: 9 additions & 16 deletions module-5/serving/pytriton_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton

from serving.predictor import Predictor

logger = logging.getLogger("server")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")

CLASSIFIER = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
predictor = Predictor.default_from_model_registry()

# Labels pre-cached on server side
LABELS = [
Expand All @@ -30,17 +31,13 @@
@batch
def _infer_fn(sequence: np.ndarray):
sequence = np.char.decode(sequence.astype("bytes"), "utf-8")
sequence = sequence.tolist()
sequence = sequence.tolist()[0]

logger.info(f"sequence = {sequence}")
results = predictor.predict(text=sequence)
logger.info(f"results = {results}")

classification_result = CLASSIFIER(sequence, LABELS)
result_labels = []
for result in classification_result:
logger.debug(result)
most_probable_label = result["labels"][0]
result_labels.append([most_probable_label])

result_labels = ['travel' for _ in range(len(sequence))]
return {"label": np.char.encode(result_labels, "utf-8")}


Expand All @@ -51,13 +48,9 @@ def main():
triton.bind(
model_name="BART",
infer_func=_infer_fn,
inputs=[
Tensor(name="sequence", dtype=bytes, shape=(1,)),
],
outputs=[
Tensor(name="label", dtype=bytes, shape=(1,)),
],
config=ModelConfig(max_batch_size=4),
inputs=[Tensor(name="sequence", dtype=bytes, shape=(-1,)),],
outputs=[Tensor(name="label", dtype=bytes, shape=(1,)),],
config=ModelConfig(max_batch_size=1),
)
logger.info("Serving inference")
triton.serve()
Expand Down

0 comments on commit bdab9c3

Please sign in to comment.