Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When to stop training #136

Open
dkokron opened this issue Feb 9, 2025 · 3 comments
Open

When to stop training #136

dkokron opened this issue Feb 9, 2025 · 3 comments

Comments

@dkokron
Copy link

dkokron commented Feb 9, 2025

From issue
#80

"and we trained with 300k steps batch 32 each, which corresponds to about ~180 epochs"

What metric did you monitor to know when to stop training?
Did you ever redo figure 3 from https://arxiv.org/pdf/2312.15796 for a model having less than 300K steps?

@andrewlkd
Copy link
Collaborator

"and we trained with 300k steps batch 32 each, which corresponds to about ~180 epochs"

Just to clarify, the 300k steps is wrt GraphCast. As Table D1 in https://arxiv.org/pdf/2312.15796 suggests, GenCast is trained for 2 million steps at 1deg and then a further 64k steps at 0.25deg. At batch size 32, this corresponds to ~1200 epochs and ~38 epochs respectively.

What metric did you monitor to know when to stop training?

The training steps were chosen empirically from sweeps.

Did you ever redo figure 3 from https://arxiv.org/pdf/2312.15796 for a model having less than 300K steps?

Unfortunately not.

@dkokron
Copy link
Author

dkokron commented Feb 9, 2025

I took the 2M value in D1 to be the "decay_steps" value in the AdamW configuration. Am I correct to understand that you used 2M for "decay_steps" as well as 2M sampling cycles with each sample cycle having a batch size of 32 samples from the 54K total samples in the dataset?

lr_schedule = optax.schedules.warmup_cosine_decay_schedule(
init_value = 0.0,
peak_value = .003,
warmup_steps = 1000,
decay_steps = 2000000,
end_value = 0.0,
exponent = 0.1,
)

@andrewlkd
Copy link
Collaborator

As stated in the table, 2M is "Total train steps". Since there are 1k warm up steps, this means there are (2M - 1k) decay steps.

Since warmup/decay only relates to the learning rate being applied, this indeed leaves "2M sampling cycles with each sample cycle having a batch size of 32 samples from the 54K total samples in the dataset".

Hope this helps

-- Andrew

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants