Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Sep 26, 2024
1 parent 70e47b2 commit 0217dd2
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
51 changes: 51 additions & 0 deletions tests/inference/huggingface_inference_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
GenerationConfig,
)

model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
do_sample = False
max_length = 128
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto",)
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained(model_name)
print(generation_config.do_sample)
generation_config.do_sample = do_sample
generation_config.num_beams=1
generation_config.temperature = None
generation_config.top_p = None


def run_text_completion():
prompt = "Help me plan a 1-week trip to Dubai"
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

generated = model.generate(
batch["input_ids"],
max_new_tokens=max_length,
generation_config=generation_config,
)
out = tokenizer.decode(generated[0])
print(out)

def run_chat_completion():
messages=[
{"role": "system", "content": "You are a helpful an honest programming assistant."},
{"role": "user", "content": "Is Rust better than Python?"},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
batch = tokenizer(tokenized_chat, return_tensors="pt")

generated = model.generate(
batch["input_ids"],
max_new_tokens=max_length,
generation_config=generation_config,
)
out = tokenizer.decode(generated[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
prompt_length = len(tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
all_text = out[prompt_length:]
print(all_text)
run_chat_completion()
33 changes: 33 additions & 0 deletions tests/inference/huggingface_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import transformers
from transformers import GenerationConfig

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
do_sample = False

generation_config = GenerationConfig.from_pretrained(model_id)
generation_config.do_sample = do_sample
generation_config.num_beams=1
# generation_config.max_length = 128
generation_config.temperature = None
generation_config.top_p = None
print(generation_config)

pipeline = transformers.pipeline(
"text-generation",
model=model_id,
# model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)

messages=[
{"role": "system", "content": "You are a helpful an honest programming assistant."},
{"role": "user", "content": "Is Rust better than Python?"},
]

# messages="Help me plan a 1-week trip to Dubai"
outputs = pipeline(
messages,
max_new_tokens=128,
generation_config=generation_config,
)
print(outputs[0]["generated_text"][-1]['content'])

0 comments on commit 0217dd2

Please sign in to comment.