Skip to content

Commit

Permalink
add mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
iankur committed Sep 1, 2024
1 parent f03f65a commit e297f87
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 14 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ bash recipes/run_vq_type_ablation.sh
### Notes
- EMA embedding sum and cluster size parameters are kept in full precision. However, rest of the model can be in lower precision. So, `model.to(new_dtype)` should be handled carefully.
- Similarity and EMA update happen in full precision even for low precision inputs. As a result, we accumulate all the residual commitments losses in full precision and cast to original input precision before returning.
- Currently, k-means based initialization uses CPU since GPU based implementations may OOM for ~100K samples. There is minor performance difference between using say ~10K samples vs ~100K samples for the initialization.
- Although seed is set, there seems to be some randomness in current implementation and same setting can lead to small difference in final performance across multiple runs.
- Torchtune text completion dataset returns input and label sequences, both are identical. Shift happens in the recipe. We modify it to return the actual input and target sequence. We also use packed dataset, which uses different padding value for input and label sequences. Padding value for label sequence is set to `CROSS_ENTROPY_IGNORE_IDX`. We use this value to create vq mask to ignore some token embeddings when updating codebook.

### Acknowledgements
Expand Down
106 changes: 106 additions & 0 deletions recipes/config/mistral_7b_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copied from config for single device full finetuning in full_finetune_single_device.py
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-8B --output-dir <OUTPUR_DIR> --hf-token <HF_TOKEN>
#
# To launch on a single device, run the following command from root:
# tune run recipes/full_finetune_single_device.py --config recipes/config/llama3_8B_single_device.yaml
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run recipes/full_finetune_single_device.py --config recipes/config/llama3_8B_single_device.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.


# Tokenizer
tokenizer:
_component_: vqllm.models.mistral_tokenizer
path: /home/ubuntu/vqllm/recipes/ckpts/Mistral-7B-v0.1/tokenizer.model

# Dataset
dataset:
_component_: vqllm.utils.data.text_completion_dataset
source: DKYoon/SlimPajama-6B
split: train
column: text
# train_on_input: True
max_seq_len: 8192
packed: True
num_random_samples: 10000

seed: 1234
shuffle: True

# Model arguments
model:
_component_: vqllm.models.mistral_7b
vq_attn_key: False
vq_attn_value: False
vq_layers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
num_codebooks: 1
num_codebook_entries: None
codebook_entry_dim: None
num_residual_codebooks: None
num_residual_steps: 1
ema_decay: 0.99
use_fast_quantizer: False
vq_attn_key_reorder_channel: True

disable_gradient: True # toggle gradient computation
freeze_checkpoint_params: True
trainable_param_keys: []
wandb_watch_layers: []

checkpointer:
_component_: vqllm.utils.checkpointer.FullModelHFCheckpointer
checkpoint_dir: /home/ubuntu/vqllm/recipes/ckpts/Mistral-7B-v0.1
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
output_dir:
model_type: MISTRAL
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 1
epochs: 1
max_steps_per_epoch: 200
gradient_accumulation_steps: 8

optimizer:
_component_: torch.optim.AdamW
lr: 1e-5

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
vq_loss_scale: 0.25

optimizer_in_bwd: False
compile: False

# Training environment
device: cuda

# Memory management
enable_activation_checkpointing: False

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
project: vqllm
group:
name:
log_every_n_steps: 1
log_peak_memory_stats: False
99 changes: 99 additions & 0 deletions recipes/run_vq_model_ablation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# model args
MODEL="model._component_"
MODEL_TYPE="checkpointer.model_type"
CHECKPOINTER="checkpointer._component_"
TOKENIZER="tokenizer._component_"
TOKENIZER_PATH="tokenizer.path"

