Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
dxli94 authored Jul 24, 2023
1 parent 7fe1dd5 commit 61989d7
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions lavis/runners/runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,12 @@ def train(self):
# training phase
if not self.evaluate_only:
logging.info("Start training")
if cur_epoch == self.start_epoch:
self.task.before_training(
model=self.unwrap_dist_model(self.model),
dataset=self.datasets["train"],
)
# See https://github.com/salesforce/LAVIS/issues/449
# if cur_epoch == self.start_epoch:
# self.task.before_training(
# model=self.unwrap_dist_model(self.model),
# dataset=self.datasets["train"],
# )
train_stats = self.train_epoch(cur_epoch)
self.log_stats(split_name="train", stats=train_stats)

Expand Down

0 comments on commit 61989d7

Please sign in to comment.