Skip to content

Commit

Permalink
Lora ckpt in HF format for NeMo AutoModel (#11712)
Browse files Browse the repository at this point in the history
* Save lora ckpt in safetensor and a config

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* remove hf variable from peft

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* vllm with automodel peft working

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com>

* revert changes

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* update examples

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com>

* removed unused import

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* enable ckpt saving

Signed-off-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com>

* remove unused import

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com>

* fix minor bug

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

---------

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com>
Signed-off-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com>
Co-authored-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com>
  • Loading branch information
oyilmaz-nvidia and oyilmaz-nvidia authored Jan 10, 2025
1 parent 500c827 commit 9799051
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 31 deletions.
13 changes: 7 additions & 6 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3622,7 +3622,7 @@ jobs:
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --disable-ckpt
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft_hf.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3
AFTER_SCRIPT: |
rm -rf nemo_experiments
Expand All @@ -3633,16 +3633,17 @@ jobs:
with:
RUNNER: self-hosted-azure
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --disable-ckpt --strategy fsdp --devices 2
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft_hf.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --strategy fsdp --devices 2
AFTER_SCRIPT: |
rm -rf nemo_experiments
L2_VLM_HF_Transformer_PEFT_4bit:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_VLM_HF_Transformer_PEFT_4bit') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --disable-ckpt --use-4bit
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft_hf.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --use-4bit
AFTER_SCRIPT: |
rm -rf nemo_experiments
Expand All @@ -3653,7 +3654,7 @@ jobs:
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/peft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --disable-ckpt
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/peft_hf.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10
AFTER_SCRIPT: |
rm -rf nemo_experiments
Expand All @@ -3675,7 +3676,7 @@ jobs:
with:
RUNNER: self-hosted-azure
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/peft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp --disable-ckpt
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/peft_hf.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp --disable-ckpt
AFTER_SCRIPT: |
rm -rf nemo_experiments
Expand Down
14 changes: 12 additions & 2 deletions examples/llm/peft/hf.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import fiddle as fdl
from lightning.pytorch.loggers import WandbLogger

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning import NeMoLogger
from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform


Expand Down Expand Up @@ -69,6 +71,7 @@ def main():
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--use-torch-jit', action='store_true')
parser.add_argument('--ckpt-folder', type=str, default=None)
args = parser.parse_args()

wandb = None
Expand All @@ -84,6 +87,13 @@ def main():
# https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None
use_dist_samp = False

import tempfile

if args.ckpt_folder is None:
args.ckpt_folder = tempfile.TemporaryDirectory().name
print("Temp directory created for base model: ", args.ckpt_folder)

tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model)

callbacks = []
Expand All @@ -110,10 +120,10 @@ def main():
precision="bf16",
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
log=NeMoLogger(log_dir=args.ckpt_folder, use_datetime_version=False),
peft=llm.peft.LoRA(
target_modules=['*_proj'],
dim=32,
dim=8,
),
)

Expand Down
42 changes: 42 additions & 0 deletions examples/llm/peft/hf_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

try:
from nemo.export.vllm_hf_exporter import vLLMHFExporter
except Exception:
raise Exception(
"vLLM should be installed in the environment or import "
"the vLLM environment in the NeMo FW container using "
"source /opt/venv/bin/activate command"
)


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, type=str, help="Local path of the base model")
parser.add_argument('--lora-model', required=True, type=str, help="Local path of the lora model")
# parser.add_argument('--triton-model-name', required=True, type=str, help="Name for the service")
args = parser.parse_args()

lora_model_name = "lora_model"

exporter = vLLMHFExporter()
exporter.export(model=args.model, enable_lora=True)
exporter.add_lora_models(lora_model_name=lora_model_name, lora_model=args.lora_model)

print(
"------------- Output: ", exporter.forward(input_texts=["How are you doing?"], lora_model_name=lora_model_name)
)
7 changes: 7 additions & 0 deletions nemo/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


