Skip to content

Commit

Permalink
Implement config.ini for enhanced persistence of settings (#19)
Browse files Browse the repository at this point in the history
* Support for coinfig.ini for persistent settings

* feat: Implement dynamic config.ini updates and saves model information and config

* feat: Implement dynamic config.ini updates and saves model information and config

* Remove duplicates and logging

---------

Co-authored-by: ashish <ashish@aesthisia.com>
  • Loading branch information
Subhanshu0027 and ashish-aesthisia authored Apr 3, 2024
1 parent 9f76567 commit 59b8940
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
models/models--*
models/.locks
models/tmp*

configs/config.ini

#compiled files
*.pyc
4 changes: 4 additions & 0 deletions configs/config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[Settings]
execution_provider =
repo_id =

18 changes: 17 additions & 1 deletion core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os, shutil
from configparser import ConfigParser

default_repo_id = "stabilityai/stable-code-instruct-3b"
config_path = "configs/config.ini"
default_repo_id_parts = default_repo_id.split("/")
default_model_folder = f"models--{'--'.join(default_repo_id_parts)}"

Expand All @@ -22,4 +24,18 @@ def remove_dir(path):
shutil.rmtree(model_path)
print("successfully removed cached models!")
except OSError as e:
print(f"Error: {e.strerror}")
print(f"Error: {e.strerror}")

def read_config():
config = ConfigParser()
config.read(config_path)
if config.get('Settings', 'repo_id') == "" and config.get('Settings', 'execution_provider') == "":
return None, config
else:
return config, config

def update_config(config, **kwargs):
for key, value in kwargs.items():
config.set('Settings', key, value)
with open(config_path, 'w') as configfile:
config.write(configfile)
17 changes: 16 additions & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,23 @@
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoConfig

from core import list_download_models, format_model_name, remove_dir, default_repo_id
from core import list_download_models, remove_dir, default_repo_id, read_config, update_config

cache_dir = os.path.join(os.getcwd(), "models")
saved_models_list = list_download_models(cache_dir)

#check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
state, config = read_config()
if state == None:
config.set('Settings', 'execution_provider', device)
config.set('Settings', 'repo_id', default_repo_id)

update_config(config)
else:
default_repo_id = config.get('Settings', 'repo_id')
device = config.get('Settings', 'execution_provider')


def initialize_model_and_tokenizer(model_name):
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
Expand Down Expand Up @@ -117,19 +127,24 @@ def removeModelCache():
def updateExecutionProvider(provider):
if provider == "cuda":
if torch.cuda.is_available():
device = "cuda"
model.cuda()
print("Model loaded in cuda", model)
else:
raise gr.Error("Torch not compiled with CUDA enabled. Please make sure cuda is installed.")

else:
device = "cpu"
model.cpu()

update_config(config, execution_provider=provider)

def loadModel(repo_id):
global llm_chain, llm
if repo_id:
model, tokenizer = initialize_model_and_tokenizer(repo_id)
llm_chain, llm = init_chain(model, tokenizer)
update_config(config, repo_id=repo_id)
return gr.update(value=repo_id)
else:
raise gr.Error("Repo can not be empty!")
Expand Down

0 comments on commit 59b8940

Please sign in to comment.