diff --git a/generate.py b/generate.py index 5ecf356f0..9d7e5437a 100644 --- a/generate.py +++ b/generate.py @@ -33,7 +33,6 @@ import fire import torch -from peft import PeftModel from transformers import GenerationConfig, AutoModel, TextIteratorStreamer from accelerate import init_empty_weights, infer_auto_device_map @@ -710,6 +709,7 @@ def get_model( base_model, **model_kwargs ) + from peft import PeftModel # loads cuda, so avoid in global scope model = PeftModel.from_pretrained( model, lora_weights, @@ -727,6 +727,7 @@ def get_model( base_model, **model_kwargs ) + from peft import PeftModel # loads cuda, so avoid in global scope model = PeftModel.from_pretrained( model, lora_weights, @@ -827,24 +828,27 @@ def get_score_model(score_model: str = None, 'iinput_nochat', ] +gen_hyper = ['temperature', + 'top_p', + 'top_k', + 'num_beams', + 'max_new_tokens', + 'min_new_tokens', + 'early_stopping', + 'max_time', + 'repetition_penalty', + 'num_return_sequences', + 'do_sample', + ] + eval_func_param_names = ['instruction', 'iinput', 'context', 'stream_output', 'prompt_type', - 'prompt_dict', - 'temperature', - 'top_p', - 'top_k', - 'num_beams', - 'max_new_tokens', - 'min_new_tokens', - 'early_stopping', - 'max_time', - 'repetition_penalty', - 'num_return_sequences', - 'do_sample', - 'chat', + 'prompt_dict'] + \ + gen_hyper + \ + ['chat', 'instruction_nochat', 'iinput_nochat', 'langchain_mode', @@ -1086,7 +1090,6 @@ def evaluate( db=db1, user_path=user_path, detect_user_path_changes_every_query=detect_user_path_changes_every_query, - max_new_tokens=max_new_tokens, cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary use_openai_embedding=use_openai_embedding, use_openai_model=use_openai_model, @@ -1099,10 +1102,20 @@ def evaluate( document_choice=document_choice, db_type=db_type, top_k_docs=top_k_docs, + + # gen_hyper: + do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, + num_beams=num_beams, + min_new_tokens=min_new_tokens, + max_new_tokens=max_new_tokens, + early_stopping=early_stopping, + max_time=max_time, + num_return_sequences=num_return_sequences, + prompt_type=prompt_type, prompt_dict=prompt_dict, n_jobs=n_jobs, diff --git a/gpt_langchain.py b/gpt_langchain.py index c142df4a5..facc635cd 100644 --- a/gpt_langchain.py +++ b/gpt_langchain.py @@ -22,6 +22,7 @@ from tqdm import tqdm from enums import DocumentChoices +from generate import gen_hyper from prompter import non_hf_types, PromptType from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \ get_device, ProgressParallel, remove, hash_file, clear_torch_cache @@ -261,11 +262,17 @@ def get_answer_from_sources(chain, sources, question): def get_llm(use_openai_model=False, model_name=None, model=None, tokenizer=None, stream_output=False, - max_new_tokens=256, + do_sample=False, temperature=0.1, - repetition_penalty=1.0, top_k=40, top_p=0.7, + num_beams=1, + max_new_tokens=256, + min_new_tokens=1, + early_stopping=False, + max_time=180, + repetition_penalty=1.0, + num_return_sequences=1, prompt_type=None, prompt_dict=None, prompter=None, @@ -312,10 +319,20 @@ def get_llm(use_openai_model=False, model_name=None, model=None, load_in_8bit=load_8bit) max_max_tokens = tokenizer.model_max_length - gen_kwargs = dict(max_new_tokens=max_new_tokens, + gen_kwargs = dict(do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + early_stopping=early_stopping, + max_time=max_time, + repetition_penalty=repetition_penalty, + num_return_sequences=num_return_sequences, return_full_text=True, - early_stopping=False, handle_long_generation='hole') + assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0 if stream_output: skip_prompt = False @@ -1235,11 +1252,17 @@ def _run_qa_db(query=None, show_rank=False, load_db_if_exists=False, db=None, - max_new_tokens=256, + do_sample=False, temperature=0.1, - repetition_penalty=1.0, top_k=40, top_p=0.7, + num_beams=1, + max_new_tokens=256, + min_new_tokens=1, + early_stopping=False, + max_time=180, + repetition_penalty=1.0, + num_return_sequences=1, langchain_mode=None, document_choice=[DocumentChoices.All_Relevant.name], n_jobs=-1, @@ -1274,14 +1297,21 @@ def _run_qa_db(query=None, assert prompt_dict is not None # should at least be {} or '' else: prompt_dict = '' + assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0 llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name, model=model, tokenizer=tokenizer, stream_output=stream_output, - max_new_tokens=max_new_tokens, + do_sample=do_sample, temperature=temperature, - repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + early_stopping=early_stopping, + max_time=max_time, + repetition_penalty=repetition_penalty, + num_return_sequences=num_return_sequences, prompt_type=prompt_type, prompt_dict=prompt_dict, prompter=prompter, @@ -1609,6 +1639,7 @@ def get_some_dbs_from_hf(dest='.', db_zips=None): assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected + def _create_local_weaviate_client(): WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080") WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME') @@ -1629,5 +1660,6 @@ def _create_local_weaviate_client(): print(f"Failed to create Weaviate client: {e}") return None + if __name__ == '__main__': pass diff --git a/utils.py b/utils.py index 22e5e2186..f012acafe 100644 --- a/utils.py +++ b/utils.py @@ -14,7 +14,6 @@ import traceback import zipfile from datetime import datetime -from enum import Enum import filelock import requests, uuid