Skip to content

Commit

Permalink
feat: dynamic model (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
parveen232 authored Mar 29, 2024
1 parent 9dd4829 commit 3c87287
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

def initialize_model_and_tokenizer(model_name="stabilityai/stable-code-instruct-3b"):
def initialize_model_and_tokenizer(model_name):
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True)
model.eval()
#model.cuda() #uncomment for cuda
Expand Down Expand Up @@ -42,7 +42,7 @@ def _llm_type(self) -> str:
llm_chain = LLMChain(prompt=prompt, llm=llm)
return llm_chain, llm

model, tokenizer = initialize_model_and_tokenizer()
model, tokenizer = initialize_model_and_tokenizer("stabilityai/stable-code-instruct-3b")

with gr.Blocks(fill_height=True) as demo:
with gr.Row():
Expand All @@ -53,7 +53,7 @@ def _llm_type(self) -> str:
repo_id = gr.Textbox(
label="Hugging Face Repo",
info="Default: stabilityai/stable-code-instruct-3b")
loadModelBtn = gr.Button(
load_model_btn = gr.Button(
value="Load Model",
variant="secondary",
interactive=True,)
Expand All @@ -73,9 +73,11 @@ def user(user_message, history):
return "", history + [[user_message, None]]

def loadModel(repo_id):
global llm_chain, llm
if repo_id:
model, tokenizer = initialize_model_and_tokenizer(repo_id)
llm_chain, llm = init_chain(model, tokenizer)
return gr.update(value=repo_id)
else:
raise gr.Error("Repo can not be empty!")

Expand All @@ -91,7 +93,7 @@ def bot(history):

msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot)
stop.click(lambda: None, None, chatbot, queue=False)
loadModelBtn.click(loadModel, repo_id, repo_id, queue=False, show_progress="full")
load_model_btn.click(loadModel, repo_id, repo_id, queue=False, show_progress="full")

demo.queue()
demo.launch(server_name="0.0.0.0")

0 comments on commit 3c87287

Please sign in to comment.