Skip to content

Commit

Permalink
Fix: Remove Cached Models (#16)
Browse files Browse the repository at this point in the history
* fix: Remove Cached Models

* refactor: default repo id, remove_dir function
  • Loading branch information
parveen232 authored Apr 1, 2024
1 parent 83beb27 commit 854c432
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
12 changes: 10 additions & 2 deletions core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os, shutil

default_repo_id = "openai-community/gpt2"
default_repo_id_parts = default_repo_id.split("/")
default_model_folder = f"models--{'--'.join(default_repo_id_parts)}"

def format_model_name(directory_name):
parts = directory_name.split("--")
return "/".join(parts[1:])
Expand All @@ -11,7 +15,11 @@ def list_download_models(cache_dir):

def remove_dir(path):
try:
shutil.rmtree(os.path.join(path, "/*"))
print(f"Directory '{path}' successfully removed.")
for model in os.listdir(path):
if model != default_model_folder:
model_path = os.path.join(path, model)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
print("successfully removed cached models!")
except OSError as e:
print(f"Error: {e.strerror}")
13 changes: 7 additions & 6 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoConfig

from core import list_download_models, format_model_name, remove_dir
from core import list_download_models, format_model_name, remove_dir, default_repo_id

cache_dir = os.path.join(os.getcwd(), "models")
saved_models = list_download_models(cache_dir)
saved_models_list = list_download_models(cache_dir)

#check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -59,7 +59,7 @@ def _llm_type(self) -> str:
llm_chain = LLMChain(prompt=prompt, llm=llm)
return llm_chain, llm

model, tokenizer = initialize_model_and_tokenizer("openai-community/gpt2")
model, tokenizer = initialize_model_and_tokenizer(default_repo_id)

with gr.Blocks(fill_height=True) as demo:
with gr.Row():
Expand All @@ -71,7 +71,7 @@ def _llm_type(self) -> str:
interactive=True)
with gr.Group():
repo_id = gr.Textbox(
value="openai-community/gpt2",
value=default_repo_id,
label="Hugging Face Repo",
info="Default: openai-community/gpt2")
load_model_btn = gr.Button(
Expand All @@ -88,7 +88,7 @@ def _llm_type(self) -> str:

with gr.Group():
saved_models = gr.Dropdown(
choices=saved_models,
choices=saved_models_list,
max_choices=5,
filterable=True,
label="Saved Models",
Expand All @@ -112,6 +112,7 @@ def user(user_message, history):

def removeModelCache():
remove_dir(cache_dir)
return gr.update(value=default_repo_id), gr.update(choices=[default_repo_id])

def updateExecutionProvider(provider):
if provider == "cuda":
Expand Down Expand Up @@ -147,7 +148,7 @@ def bot(history):
load_model_btn.click(loadModel, repo_id, repo_id, queue=False, show_progress="full")
execution_provider.change(fn=updateExecutionProvider, inputs=execution_provider, queue=False, show_progress="full")
saved_models.change(loadModel, saved_models, repo_id, queue=False, show_progress="full")
offload_models.click(removeModelCache, None, saved_models, queue=False, show_progress="full")
offload_models.click(removeModelCache, None, [repo_id, saved_models], queue=False, show_progress="full")

demo.queue()
demo.launch(server_name="0.0.0.0")

0 comments on commit 854c432

Please sign in to comment.