Skip to content

Commit

Permalink
add gradietn accumulation in HF Trainer integration
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed May 31, 2024
1 parent 0f465d6 commit 6c9ccd1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 33 deletions.
2 changes: 2 additions & 0 deletions examples/huggingface/bert_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def main():
parser.add_argument("--config_path", type=str, default="./config.yaml")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--data_name", type=str, default="sst2")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
args = parser.parse_args()

set_seed(0)
Expand All @@ -41,6 +42,7 @@ def main():
output_dir="./output",
num_train_epochs=1,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
report_to="none",
)

Expand Down
2 changes: 2 additions & 0 deletions examples/huggingface/gpt_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def main():
parser.add_argument("--config_path", type=str, default="./config.yaml")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--data_name", type=str, default="sst2")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
args = parser.parse_args()

set_seed(0)
Expand All @@ -41,6 +42,7 @@ def main():
num_train_epochs=1,
per_device_train_batch_size=args.batch_size,
report_to="none",
gradient_accumulation_steps=args.gradient_accumulation_steps,
)

LogIXTrainer = patch_trainer(Trainer)
Expand Down
35 changes: 17 additions & 18 deletions logix/huggingface/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from transformers.trainer import TrainerCallback

from logix import LogIX, LogIXScheduler
from logix.utils import merge_logs
from logix.huggingface.arguments import LogIXArguments


Expand All @@ -17,6 +18,8 @@ def __init__(
self.logix_scheduler = iter(logix_scheduler)
self.args = args

self.accumulated_log = []

self._log_dataloader = None

def on_init_end(self, args, state, control, **kwargs):
Expand Down Expand Up @@ -51,35 +54,31 @@ def on_train_begin(self, args, state, control, **kwargs):

def on_step_end(self, args, state, control, **kwargs):
if self.args.mode == "influence":
test_log = self.logix.get_log()
self.accumulated_log.append(self.logix.get_log(copy=True))
accumulated_log = merge_logs(self.accumulated_log)

self.logix.influence.compute_influence_all(
test_log,
accumulated_log,
self.log_dataloader(),
mode=self.args.influence_mode,
damping=self.args.influence_damping,
save=True,
)

self.accumulated_log = []
elif self.args.mode == "self_influence":
test_log = self.logix.get_log()
self.accumulated_log.append(self.logix.get_log(copy=True))
accumulated_log = merge_logs(self.accumulated_log)

self.logix.influence.compute_self_influence(
test_log, damping=self.args.influence_damping
accumulated_log, damping=self.args.influence_damping
)

self.accumulated_log = []

def on_substep_end(self, args, state, control, **kwargs):
if self.args.mode == "influence":
test_log = self.logix.get_log()
self.logix.influence.compute_influence_all(
test_log,
self.log_dataloader(),
mode=self.args.influence_mode,
damping=self.args.influence_damping,
save=True,
)
elif self.args.mode == "self_influence":
test_log = self.logix.get_log()
self.logix.influence.compute_self_influence(
test_log, damping=self.args.influence_damping
)
if self.args.mode in ["influence", "self_influence"]:
self.accumulated_log.append(self.logix.get_log(copy=True))

def log_dataloader(self):
if self._log_dataloader is None:
Expand Down
15 changes: 0 additions & 15 deletions logix/logix.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,21 +434,6 @@ def compute_self_influence(
)
return result

def save_config(self) -> None:
"""
Save LogIX state to disk.
"""
config_file = os.path.join(self.log_dir, "config.yaml")
config_dict = asdict(self.config)
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(config_dict, f, default_flow_style=False)

def save_state(self) -> None:
"""
Save Hessian state to disk.
"""
self.state.save_state(self.log_dir)

def save_lora(self) -> None:
"""
Save LoRA state to disk.
Expand Down

0 comments on commit 6c9ccd1

Please sign in to comment.