Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
misc tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Oct 23, 2023
1 parent 02dc713 commit 7719205
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 57 deletions.
9 changes: 5 additions & 4 deletions cht-petals/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
import os

import uvicorn
from dotenv import load_dotenv
Expand All @@ -9,12 +10,12 @@

load_dotenv()

MODEL_PATH = "./models"
DHT_PREFIX = "StableBeluga2"
MODEL_PATH = os.getenv("MODEL_PATH", "./models")
DHT_PREFIX = os.getenv("DHT_PREFIX", "StableBeluga2")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", help="Path to Model files directory", default=MODEL_PATH)
parser.add_argument("--dht_prefix", help="DHT prefix to use")
parser.add_argument("--model-path", help="Path to Model files directory", default=MODEL_PATH)
parser.add_argument("--dht-prefix", help="DHT prefix to use")
parser.add_argument("--port", help="Port to run model server on", type=int)
args = parser.parse_args()
MODEL_PATH = args.model_path
Expand Down
6 changes: 3 additions & 3 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def generate(
return [cls.tokenizer.decode(outputs[0])]

@classmethod
def get_model(cls, model_path: str, dht_prefix: str):
def get_model(cls, model_path: str = "./models", dht_prefix: str = "StableBeluga2"):
if cls.model is None:
Tokenizer = LlamaTokenizer if "llama" in os.getenv("MODEL_ID", model_path).lower() else AutoTokenizer
Tokenizer = LlamaTokenizer if "llama" in model_path.lower() else AutoTokenizer
cls.tokenizer = Tokenizer.from_pretrained(os.getenv("MODEL_ID", model_path))
cls.model = AutoDistributedModelForCausalLM.from_pretrained(
os.getenv("MODEL_ID", model_path),
torch_dtype=torch.float32,
dht_prefix=os.getenv("DHT_PREFIX", dht_prefix),
dht_prefix=dht_prefix,
)
return cls.model
61 changes: 11 additions & 50 deletions cht-petals/setup-petals.sh
Original file line number Diff line number Diff line change
@@ -1,54 +1,15 @@
#!/bin/bash
# ./setup-petals.sh --model_path ./models/models--petals-team--StableBeluga2 --dht_prefix StableBeluga2-hf --port 8794
# ./setup-petals.sh --model-path ./models/models--petals-team--StableBeluga2 --dht-prefix StableBeluga2-hf --port 8794

# Check if the required environment variables are set
if [ -z "$PREM_PYTHON" ]; then
echo "Please set the required PREM_PYTHON environment variable."
echo "Example: export PREM_PYTHON=appdir/envs/prem_env/bin/python"
exit 1
fi
tmpdir="$(mktemp)"

# Define the paths based on environment variables
python_exec="${PREM_PYTHON:-python}"
# clone source
git clone -n --depth=1 --filter=tree:0 https://github.com/premAI-io/prem-services.git "$tmpdir"
git -C "$tmpdir" sparse-checkout set --no-cone cht-petals
git -C "$tmpdir" checkout
# install deps
"${PREM_PYTHON:-python}" -m pip install -r "$tmpdir/cht-petals/requirements.txt"
# run server
PYTHONPATH="$tmpdir/cht-petals" "${PREM_PYTHON:-python}" "$tmpdir/cht-petals/main.py" "$@"

# Clone the GitHub repository if not already present
if [ ! -d "prem-services" ]; then
# only clone the required directory - https://stackoverflow.com/a/52269934
git clone -n --depth=1 --filter=tree:0 https://github.com/premAI-io/prem-services.git
git -C prem-services sparse-checkout set --no-cone cht-petals
git -C prem-services checkout
else
echo "Using the existing 'prem-services' directory."
fi

# Install requirements using the specified Python binary
"$python_exec" -m pip install -r prem-services/cht-petals/requirements.txt

# Check for the --model_path, --dht_prefix, and --port parameters and run main.py
while [[ $# -gt 0 ]]; do
case "$1" in
--model_path)
model_path="$2"
shift 2
;;
--dht_prefix)
dht_prefix="$2"
shift 2
;;
--port)
port="$2"
shift 2
;;
*)
shift
;;
esac
done

if [ -n "$model_path" ] && [ -n "$dht_prefix" ] && [ -n "$port" ]; then
export PYTHONPATH="$(pwd)/prem-services"
"$python_exec" prem-services/cht-petals/main.py --model_path "$model_path" --dht_prefix "$dht_prefix" --port $port
else
echo "Please provide the --model_path parameter with the path to the model directory, the --dht_prefix parameter for the DHT prefix, and the --port parameter for the port number."
exit 1
fi
rm -r "$tmpdir"

0 comments on commit 7719205

Please sign in to comment.