Skip to content

Commit

Permalink
multi-modality model construction support (pytorch#1068)
Browse files Browse the repository at this point in the history
This PR makes torchchat support multi-modality model definition and constructions. To show our power in multi-modality area, we integrate flamingo component into our system.
Note that this is only for bare-minimum support for model definition. Please check openai_api_multimodal branch for e2e, and pytorch#1123 (comment) for better structure and llama3.1 support
  • Loading branch information
Gasoonjia authored Sep 11, 2024
1 parent c272df4 commit 964d437
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 35 deletions.
2 changes: 1 addition & 1 deletion dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main():
gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

config = ModelArgs.from_name(MODEL_NAME).text_transformer_args
config = ModelArgs.from_name(MODEL_NAME).transformer_args['text']
logger.info(f"Chat Model Config: {config}")

tokenizer = _build_chat_tokenizer()
Expand Down
2 changes: 1 addition & 1 deletion distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def apply_tp(
# after we apply TP to the model. Because we don't want to change model code
# when applying TP. We need to have change to ensure KVCache has the correct
# size as k and v.
model.config.text_transformer_args.n_local_heads = model.config.text_transformer_args.n_local_heads // tp_mesh.size()
model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size()

# Apply tensor parallelism to every transformer block
for transformer_block in model.layers:
Expand Down
4 changes: 2 additions & 2 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def validate_model(

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
use_tiktoken = model.config.text_transformer_args.use_tiktoken
use_tiktoken = model.config.transformer_args["text"].use_tiktoken

if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
raise RuntimeError(
Expand Down Expand Up @@ -534,7 +534,7 @@ def _initialize_model(
if builder_args.setup_caches:
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length or model.config.text_transformer_args.max_seq_length
max_batch_size=1, max_seq_length=max_seq_length or model.config.transformer_args["text"].max_seq_length
)

model.to(dtype=builder_args.precision)
Expand Down
2 changes: 1 addition & 1 deletion torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def convert_hf_checkpoint(
if model_name is None:
model_name = model_dir.name

config = ModelArgs.from_name(model_name).text_transformer_args
config = ModelArgs.from_name(model_name).transformer_args['text']
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
Expand Down
4 changes: 2 additions & 2 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def export_for_server(
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
)

seq = Dim("seq", min=1, max=model.config.text_transformer_args.max_seq_length)
seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length)
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}}
else:
input = (
torch.tensor([[1]], dtype=torch.int, device=device),
Expand Down
4 changes: 2 additions & 2 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def chat(
self.system_prompt = None
# Set up our max_seq_length
if generator_args.chat_mode:
max_seq_length = self.model.config.text_transformer_args.max_seq_length
max_seq_length = self.model.config.transformer_args["text"].max_seq_length
print(
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
)
Expand All @@ -700,7 +700,7 @@ def chat(
else:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
self.model.config.text_transformer_args.block_size,
self.model.config.transformer_args["text"].block_size,
)

max_seq_length = (
Expand Down
178 changes: 155 additions & 23 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# LICENSE file in the root directory of this source tree.
import json
import os
import warnings

from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, Optional
from typing import Callable, Dict, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -26,8 +28,72 @@

from torchchat.utils.build_utils import find_multiple, get_precision

# bypass the import issue, if any
# TODO: remove this once the torchao is ready on macos
try:
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
from torchtune.modules.model_fusion import DeepFusionModel
except:
pass

config_path = Path(f"{str(Path(__file__).parent)}/model_params")

class ModelType(Enum):
TextOnly = "text_only"
Flamingo = "flamingo"

# Type for objects that can generate nn.Module instance
ModuleLike = Union[nn.Module, Callable[..., nn.Module]]

@dataclass
class ModelRecipe:
"""
The class describes and contains all supported model structures in torchchat.
ModelRecipe represents a model as a collection of Transformer modules and a fusion module,
providing a standardized and centralized way to define and build models in torchchat.
Attributes:
model_type (ModelType):
The type of the model.
modules (Dict[str, ModuleLike]):
A dictionary of ModuleLike modules, where each key is the module name and each
value is a ModuleLike object that generates the transformer.
The names of the Transformer modules should match the corresponding names in the
fusion class and the JSON file holding model hyperparameters.
fusion_class (ModuleLike):
A ModuleLike object that generates a fusion module by taking the constructed modules above.
"""

model_type: ModelType
modules: Dict[str, ModuleLike]
fusion_class: ModuleLike

@classmethod
def _text_only(cls):
return cls(
model_type=ModelType.TextOnly,
modules={'text_transformer': Transformer},
fusion_class=nn.Identity,
)
@classmethod
def _flamingo(cls):
return cls(
model_type=ModelType.Flamingo,
modules={
'encoder': flamingo_vision_encoder,
'decoder': flamingo_decoder
},
fusion_class=DeepFusionModel,
)

@classmethod
def get_recipe(cls, model_type):
if model_type == ModelType.TextOnly:
return cls._text_only()
elif model_type == ModelType.Flamingo:
return cls._flamingo()
else:
raise ValueError(f"Can not find the model recipe for {model_type}")

@dataclass
class TransformerArgs:
Expand Down Expand Up @@ -77,13 +143,33 @@ def from_params(cls, params):
params[_to] = params.pop(_from)
return cls(**params)


@dataclass
class ModelArgs:
text_transformer_args: TransformerArgs
model_type: ModelType
transformer_args: Dict[str, Union[Dict, TransformerArgs]]

def __post_init__(self):
assert self.text_transformer_args is not None
assert type(self.text_transformer_args) == TransformerArgs
def __init__(
self,
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
model_type: ModelType = ModelType.TextOnly,
) -> None:
self._sanity_check(transformer_args, model_type)

self.model_type = model_type
if isinstance(transformer_args, TransformerArgs):
assert model_type == ModelType.TextOnly
self.transformer_args = {"text": transformer_args}
else:
self.transformer_args = transformer_args

def _sanity_check(
self,
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
model_type: ModelType,
) -> None:
assert isinstance(model_type, ModelType)
assert isinstance(transformer_args, (TransformerArgs, dict))

@classmethod
def from_params(cls, params_path):
Expand All @@ -92,18 +178,18 @@ def from_params(cls, params_path):

try:
# try to interpret as a single transformer config
text_transformer_args = TransformerArgs.from_params(
loaded_params
)
transformer_args: Dict[str, TransformerArgs] = {}
transformer_args["text"] = TransformerArgs.from_params(loaded_params)
model_type = ModelType.TextOnly
except TypeError:
# try to interpret as a dict of transformer configs
for name, params in loaded_params.items():
if name == "text":
text_transformer_args = TransformerArgs.from_params(params)
else:
raise ValueError(f"Unknown transformer name {name}")
model_type = ModelType(loaded_params["model_type"])

return cls(text_transformer_args)
# Currently only supporting flamingo model
assert model_type == ModelType.Flamingo
transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"}

return cls(transformer_args, model_type)

@classmethod
def from_table(cls, name: str):
Expand Down Expand Up @@ -181,16 +267,61 @@ def update(self, input_pos, k_val, v_val):


class Model(nn.Module):
"""
The entrance for model construction in torchchat.
"""
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.text_transformer = Transformer(config.text_transformer_args)

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return self.text_transformer(idx, input_pos)
# TODO: unify the model init logic
if config.model_type == ModelType.TextOnly:
self.text_transformer = Transformer(config.transformer_args["text"])
else:
self.model = self.build_model()

def setup_caches(self, max_batch_size, max_seq_length):
self.text_transformer.setup_caches(max_batch_size, max_seq_length)
def build_model(self) -> nn.Module:
"""
Builds a model based on the provided configuration.
This method retrieves a ModelRecipe instance corresponding to the specified model type,
constructs the required Transformer modules, and combines them using the fusion class.
Returns:
The constructed model instance.
"""
recipe = ModelRecipe.get_recipe(self.config.model_type)
modules = {}
for name, module_class in recipe.modules.items():
modules[name] = module_class(**self.config.transformer_args[name])

return recipe.fusion_class(**modules)

def forward(self,
tokens: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
encoder_input: Optional[Dict[str, Tensor]] = None,
encoder_mask: Optional[Tensor] = None) -> Tensor:

if self.config.model_type == ModelType.TextOnly:
return self.text_transformer(tokens, input_pos)
else:
assert self.config.model_type == ModelType.Flamingo
if input_pos:
warnings.warn("input_pos is not used for Flamingo model. Ignoring it.")
if encoder_input is None:
return self.model(tokens, encoder_mask = encoder_mask)
return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask)

def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None):
if self.config.model_type == ModelType.TextOnly:
self.text_transformer.setup_caches(max_batch_size, max_seq_length)
else:
assert self.config.model_type == ModelType.Flamingo
if max_seq_length is not None:
warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.")
self.model.setup_caches(max_batch_size, dtype=dtype)

def reset_caches(self):
assert self.config.model_type == ModelType.Flamingo
self.model.reset_caches()

@classmethod
def from_name(cls, name: str):
Expand Down Expand Up @@ -564,11 +695,11 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
# ExecuTorch model components
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

try:
try:
from executorch.extension.pybindings import portable_lib as exec_lib

# ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately.
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa

class PTEModel(nn.Module):
def __init__(self, config, path) -> None:
Expand All @@ -589,5 +720,6 @@ def forward(self, x, input_pos):

def setup_caches(self, max_batch_size, max_seq_length):
pass

except:
pass
2 changes: 1 addition & 1 deletion torchchat/usages/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
T = prompt.size(0)
T_new = T + max_new_tokens
if max_seq_length is None:
max_seq_length = min(T_new, model.config.text_transformer_args.block_size)
max_seq_length = min(T_new, model.config.transformer_args["text"].block_size)

device, dtype = prompt.device, prompt.dtype
# create an empty tensor of the expected final shape and
Expand Down
4 changes: 2 additions & 2 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)
self.max_seq_length = (
self.model.config.text_transformer_args.max_seq_length
self.model.config.transformer_args["text"].max_seq_length
+ self.speculative_builder_args.speculate_k
+ 1
if self.draft_model is not None
else self.model.config.text_transformer_args.max_seq_length
else self.model.config.transformer_args["text"].max_seq_length
)
# The System fingerprint is a unique identifier for the model and its configuration.
self.system_fingerprint = (
Expand Down

0 comments on commit 964d437

Please sign in to comment.