Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure generate hyperparameters are passed through to h2oai_pipelinepy for generation #265

Merged
merged 2 commits into from
Jun 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions generate.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,6 @@

import fire
import torch
from peft import PeftModel
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
from accelerate import init_empty_weights, infer_auto_device_map

@@ -710,6 +709,7 @@ def get_model(
base_model,
**model_kwargs
)
from peft import PeftModel # loads cuda, so avoid in global scope
model = PeftModel.from_pretrained(
model,
lora_weights,
@@ -727,6 +727,7 @@ def get_model(
base_model,
**model_kwargs
)
from peft import PeftModel # loads cuda, so avoid in global scope
model = PeftModel.from_pretrained(
model,
lora_weights,
@@ -827,24 +828,27 @@ def get_score_model(score_model: str = None,
'iinput_nochat',
]

gen_hyper = ['temperature',
'top_p',
'top_k',
'num_beams',
'max_new_tokens',
'min_new_tokens',
'early_stopping',
'max_time',
'repetition_penalty',
'num_return_sequences',
'do_sample',
]

eval_func_param_names = ['instruction',
'iinput',
'context',
'stream_output',
'prompt_type',
'prompt_dict',
'temperature',
'top_p',
'top_k',
'num_beams',
'max_new_tokens',
'min_new_tokens',
'early_stopping',
'max_time',
'repetition_penalty',
'num_return_sequences',
'do_sample',
'chat',
'prompt_dict'] + \
gen_hyper + \
['chat',
'instruction_nochat',
'iinput_nochat',
'langchain_mode',
@@ -1086,7 +1090,6 @@ def evaluate(
db=db1,
user_path=user_path,
detect_user_path_changes_every_query=detect_user_path_changes_every_query,
max_new_tokens=max_new_tokens,
cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
use_openai_embedding=use_openai_embedding,
use_openai_model=use_openai_model,
@@ -1099,10 +1102,20 @@ def evaluate(
document_choice=document_choice,
db_type=db_type,
top_k_docs=top_k_docs,

# gen_hyper:
do_sample=do_sample,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
num_return_sequences=num_return_sequences,

prompt_type=prompt_type,
prompt_dict=prompt_dict,
n_jobs=n_jobs,
48 changes: 40 additions & 8 deletions gpt_langchain.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
from tqdm import tqdm

from enums import DocumentChoices
from generate import gen_hyper
from prompter import non_hf_types, PromptType
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
get_device, ProgressParallel, remove, hash_file, clear_torch_cache
@@ -261,11 +262,17 @@ def get_answer_from_sources(chain, sources, question):

def get_llm(use_openai_model=False, model_name=None, model=None,
tokenizer=None, stream_output=False,
max_new_tokens=256,
do_sample=False,
temperature=0.1,
repetition_penalty=1.0,
top_k=40,
top_p=0.7,
num_beams=1,
max_new_tokens=256,
min_new_tokens=1,
early_stopping=False,
max_time=180,
repetition_penalty=1.0,
num_return_sequences=1,
prompt_type=None,
prompt_dict=None,
prompter=None,
@@ -312,10 +319,20 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
load_in_8bit=load_8bit)

max_max_tokens = tokenizer.model_max_length
gen_kwargs = dict(max_new_tokens=max_new_tokens,
gen_kwargs = dict(do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
return_full_text=True,
early_stopping=False,
handle_long_generation='hole')
assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0

if stream_output:
skip_prompt = False
@@ -1235,11 +1252,17 @@ def _run_qa_db(query=None,
show_rank=False,
load_db_if_exists=False,
db=None,
max_new_tokens=256,
do_sample=False,
temperature=0.1,
repetition_penalty=1.0,
top_k=40,
top_p=0.7,
num_beams=1,
max_new_tokens=256,
min_new_tokens=1,
early_stopping=False,
max_time=180,
repetition_penalty=1.0,
num_return_sequences=1,
langchain_mode=None,
document_choice=[DocumentChoices.All_Relevant.name],
n_jobs=-1,
@@ -1274,14 +1297,21 @@ def _run_qa_db(query=None,
assert prompt_dict is not None # should at least be {} or ''
else:
prompt_dict = ''
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
model=model, tokenizer=tokenizer,
stream_output=stream_output,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
prompter=prompter,
@@ -1609,6 +1639,7 @@ def get_some_dbs_from_hf(dest='.', db_zips=None):
assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected


def _create_local_weaviate_client():
WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
@@ -1629,5 +1660,6 @@ def _create_local_weaviate_client():
print(f"Failed to create Weaviate client: {e}")
return None


if __name__ == '__main__':
pass
1 change: 0 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@
import traceback
import zipfile
from datetime import datetime
from enum import Enum

import filelock
import requests, uuid