Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[Cpp Graph] Beam Search Pybind (model archs: gptj and gptneox) #449

Merged
merged 12 commits into from
Oct 17, 2023

Conversation

zhentaoyu
Copy link
Contributor

@zhentaoyu zhentaoyu commented Oct 12, 2023

Type of Change

support polyglot-5.8b cpp inference and its beam search
support gpt-neox beam search and pybind beam_search
API NO change

Description

detail description
JIRA ticket: 920

TODO
- [ ] cpp tokenizer (only for polyglot)

  • graph convert-quant-load-inference
  • cpp beam search (polyglot & gptneo-x)
  • beam search pybind
    - [ ] python ut
    polyglot related tasks will be reopened in the next PR due to the Jira priority

Expected Behavior & Potential Risk

add python ut. verify its result by python api with transformers tokenizer class first (send ints and receive ints)

How has this PR been tested?

python ut
Golden res: transformers infer

from transformers import pipeline, set_seed, AutoModelForCausalLM, AutoTokenizer

model_dir = "polyglot-ko-5.8b" # "gptneox-20b"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir)
model.eval()

prompt = "she opened the door and see"   #"What is the meaning of life?"
inputs = tokenizer(prompt, return_tensors="pt")
print("inputs", inputs)

out = model(input_ids = inputs.input_ids)
print("first token logits:")
print(out['logits'][0][0-1][:32])

# beam search
generate_ids = model.generate(inputs.input_ids, num_beams=4, max_new_tokens=128, min_new_tokens=30, early_stopping=True)
ans = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(ans)

pybind example:
naive version (without transformers)

from transformers import AutoTokenizer
from intel_extension_for_transformers.llm.runtime.graph import Model
model_name = "gpt-neox-20b"
prompt = "Once upon a time, a little girl"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").input_ids
model = Model()
# fp32 or int4
model.init_from_bin("gptneox", "fp32.bin", num_beams=4, max_new_tokens=128, min_new_tokens=30, early_stopping=True)
outputs = model.generate(inputs, num_beams=4, max_new_tokens=128, min_new_tokens=30, early_stopping=True)
ans = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(ans)

high-level version (with transformers, updated in python_api_example.py)

from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

model_name = "gpt-neox-20b"
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
prompt = "Once upon a time, a little girl"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").input_ids
streamer = TextStreamer(tokenizer)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
# top_k_top_p sample or greedy_search
outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300)
# beam search
outputs = model.generate(inputs, num_beams=4, max_new_tokens=128, min_new_tokens=30, early_stopping=True)
ans = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(ans)

Dependency Change?

None

@zhentaoyu zhentaoyu requested a review from airMeng as a code owner October 12, 2023 06:10
@zhentaoyu zhentaoyu changed the title Polyglot [Cpp Graph] Polyglot and Beam Search Pybind Oct 12, 2023
@zhentaoyu zhentaoyu marked this pull request as draft October 12, 2023 06:11
@zhentaoyu
Copy link
Contributor Author

GPT-NEO-X 20B FP32 beam search comparisons:
prompt = "she opened the door and see"

transformers && cpp graph give the same outputs:

she opened the door and see who it was.

"Oh, it's you," she said.

"Yes, it's me."

"What do you want?"

"I want to talk to you."

"What about?"

"You know what about."

"No, I don't."

"Yes, you do."

"No, I don't."

"Yes, you do."

"No, I don't."

"Yes, you do."

"No, I don't."

"Yes, you do."

"

@zhentaoyu zhentaoyu force-pushed the polyglot branch 2 times, most recently from a692b94 to 25fcac3 Compare October 16, 2023 02:57
@zhentaoyu zhentaoyu changed the title [Cpp Graph] Polyglot and Beam Search Pybind [Cpp Graph] Beam Search Pybind (model archs: gptj and gptneox) Oct 16, 2023
Copy link
Contributor

@a32543254 a32543254 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@a32543254 a32543254 marked this pull request as ready for review October 16, 2023 05:24
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
@zhentaoyu
Copy link
Contributor Author

zhentaoyu commented Oct 16, 2023

gpt-neox-20b pybind beam_search outputs:

'fp32'

she opened the door and see who it was.

"Oh, it's you," she said.

"Yes, it's me."

"What do you want?"

"I want to talk to you."

"What about?"

"You know what about."

"No, I don't."

"Yes, you do."

"No, I don't."

"Yes, you do."

"No, I don't."

"Yes, you do."

"No, I don't."

"Yes, you do."

"

int4 (q4_0)

she opened the door and see what was going on.

"What's going on?" she asked.

"I don't know," I said.

"What do you mean, you don't know?" she asked.

"I don't know what's going on," I said.

"What do you mean, you don't know what's going on?" she asked.

"I don't know what's going on," I said.

"What do you mean, you don't know what's going on?" she asked.

"I don't know what's going on," I said.

Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
@zhentaoyu zhentaoyu requested a review from DDEle October 16, 2023 06:05
Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
@zhentaoyu zhentaoyu removed the draft label Oct 16, 2023
@zhentaoyu
Copy link
Contributor Author

gpt-j-6b pybind beam_search outputs:

fp32 (as same as transformers):

she opened the door and see me standing there.

"What are you doing here?" she asked.

"I came to see you," I said.

"I don't want to see you," she said.

"Why not?" I asked.

"Because I don't want to see you," she said.

"Why not?" I asked.

"Because I don't want to see you," she said.

"Why not?" I asked.

"Because I don't want to see you," she said.

"Why not?" I asked.

"Because I

int4 (q4-j-b128):

she opened the door and see me standing there.

"What are you doing here?" she asked.

"I came to see you," I said.

"I don't want to see you," she said.

"Why not?" I asked.

"I don't want to see you," she said.

"Why not?" I asked.

"I don't want to see you," she said.

"Why not?" I asked.

"I don't want to see you," she said.

"Why not?" I asked.

"I don't want to

Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
@airMeng
Copy link
Contributor

airMeng commented Oct 16, 2023

looking forward to more optimization of post process

Copy link
Contributor

@DDEle DDEle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to having beam-search in main_run.cpp

Comment on lines +306 to +310
logits_out.resize(n_vocab * batch_size);
for (int i = 0; i < batch_size; ++i) {
memcpy(logits_out.data() + (i * n_vocab), (float*)ne_get_data(inpL) + (i * bs_stride) + (n_vocab * (N - 1)),
sizeof(float) * n_vocab);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, it the logits for the last token is only required, why don't we earlier (up to norm in L259?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only happens in the first tokens. maybe we can add a slice kernel before or after LN. However, it may have not much acceleration in the "small" prompt (lm_head GEMM only). But we can try it. cc @a32543254 'cause you asked the same question. We can consider it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ne_view_1d/2d/3d/4d should be able to work as your "slice kernel".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right!!!!. keep it here for waiting for an optimization PR.

@VincyZhang VincyZhang merged commit 958d048 into main Oct 17, 2023
@VincyZhang VincyZhang deleted the polyglot branch October 17, 2023 01:29
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants