Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime]Add GGUF API UT (#1160)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenzhong1 authored Jan 19, 2024
1 parent ea58cd5 commit 1383c76
Showing 1 changed file with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def test_llm_runtime(self):
print(config_type, cmpData(pt_logits.detach().numpy().flatten(), itrex_logits.flatten()))


def test_gguf_api(self):
model_name = "TheBloke/Mistral-7B-v0.1-GGUF"
model_file = "mistral-7b-v0.1.Q4_0.gguf"
tokenizer_name = "/tf_dataset2/models/pytorch/Mistral-7B-v0.1"

prompt = "Once upon a time"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").input_ids
streamer = TextStreamer(tokenizer)

model = AutoModelForCausalLM.from_pretrained(model_name, model_file = model_file)
output = model.generate(inputs, streamer=streamer, max_new_tokens=10)
print("output = ", output)
assert(output == [[1, 5713, 3714, 264, 727, 28725, 736, 403, 264, 1628, 2746, 693, 6045, 298, 1220, 28723, 985]])


def test_beam_search(self):
model_name = "/tf_dataset2/models/pytorch/gpt-j-6B" # or local path to model
prompts = [
Expand Down

0 comments on commit 1383c76

Please sign in to comment.