From 854c43291f4dba31d03ecdc54b6adf5632f85ae9 Mon Sep 17 00:00:00 2001 From: Parveen Kumar <89995648+parveen232@users.noreply.github.com> Date: Mon, 1 Apr 2024 20:24:27 +0530 Subject: [PATCH] Fix: Remove Cached Models (#16) * fix: Remove Cached Models * refactor: default repo id, remove_dir function --- core.py | 12 ++++++++++-- webui.py | 13 +++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/core.py b/core.py index 4bd0c8a..a057dc7 100644 --- a/core.py +++ b/core.py @@ -1,5 +1,9 @@ import os, shutil +default_repo_id = "openai-community/gpt2" +default_repo_id_parts = default_repo_id.split("/") +default_model_folder = f"models--{'--'.join(default_repo_id_parts)}" + def format_model_name(directory_name): parts = directory_name.split("--") return "/".join(parts[1:]) @@ -11,7 +15,11 @@ def list_download_models(cache_dir): def remove_dir(path): try: - shutil.rmtree(os.path.join(path, "/*")) - print(f"Directory '{path}' successfully removed.") + for model in os.listdir(path): + if model != default_model_folder: + model_path = os.path.join(path, model) + if os.path.isdir(model_path): + shutil.rmtree(model_path) + print("successfully removed cached models!") except OSError as e: print(f"Error: {e.strerror}") \ No newline at end of file diff --git a/webui.py b/webui.py index ce23799..d44816f 100644 --- a/webui.py +++ b/webui.py @@ -7,10 +7,10 @@ from langchain.llms.base import LLM from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoConfig -from core import list_download_models, format_model_name, remove_dir +from core import list_download_models, format_model_name, remove_dir, default_repo_id cache_dir = os.path.join(os.getcwd(), "models") -saved_models = list_download_models(cache_dir) +saved_models_list = list_download_models(cache_dir) #check if cuda is available device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -59,7 +59,7 @@ def _llm_type(self) -> str: llm_chain = LLMChain(prompt=prompt, llm=llm) return llm_chain, llm -model, tokenizer = initialize_model_and_tokenizer("openai-community/gpt2") +model, tokenizer = initialize_model_and_tokenizer(default_repo_id) with gr.Blocks(fill_height=True) as demo: with gr.Row(): @@ -71,7 +71,7 @@ def _llm_type(self) -> str: interactive=True) with gr.Group(): repo_id = gr.Textbox( - value="openai-community/gpt2", + value=default_repo_id, label="Hugging Face Repo", info="Default: openai-community/gpt2") load_model_btn = gr.Button( @@ -88,7 +88,7 @@ def _llm_type(self) -> str: with gr.Group(): saved_models = gr.Dropdown( - choices=saved_models, + choices=saved_models_list, max_choices=5, filterable=True, label="Saved Models", @@ -112,6 +112,7 @@ def user(user_message, history): def removeModelCache(): remove_dir(cache_dir) + return gr.update(value=default_repo_id), gr.update(choices=[default_repo_id]) def updateExecutionProvider(provider): if provider == "cuda": @@ -147,7 +148,7 @@ def bot(history): load_model_btn.click(loadModel, repo_id, repo_id, queue=False, show_progress="full") execution_provider.change(fn=updateExecutionProvider, inputs=execution_provider, queue=False, show_progress="full") saved_models.change(loadModel, saved_models, repo_id, queue=False, show_progress="full") - offload_models.click(removeModelCache, None, saved_models, queue=False, show_progress="full") + offload_models.click(removeModelCache, None, [repo_id, saved_models], queue=False, show_progress="full") demo.queue() demo.launch(server_name="0.0.0.0") \ No newline at end of file