diff --git a/llm.py b/llm.py index 2b8c8e0..4e25637 100644 --- a/llm.py +++ b/llm.py @@ -7,7 +7,7 @@ from langchain.llms.base import LLM from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer -def initialize_model_and_tokenizer(model_name="stabilityai/stable-code-instruct-3b"): +def initialize_model_and_tokenizer(model_name): model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True) model.eval() #model.cuda() #uncomment for cuda @@ -42,7 +42,7 @@ def _llm_type(self) -> str: llm_chain = LLMChain(prompt=prompt, llm=llm) return llm_chain, llm -model, tokenizer = initialize_model_and_tokenizer() +model, tokenizer = initialize_model_and_tokenizer("stabilityai/stable-code-instruct-3b") with gr.Blocks(fill_height=True) as demo: with gr.Row(): @@ -53,7 +53,7 @@ def _llm_type(self) -> str: repo_id = gr.Textbox( label="Hugging Face Repo", info="Default: stabilityai/stable-code-instruct-3b") - loadModelBtn = gr.Button( + load_model_btn = gr.Button( value="Load Model", variant="secondary", interactive=True,) @@ -73,9 +73,11 @@ def user(user_message, history): return "", history + [[user_message, None]] def loadModel(repo_id): + global llm_chain, llm if repo_id: model, tokenizer = initialize_model_and_tokenizer(repo_id) llm_chain, llm = init_chain(model, tokenizer) + return gr.update(value=repo_id) else: raise gr.Error("Repo can not be empty!") @@ -91,7 +93,7 @@ def bot(history): msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot) stop.click(lambda: None, None, chatbot, queue=False) - loadModelBtn.click(loadModel, repo_id, repo_id, queue=False, show_progress="full") + load_model_btn.click(loadModel, repo_id, repo_id, queue=False, show_progress="full") demo.queue() demo.launch(server_name="0.0.0.0") \ No newline at end of file