use_tensorrt = True
try:
from nemo.export.tensorrt_lazy_compiler import trt_compile
except Exception as e:
use_tensorrt = False
24 changes: 20 additions & 4 deletions nemo/export/vllm_hf_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytriton.decorators import batch
from pytriton.model_config import Tensor
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from nemo.deploy import ITritonDeployable
from nemo.deploy.utils import cast_output, str_ndarray2list
Expand Down Expand Up @@ -48,14 +49,20 @@ class vLLMHFExporter(ITritonDeployable):

def __init__(self):
self.model = None
self.lora_models = None

def export(self, model):
def export(self, model, enable_lora: bool = False):
"""
Exports the HF checkpoint to vLLM and initializes the engine.
Args:
model (str): model name or the path
"""
self.model = LLM(model=model)
self.model = LLM(model=model, enable_lora=enable_lora)

def add_lora_models(self, lora_model_name, lora_model):
if self.lora_models is None:
self.lora_models = {}
self.lora_models[lora_model_name] = lora_model

@property
def get_triton_input(self):
Expand Down Expand Up @@ -99,15 +106,24 @@ def forward(
input_texts: List[str],
max_output_len: int = 64,
top_k: int = 1,
top_p: float = 0.0,
top_p: float = 0.1,
temperature: float = 1.0,
lora_model_name: str = None,
):
assert self.model is not None, "Model is not initialized."

lora_request = None
if lora_model_name is not None:
if self.lora_models is None:
raise Exception("No lora models are available.")
assert lora_model_name in self.lora_models.keys(), "Lora model was not added before"
lora_request = LoRARequest(lora_model_name, 1, self.lora_models[lora_model_name])

sampling_params = SamplingParams(
max_tokens=max_output_len, temperature=temperature, top_k=int(top_k), top_p=top_p
)
request_output = self.model.generate(input_texts, sampling_params)

request_output = self.model.generate(input_texts, sampling_params, lora_request=lora_request)
output = []
for o in request_output:
output.append(o.outputs[0].text)
Expand Down
97 changes: 97 additions & 0 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,103 @@ def should_remove_missing_sharded_base(x: Any):
return sharded_state_dict


class HuggingFaceCheckpointIO(AsyncCompatibleCheckpointIO, IOMixin):
"""CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively,
common for most use cases.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
"""

def __init__(self, hf_model=None, lora=False):
self.hf_model = hf_model
self.lora = lora

@override
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
path: write-target path
storage_options: not used in ``TorchCheckpointIO.save_checkpoint``
Raises
------
TypeError:
If ``storage_options`` arg is passed in
"""

if self.lora:
from safetensors.torch import save_file

state_dict = {}
for module_name, module_weight in checkpoint["state_dict"].items():
new_module_name = module_name.replace("model.model", "base_model.model")
new_module_name = new_module_name.replace("lora_a", "lora_A.weight").replace("lora_b", "lora_B.weight")
state_dict[new_module_name] = module_weight

checkpoint_dir = ckpt_to_weights_subdir(path, is_saving=True)
fs = get_filesystem(checkpoint_dir)
fs.makedirs(checkpoint_dir, exist_ok=True)
save_file(state_dict, checkpoint_dir / "adapter_model.safetensors")

@override
def load_checkpoint(
self,
path: _PATH,
sharded_state_dict=None,
map_location: Optional[Callable] = None,
strict: Optional['StrictHandling'] | bool = None,
) -> Dict[str, Any]:
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.
Args:
path: Path to checkpoint
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.
Returns: The loaded checkpoint.
Raises
------
FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem
"""

# Try to read the checkpoint at `path`. If not exist, do not restore checkpoint.
fs = get_filesystem(path)
if not fs.exists(path):
raise FileNotFoundError(f"Checkpoint file not found: {path}")
if not fs.isdir(path):
raise ValueError(f"Checkpoints should be a directory. Found: {path}.")