# vq args
VQ_KEY="model.vq_attn_key"
VQ_VALUE="model.vq_attn_value"
FAST_QUANTIZER="model.use_fast_quantizer"
RESIDUAL_CODEBOOKS="model.num_residual_codebooks"
CODES="model.num_codebook_entries"
CODE_DIM="model.codebook_entry_dim"
REORDER_CHANNEL="model.vq_attn_key_reorder_channel"

# metric logger is set to wandb
WANDB_PROJECT="metric_logger.project=vqllm"
WANDB_GROUP="metric_logger.group=vq_model_ablation"
WANDB_NAME="metric_logger.name"

dhat=32
C=2048
K=8

# eval meta/llama3-8b
LLAMA_TOKENIZER=torchtune.models.llama3.llama3_tokenizer
LLAMA_TOKENIZER_PATH=/home/ubuntu/vqllm/recipes/ckpts/llama3_8b/original/tokenizer.model
LLAMA_CHECKPOINTER=vqllm.utils.checkpointer.FullModelMetaCheckpointer

tune run recipes/eleuther_eval.py \
--config recipes/config/eleuther_evaluation.yaml \
$MODEL=vqllm.models.llama3_8b $MODEL_TYPE=LLAMA3 \
$CHECKPOINTER=$LLAMA_CHECKPOINTER $TOKENIZER=$LLAMA_TOKENIZER \
$TOKENIZER_PATH=$LLAMA_TOKENIZER_PATH \
$WANDB_PROJECT $WANDB_GROUP $WANDB_NAME="llama3_8b"

# train and eval vqllm/llama3-8b
CKPT_DIR="/home/ubuntu/vqllm/recipes/ckpts/vq_llama3_8b"

# train
tune run recipes/full_finetune_single_device.py \
--config recipes/config/llama3_8b_single_device.yaml \
$MODEL=vqllm.models.llama3_8b $MODEL_TYPE=LLAMA3 \
$CHECKPOINTER=$LLAMA_CHECKPOINTER $TOKENIZER=$LLAMA_TOKENIZER \
$TOKENIZER_PATH=$LLAMA_TOKENIZER_PATH \
$VQ_KEY=True $VQ_VALUE=True $FAST_QUANTIZER=True \
$RESIDUAL_CODEBOOKS=$K $CODES=$C $CODE_DIM=$dhat \
$REORDER_CHANNEL=True checkpointer.output_dir=$CKPT_DIR \
$WANDB_PROJECT $WANDB_GROUP $WANDB_NAME="train_vq_llama3_8b"

# eval
tune run recipes/eleuther_eval.py \
--config recipes/config/eleuther_evaluation.yaml \
$MODEL=vqllm.models.llama3_8b $MODEL_TYPE=LLAMA3 \
$CHECKPOINTER=$LLAMA_CHECKPOINTER $TOKENIZER=$LLAMA_TOKENIZER \
$TOKENIZER_PATH=$LLAMA_TOKENIZER_PATH \
$VQ_KEY=True $VQ_VALUE=True $FAST_QUANTIZER=True \
$RESIDUAL_CODEBOOKS=$K $CODES=$C $CODE_DIM=$dhat \
$REORDER_CHANNEL=True checkpointer.checkpoint_dir=$CKPT_DIR \
checkpointer.checkpoint_files=['meta_model_0.pt'] \
$WANDB_PROJECT $WANDB_GROUP $WANDB_NAME="eval_vq_llama3_8b"

# eval mistral-7b
MISTRAL_TOKENIZER=torchtune.models.mistral_tokenizer
MISTRAL_TOKENIZER_PATH=/home/ubuntu/vqllm/recipes/ckpts/Mistral-7B-v0.1/tokenizer.model
MISTRAL_CHECKPOINTER=vqllm.utils.checkpointer.FullModelHFCheckpointer
MISTRAL_CKPT="[pytorch_model-00001-of-00002.bin,pytorch_model-00002-of-00002.bin]"

tune run recipes/eleuther_eval.py \
--config recipes/config/eleuther_evaluation.yaml \
$MODEL=vqllm.models.mistral_7b $MODEL_TYPE=MISTRAL \
$CHECKPOINTER=$MISTRAL_CHECKPOINTER $TOKENIZER=$MISTRAL_TOKENIZER \
$TOKENIZER_PATH=$MISTRAL_TOKENIZER_PATH \
checkpointer.checkpoint_files=$MISTRAL_CKPT \
$WANDB_PROJECT $WANDB_GROUP $WANDB_NAME="mistral_7b"

# train and eval vqllm/mistral-7b
CKPT_DIR="/home/ubuntu/vqllm/recipes/ckpts/vq_mistral_7b"

# train
tune run recipes/full_finetune_single_device.py \
--config recipes/config/mistral_7b_single_device.yaml \
$VQ_KEY=True $VQ_VALUE=True $FAST_QUANTIZER=True \
$RESIDUAL_CODEBOOKS=$K $CODES=$C $CODE_DIM=$dhat \
$REORDER_CHANNEL=True checkpointer.output_dir=$CKPT_DIR \
$WANDB_PROJECT $WANDB_GROUP $WANDB_NAME="train_vq_mistral_7b"

# eval
tune run recipes/eleuther_eval.py \
--config recipes/config/eleuther_evaluation.yaml \
$MODEL=vqllm.models.mistral_7b $MODEL_TYPE=MISTRAL \
$CHECKPOINTER=$MISTRAL_CHECKPOINTER $TOKENIZER=$MISTRAL_TOKENIZER \
$TOKENIZER_PATH=$MISTRAL_TOKENIZER_PATH \
$VQ_KEY=True $VQ_VALUE=True $FAST_QUANTIZER=True \
$RESIDUAL_CODEBOOKS=$K $CODES=$C $CODE_DIM=$dhat \
$REORDER_CHANNEL=True checkpointer.checkpoint_dir=$CKPT_DIR \
checkpointer.checkpoint_files=['hf_model_0.pt'] \
$WANDB_PROJECT $WANDB_GROUP $WANDB_NAME="eval_vq_mistral_7b"
16 changes: 5 additions & 11 deletions recipes/run_vq_type_ablation.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
# metric logger is set to wandb
WANDB_PROJECT="metric_logger.project=vqllm"
WANDB_GROUP="metric_logger.group=vq_type_ablation"
WANDB_NAME="metric_logger.name"

# eval meta/llama3-8b
tune run recipes/eleuther_eval.py \
--config recipes/config/eleuther_evaluation.yaml \
$WANDB_PROJECT $WANDB_NAME="llama3_8b"


# train and eval vqllm/llama3-8b
VQ_KEY="model.vq_attn_key"
VQ_VALUE="model.vq_attn_value"
Expand All @@ -18,6 +7,11 @@ CODES="model.num_codebook_entries"
CODE_DIM="model.codebook_entry_dim"
REORDER_CHANNEL="model.vq_attn_key_reorder_channel"

# metric logger is set to wandb
WANDB_PROJECT="metric_logger.project=vqllm"
WANDB_GROUP="metric_logger.group=vq_type_ablation"
WANDB_NAME="metric_logger.name"

dhat=32
C=2048
K=8
Expand Down
41 changes: 39 additions & 2 deletions vqllm/models/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,19 @@ def llama3_8b(
vq_attn_key_reorder_channel=vq_attn_key_reorder_channel,
)

def mistral_7b() -> TransformerDecoder:
def mistral_7b(
vq_attn_key=False,
vq_attn_value=False,
vq_layers=[],
num_codebooks=None,
num_codebook_entries=None,
codebook_entry_dim=None,
num_residual_codebooks=1,
num_residual_steps=1,
ema_decay=0.0,
use_fast_quantizer=False,
vq_attn_key_reorder_channel=False,
) -> TransformerDecoder:
"""
Builder for creating a Mistral 7B model initialized w/ the default 7b parameter values
from https://mistral.ai/news/announcing-mistral-7b/
Expand All @@ -76,16 +88,41 @@ def mistral_7b() -> TransformerDecoder:
Returns:
TransformerDecoder: Instantiation of Mistral 7B model
"""
embed_dim = 4096

