Skip to content

Commit

Permalink
finetune hfargparser
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzuo committed Jul 22, 2024
1 parent 6d50076 commit 664c131
Showing 1 changed file with 66 additions and 32 deletions.
98 changes: 66 additions & 32 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections import defaultdict
from dataclasses import dataclass, field
from functools import partial
import importlib
import os
from packaging import version
import sqlite3
import yaml

Expand All @@ -14,11 +17,12 @@
import bitsandbytes as bnb
from datasets import Dataset
from peft import LoraConfig, get_peft_model
import transformers
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments,
HfArgumentParser,
PreTrainedTokenizer,
PreTrainedModel,
)
Expand All @@ -27,10 +31,35 @@

import llm_planner as llmp

from accelerate import Accelerator
HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")


HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")
@dataclass
class TrainingArguments(transformers.TrainingArguments):
output_dir: str | None = field(default=None)


def is_ipex_available():
def get_major_and_minor_from_version(full_version):
return (
str(version.parse(full_version).major)
+ "."
+ str(version.parse(full_version).minor)
)

_torch_version = importlib.metadata.version("torch")
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
_ipex_version = "N/A"
try:
_ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
except importlib.metadata.PackageNotFoundError:
return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor:
return False
return True


def load_dataset(config: dict) -> dict[str, Dataset]:
Expand Down Expand Up @@ -67,11 +96,15 @@ def load_dataset(config: dict) -> dict[str, Dataset]:
for key in split_key:
split_ids = split_ids[key]

c.execute(
f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})",
split_ids,
)
queries.extend(c.fetchall())
# split into chunks of 999

for i in range(0, len(split_ids), 999):
batch = split_ids[i : i + 999]
c.execute(
f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(batch))})",
batch,
)
queries.extend(c.fetchall())

for domain, problem_pddl, natural_language in queries:
dataset[split]["domain"].append(domains[domain])
Expand Down Expand Up @@ -136,7 +169,7 @@ def preprocess(
inputs = [
strip(
tokenizer.apply_chat_template(
llmp.PlanningProblem(nl, d, p).apply_template(
llmp.PlanningProblem(nl, d, p).apply_template(
domain_prompt,
problem_prompt,
),
Expand Down Expand Up @@ -184,15 +217,19 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
else:
bnb_config = None

device_index = Accelerator().process_index
device_map = {"": device_index}
device_map = "auto"
if os.environ.get("LOCAL_RANK") is not None:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
device_map = {"": local_rank}

model = AutoModelForCausalLM.from_pretrained(
config["model"]["model_name"],
**config["model"].get("model_kwargs", {}),
device_map=device_map,
token=HF_USER_TOKEN,
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=True,
)

lora_config = LoraConfig(
Expand All @@ -204,21 +241,31 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
return tokenizer, model


def main(config_path: str):
def main():
"""Train a model on a dataset using a given configuration.
Args:
config_path (str): The path to the configuration file.
"""
parser = HfArgumentParser(TrainingArguments)
parser.add_argument(
"-c",
"--config",
type=str,
# required=True,
help="Path to the configuration file.",
)

training_args, args = parser.parse_args_into_dataclasses()

# Load configuration
with open(config_path) as f:
with open(args.config) as f:
config = yaml.safe_load(f)
train_config = config["train"]

# Load dataset
dataset = load_dataset(config["dataset"])

train_config = config["train"]

# Load model
tokenizer, model = load_model(train_config)

Expand All @@ -233,11 +280,12 @@ def main(config_path: str):

# Build training arguments
args_config = train_config.get("training_args", {})
training_args = TrainingArguments(**args_config)
training_args.__dict__.update(args_config)

# Create trainer
trainer = SFTTrainer(
model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
Expand All @@ -256,18 +304,4 @@ def main(config_path: str):


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser("Fine-tune a model on PDDL dataset.")
parser.add_argument(
"-c",
"--config",
type=str,
default="config.yaml",
required=True,
help="Path to the configuration file.",
)

args = parser.parse_args()

main(args.config)
main()

0 comments on commit 664c131

Please sign in to comment.