Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate the stateless llama logic #729

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gpetters-amd
Copy link
Contributor

It's producing vmfbs now, just needs some more cleanup and vmfb runner logic if we want to do that here.

@dan-garvey
Copy link
Member

why are we pulling this to @monorimet's branch? this should be fine standalone, @monorimet can rebase after merging it? That we we get good test coverage?

models/turbine_models/custom_models/stateless_llama.py Outdated Show resolved Hide resolved
models/turbine_models/custom_models/stateless_llama.py Outdated Show resolved Hide resolved
device_inputs = [
ireert.asdevicearray(self.device, input_tensor)
]
if self.first_input: # or not self.streaming_llm:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

streaming llm commented code? is this from the original?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I redid some of the logic to remove non-streaming since I thought our plan was to only support streaming, but I think that's not actually the case. I'll add the support back in.

@@ -3,6 +3,7 @@
import re
import json
from turbine_models.turbine_tank import turbine_tank
from pathlib import Path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?

@@ -489,26 +491,362 @@ def evict_kvcache_space(self):
return blob_name, tokenizer


llm_model_map = {
"meta-llama/Llama-2-7b-chat-hf": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not supporting the larger models we care about like 13b and 70b

@@ -489,26 +491,362 @@ def evict_kvcache_space(self):
return blob_name, tokenizer


llm_model_map = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might belong in a separate config file to reduce clutter, also more limiting to have this without some default setup

pipeline_dir: str | Path = "./shark_vmfbs",
external_weights_dir: str | Path = "./shark_weights",
external_weights: str = "safetensors",
custom_vae: str = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove vae and other unnecessary flags that look to come from SD (scheduler etc etc)

}


class StatelessLlamaPipeline:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama doesn't really have a pipeline, might want to remove pipeline references


# FILE MANAGEMENT AND PIPELINE SETUP

def check_prepared(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file management looks to be copied from SD code, can we just combine to reduce repeated code?


# RUN

def chat(self, prompt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used? Looks like this should be part of llm_runner.py, stateless_llama.py should be just for tracing and generating IR and or compiling vmfbs.

@gpetters-amd gpetters-amd changed the base branch from ean-unify-sd to main July 2, 2024 23:05
pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any sdxl related changes should be moved to a different PR

@@ -0,0 +1,169 @@
// Copyright 2024 The IREE Authors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only for argmax right? We should also pull in all the transform spec changes that apply from the sdxl spec file as well

##############################################################################

p.add_argument(
"--seed", type=float, default=0, help="Seed for random number/latents generation."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama doesn't need a seed does it?

help="Path to location of vmfb files.",
)

p.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary weight flags for llama. We are only using 1 external weight file so could remove external_weights_dir, and I don't think we need external_weight_file below,

@@ -8,7 +8,7 @@
from iree.compiler.ir import Context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let keep the separate model updates in separate patches. Makes it easier to track and revert patches if ever needed

@@ -90,20 +90,42 @@ def test_vmfb_comparison(self):

upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")

blob_name = llama.export_transformer_model(
# blob_name = llama.export_transformer_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented code?

(And accidentally undo some cleanup, oops)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants