Skip to content

Commit

Permalink
Rebase, rename llama_flash_attn -> flash_attn.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed May 11, 2023
1 parent bd1b009 commit af4a98b
Showing 1 changed file with 38 additions and 40 deletions.
78 changes: 38 additions & 40 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def train(
lora_dropout: float = 0.05,
lora_target_modules: List[str] = None,
llama_type: bool = None,
llama_flash_attn: bool = False,
flash_attn: bool = False,

# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
Expand All @@ -173,10 +173,9 @@ def train(
save_steps: int = None, # must be round multiple of eval_steps
save_total_limit: int = 3,
add_eos_token: bool = False,
flash_attention: bool = False,
):

if llama_flash_attn:
if flash_attn:
# Need to call this before importing transformers.
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
Expand Down Expand Up @@ -212,21 +211,16 @@ def train(
tokenizer_base_model = base_model
if llama_type is None:
llama_type = "llama" in base_model.lower()
if llama_type and llama_flash_attn:
if flash_attn:
import pkg_resources
try:
pkg_resources.get_distribution('flash_attn')
can_do_flash_attn = True
log("Enabling Flash attention")
except (pkg_resources.DistributionNotFound, pkg_resources.ContextualVersionConflict):
can_do_flash_attn = False

if not can_do_flash_attn:
raise RuntimeError("""Flash attention not installed.
NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
Expand Down Expand Up @@ -297,36 +291,40 @@ def train(
lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
lora_mappings['distilgpt2'] = ["c_attn"]

if "h2ogpt" in base_model and not llama_type and flash_attention:
if not llama_type and flash_attn:
log("Enabling Flash attention")
# speed up forward prop for attention layer and reduce memory especially for long context lengths
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config
from flash_attn.models.gptj import gptj_config_to_gpt2_config

if "gpt-j" in base_model.lower():
config = gptj_config_to_gpt2_config(model.config)
else:
config = gpt_neox_config_to_gpt2_config(model.config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
lora_target_modules = ['Wqkv']
model = GPTLMHeadModel.from_pretrained(base_model, config, device='cuda', dtype=torch.float16)
# for v in vars(model2.config):
# setattr(model.config, v, getattr(model2.config, v))
# model.transformer.config = model.config
# model.transformer.h = model2.transformer.layers
# model.lm_head = model2.lm_head
### model.transformer.wte = model2.transformer.wte
### model.transformer.embeddings = model2.transformer.embeddings
print(model)
# FIXME - don't disable LoRA
lora_r = 0
# FIXME - enable 8-bit
# model = prepare_model_for_int8_training(model)
# from flash_attn.models.gpt import GPTLMHeadModel
# from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config
# from flash_attn.models.gptj import gptj_config_to_gpt2_config
#
# if "gpt-j" in base_model.lower():
# config = gptj_config_to_gpt2_config(model.config)
# else:
# assert any([x in base_model.lower() for x in ["pythia", "h2ogpt", "gpt-neox"]])
# config = gpt_neox_config_to_gpt2_config(model.config)
# config.use_flash_attn = True
# config.fused_bias_fc = True
# config.activation_function = 'gelu_fast' # GPT-NeoX-20B uses "gelu_fast"
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
# config.fused_dropout_add_ln = True
# config.residual_in_fp32 = True
# lora_target_modules = ['Wqkv']
# # model = GPTLMHeadModel.from_pretrained(base_model, config, device='cuda', dtype=torch.float16)
#
# model = GPTLMHeadModel(config, base_model, device='cuda', dtype=torch.float16)
# # Load state_dict in cpu because we already initialized the model in GPU, and we don't
# # want extra stuff taking up more GPU memory
# state_dict = state_dict_from_pretrained(
# base_model, device='cpu', dtype=torch.float16
# )
# if base_model.startswith('EleutherAI/gpt-j-'):
# state_dict = remap_state_dict_hf_gptj(state_dict, config)
# strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
# else:
# state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
# if world_size > 1:
# state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
# model.load_state_dict(state_dict, strict=True)

if lora_weights:

Expand Down Expand Up @@ -672,7 +670,7 @@ def compute_metrics(eval_preds):
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
# WIP (not generally replacing layers until pytorch 2.1)
if not llama_flash_attn:
if not flash_attn:
torch.backends.cuda.enable_flash_sdp(True)

if gpus > 1 and not ddp:
Expand Down

0 comments on commit af4a98b

Please sign in to comment.