diff --git a/ebd-all-minilm/build-aarch64-apple-darwin.sh b/ebd-all-minilm/build-aarch64-apple-darwin.sh index 072befa..10ae98a 100755 --- a/ebd-all-minilm/build-aarch64-apple-darwin.sh +++ b/ebd-all-minilm/build-aarch64-apple-darwin.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export VERSION=1.0.4 +export VERSION=1.0.5 test -f venv/bin/activate || python -m venv venv source venv/bin/activate diff --git a/ebd-all-minilm/build.sh b/ebd-all-minilm/build.sh index c112287..4c84146 100755 --- a/ebd-all-minilm/build.sh +++ b/ebd-all-minilm/build.sh @@ -1,6 +1,6 @@ #!/bin/bash set -e -export VERSION=1.0.4 +export VERSION=1.0.5 source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh" build_cpu ghcr.io/premai-io/embeddings-all-minilm-l6-v2-cpu all-MiniLM-L6-v2 ${@:1} diff --git a/ebd-all-minilm/main.py b/ebd-all-minilm/main.py index a624f51..71bbead 100644 --- a/ebd-all-minilm/main.py +++ b/ebd-all-minilm/main.py @@ -1,4 +1,6 @@ +import argparse import logging +import os import uvicorn from dotenv import load_dotenv @@ -9,6 +11,15 @@ load_dotenv() +MODEL_DIR = os.getenv("MODEL_ID", "all-MiniLM-L6-v2") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--port", help="Port to run model server on", type=int, default=8444) + parser.add_argument("--model-dir", help="Path to model dir", default=MODEL_DIR) + args = parser.parse_args() + MODEL_DIR = args.model_dir + logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, @@ -18,7 +29,7 @@ def create_start_app_handler(app: FastAPI): def start_app() -> None: - SentenceTransformerBasedModel.get_model() + SentenceTransformerBasedModel.get_model(MODEL_DIR) return start_app @@ -41,4 +52,4 @@ def get_application() -> FastAPI: if __name__ == "__main__": - uvicorn.run("main:app", host="0.0.0.0", port=8000) + uvicorn.run("main:app", host="0.0.0.0", port=args.port) diff --git a/ebd-all-minilm/models.py b/ebd-all-minilm/models.py index 7ed2ed9..afade91 100644 --- a/ebd-all-minilm/models.py +++ b/ebd-all-minilm/models.py @@ -12,10 +12,7 @@ def embeddings(cls, texts): return values.tolist() @classmethod - def get_model(cls): + def get_model(cls, model_path): if cls.model is None: - cls.model = SentenceTransformer( - os.getenv("MODEL_ID", "all-MiniLM-L6-v2"), - device=os.getenv("DEVICE", "cpu"), - ) + cls.model = SentenceTransformer(model_path, device=os.getenv("DEVICE", "cpu")) return cls.model