if num_codebook_entries is None and num_codebooks is not None:
num_codebook_entries = embed_dim // num_codebooks
if codebook_entry_dim is None and num_codebooks is not None:
codebook_entry_dim = embed_dim // num_codebooks

if (vq_attn_key or vq_attn_value) and (
(num_codebooks != 1 and num_codebooks * codebook_entry_dim != embed_dim)
or (embed_dim % codebook_entry_dim != 0)
):
raise ValueError

return mistral(
vocab_size=32_000,
num_layers=32,
num_heads=32,
num_kv_heads=8,
embed_dim=4096,
embed_dim=embed_dim,
intermediate_dim=14336,
max_seq_len=32768,
attn_dropout=0.0,
norm_eps=1e-5,
# vq parameters
vq_attn_key=vq_attn_key,
vq_attn_value=vq_attn_value,
vq_layers=vq_layers,
num_codebooks=num_codebooks,
num_codebook_entries=num_codebook_entries,
codebook_entry_dim=codebook_entry_dim,
num_residual_codebooks=num_residual_codebooks,
num_residual_steps=num_residual_steps,
ema_decay=ema_decay,
use_fast_quantizer=use_fast_quantizer,
vq_attn_key_reorder_channel=vq_attn_key_reorder_channel,
)

def mistral_tokenizer(path: str) -> SentencePieceTokenizer:
Expand Down
1 change: 0 additions & 1 deletion vqllm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def quantize(
mask,
quantizers,
training=False,
scale=None,
cosine_distance=False,
):
"""
Expand Down
18 changes: 18 additions & 0 deletions vqllm/utils/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
logger = get_logger("DEBUG")


_FROM_HF = convert_weights._FROM_HF
_FROM_HF.update(
{
# vq_attn_key
"model.layers.{}.attention.quantizer.key.{}.codebook.weight": "layers.{}.attn.quantizer.key.{}.codebook.weight",
"model.layers.{}.attention.quantizer.key.{}.codebook.cluster_size": "layers.{}.attn.quantizer.key.{}.codebook.cluster_size",
"model.layers.{}.attention.quantizer.key.{}.codebook.embed_avg": "layers.{}.attn.quantizer.key.{}.codebook.embed_avg",
"model.layers.{}.attention.quantizer.key.{}.data_initialized": "layers.{}.attn.quantizer.key.{}.data_initialized",
# vq_attn_value
"model.layers.{}.attention.quantizer.value.{}.codebook.weight": "layers.{}.attn.quantizer.value.{}.codebook.weight",
"model.layers.{}.attention.quantizer.value.{}.codebook.cluster_size": "layers.{}.attn.quantizer.value.{}.codebook.cluster_size", # noqa: B950
"model.layers.{}.attention.quantizer.value.{}.codebook.embed_avg": "layers.{}.attn.quantizer.value.{}.codebook.embed_avg",
"model.layers.{}.attention.quantizer.value.{}.data_initialized": "layers.{}.attn.quantizer.value.{}.data_initialized",
}
)

_FROM_META = convert_weights._FROM_META
_FROM_META.update(
{
Expand Down Expand Up @@ -47,6 +63,8 @@ def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
return new_key


torchtune.models.convert_weights._FROM_HF = _FROM_HF
torchtune.models.convert_weights._FROM_META = _FROM_META
torchtune.models.convert_weights.get_mapped_key = get_mapped_key
FullModelHFCheckpointer = torchtune.utils._checkpointing.FullModelHFCheckpointer
FullModelMetaCheckpointer = torchtune.utils._checkpointing.FullModelMetaCheckpointer

0 comments on commit e297f87

Please sign in to comment.