Skip to content
This repository has been archived by the owner on Mar 14, 2024. It is now read-only.

Weight decay #265

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
early stopping
  • Loading branch information
adamlerer committed Sep 21, 2022
commit 8fc291326076590cd9ddd631bc48d8a47e8d27b0
6 changes: 6 additions & 0 deletions torchbiggraph/config.py
Original file line number Diff line number Diff line change
@@ -396,6 +396,12 @@ class ConfigSchema(Schema):
"after each training step."
},
)
early_stopping: bool = attr.ib(
default=False,
metadata={
"help": "Stop training when validation loss increases."
}
)

# expert options

83 changes: 47 additions & 36 deletions torchbiggraph/train_cpu.py
Original file line number Diff line number Diff line change
@@ -578,6 +578,7 @@ def train(self) -> None:
eval_stats_chunk_avg,
)

last_chunk_loss = float("inf")
for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
logger.info(
f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
@@ -729,10 +730,17 @@ def train(self) -> None:

current_index = (iteration_manager.iteration_idx + 1) * total_buckets - 1

self._maybe_write_checkpoint(
all_stats_dicts = self._maybe_write_checkpoint(
epoch_idx, edge_path_idx, edge_chunk_idx, current_index
)

if config.early_stopping:
assert iteration_manager.num_edge_paths == 1
chunk_loss = all_stats_dicts[-1]["eval_stats_chunk_avg"]["metrics"]["loss"]
if chunk_loss > last_chunk_loss:
break
last_chunk_loss = chunk_loss

# now we're sure that all partition files exist,
# so be strict about loading them
self.strict = True
@@ -922,7 +930,7 @@ def _maybe_write_checkpoint(
edge_path_idx: int,
edge_chunk_idx: int,
current_index: int,
) -> None:
) -> List[Dict[str, Any]]:

config = self.config

@@ -963,42 +971,43 @@ def _maybe_write_checkpoint(
state_dict, self.trainer.model_optimizer.state_dict()
)

logger.info("Writing the training stats")
all_stats_dicts: List[Dict[str, Any]] = []
bucket_eval_stats_list = []
chunk_stats_dict = {
"epoch_idx": epoch_idx,
"edge_path_idx": edge_path_idx,
"edge_chunk_idx": edge_chunk_idx,
all_stats_dicts: List[Dict[str, Any]] = []
bucket_eval_stats_list = []
chunk_stats_dict = {
"epoch_idx": epoch_idx,
"edge_path_idx": edge_path_idx,
"edge_chunk_idx": edge_chunk_idx,
}
for stats in self.bucket_scheduler.get_stats_for_pass():
stats_dict = {
"lhs_partition": stats.lhs_partition,
"rhs_partition": stats.rhs_partition,
"index": stats.index,
"stats": stats.train.to_dict(),
}
for stats in self.bucket_scheduler.get_stats_for_pass():
stats_dict = {
"lhs_partition": stats.lhs_partition,
"rhs_partition": stats.rhs_partition,
"index": stats.index,
"stats": stats.train.to_dict(),
}
if stats.eval_before is not None:
stats_dict["eval_stats_before"] = stats.eval_before.to_dict()
bucket_eval_stats_list.append(stats.eval_before)

if stats.eval_after is not None:
stats_dict["eval_stats_after"] = stats.eval_after.to_dict()

stats_dict.update(chunk_stats_dict)
all_stats_dicts.append(stats_dict)

if len(bucket_eval_stats_list) != 0:
eval_stats_chunk_avg = Stats.average_list(bucket_eval_stats_list)
self.stats_handler.on_stats(
index=current_index, eval_stats_chunk_avg=eval_stats_chunk_avg
)
chunk_stats_dict["index"] = current_index
chunk_stats_dict[
"eval_stats_chunk_avg"
] = eval_stats_chunk_avg.to_dict()
all_stats_dicts.append(chunk_stats_dict)
if stats.eval_before is not None:
stats_dict["eval_stats_before"] = stats.eval_before.to_dict()
bucket_eval_stats_list.append(stats.eval_after)

if stats.eval_after is not None:
stats_dict["eval_stats_after"] = stats.eval_after.to_dict()

stats_dict.update(chunk_stats_dict)
all_stats_dicts.append(stats_dict)

if len(bucket_eval_stats_list) != 0:
eval_stats_chunk_avg = Stats.average_list(bucket_eval_stats_list)
chunk_stats_dict["index"] = current_index
chunk_stats_dict[
"eval_stats_chunk_avg"
] = eval_stats_chunk_avg.to_dict()
all_stats_dicts.append(chunk_stats_dict)

if self.rank == 0:
logger.info("Writing the training stats")
self.stats_handler.on_stats(
index=current_index, eval_stats_chunk_avg=eval_stats_chunk_avg
)
self.checkpoint_manager.append_stats(all_stats_dicts)

logger.info("Writing the checkpoint")
@@ -1029,3 +1038,5 @@ def _maybe_write_checkpoint(
self.checkpoint_manager.preserve_current_version(config, epoch_idx + 1)
if not preserve_old_checkpoint:
self.checkpoint_manager.remove_old_version(config)

return all_stats_dicts