state_dict = None
if (path / "adaptor_config.json").exists():
from safetensors import safe_open

state_dict = {}
with safe_open("adapter_model.safetensors", framework="pt", device=0) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)

return {'state_dict': state_dict}

@override
def remove_checkpoint(self, path: _PATH) -> None:
"""Remove checkpoint file from the filesystem.
Args:
path: Path to checkpoint
"""
fs = get_filesystem(path)
if fs.exists(path):
fs.rm(path, recursive=True)
log.debug(f"Removed checkpoint: {path}")


def _fix_tensors_device(ckpt: Dict) -> Dict:
"""Ensure checkpoint tensors are on the correct device."""
assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized())
Expand Down
64 changes: 48 additions & 16 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,27 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)
trainer.strategy.trainer = trainer
wrapped_io = partial(WrappedAdapterIO, peft=self)

ckpt_io_kwarg_names = [
"save_ckpt_format",
"async_save",
"torch_dist_multiproc",
"assume_constant_structure",
"parallel_save",
"parallel_save_within_dp",
"parallel_load",
"load_directly_on_device",
]
ckpt_io_kwargs = {
arg: getattr(trainer.strategy, arg)
for arg in filter(lambda x: hasattr(trainer.strategy, x), ckpt_io_kwarg_names)
}
is_hf_model = getattr(trainer.model, "is_hf_model", False)
if not type(is_hf_model) == type(True):
is_hf_model = False

if is_hf_model:
ckpt_io_kwargs = {"model_library": "huggingface", "lora": True}
else:
ckpt_io_kwarg_names = [
"save_ckpt_format",
"async_save",
"torch_dist_multiproc",
"assume_constant_structure",
"parallel_save",
"parallel_save_within_dp",
"parallel_load",
"load_directly_on_device",
]
ckpt_io_kwargs = {
arg: getattr(trainer.strategy, arg)
for arg in filter(lambda x: hasattr(trainer.strategy, x), ckpt_io_kwarg_names)
}
trainer.strategy._checkpoint_io = create_checkpoint_io(wrapping_ckpt_io=wrapped_io, **ckpt_io_kwargs)
self.wrapped_io = (
trainer.strategy._checkpoint_io._checkpoint_io
Expand Down Expand Up @@ -401,14 +408,39 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
from nemo.utils.get_rank import is_global_rank_zero

if is_global_rank_zero():
metadata = {"model_ckpt_path": str(self.model_ckpt_path)}
base_dir = ckpt_to_weights_subdir(path, is_saving=True)
base_dir.mkdir(parents=True, exist_ok=True)
adapter_meta_path = base_dir / ADAPTER_META_FILENAME

from nemo.lightning.io.pl import HuggingFaceCheckpointIO

if isinstance(self.checkpoint_io, HuggingFaceCheckpointIO):
metadata = self._create_lora_hf_config()
adapter_meta_path = base_dir / "adapter_config.json"
else:
metadata = {"model_ckpt_path": str(self.model_ckpt_path)}
adapter_meta_path = base_dir / ADAPTER_META_FILENAME

with open(adapter_meta_path, "w") as f:
json.dump(metadata, f)
return request

def _create_lora_hf_config(self):
from peft import LoraConfig
from nemo.collections.llm.peft import DoRA

lora_config = LoraConfig(
r=self.peft.dim,
target_modules=self.peft.target_modules,
lora_alpha=self.peft.alpha,
lora_dropout=self.peft.dropout,
use_dora=isinstance(self.peft, DoRA),
)
lora_config = lora_config.to_dict()
lora_config["peft_type"] = "LORA"
lora_config["megatron_core"] = None
lora_config["target_modules"] = self.peft.target_modules
return lora_config

@override
def load_checkpoint(
self,
Expand Down
Loading

0 comments on commit 9799051

Please sign in to comment.