Replies: 1 comment 2 replies
-
Hi @drewoldag, thansk for asking a question, here is a working example of what you would like to do if I understood correctly: import torch.nn as nn
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, global_step_from_engine
from ignite.utils import setup_logger, logging
train_data = range(10)
eval_data = range(4)
max_epochs = 5
model = nn.Linear(10, 10)
def train_step(engine, batch):
print(f"{engine.state.epoch} / {engine.state.max_epochs} | {engine.state.iteration} - batch: {batch}", flush=True)
trainer = Engine(train_step)
to_save = {
"model": model,
}
checkpoint = Checkpoint(
to_save,
"./checkpoints",
n_saved=1,
global_step_transform=global_step_from_engine(trainer)
)
import signal
@trainer.on(Events.ITERATION_STARTED(once=23))
def send_signal():
signal.raise_signal(signal.SIGINT)
def terminate_training_and_checkpoint(*args, **kwargs):
print("Call checkpoint object to save the model etc")
checkpoint(trainer)
print("Terminate training")
trainer.terminate()
signal.signal(signal.SIGINT, terminate_training_and_checkpoint)
trainer.run(train_data, max_epochs=max_epochs) Output:
The problem in your code is mainly that you have to call Checkpoint instance with trainer argument to make the checkpoint: class Trainer():
def __init__(self, config):
signal.signal(signal.SIGINT, self.signal_handler)
def run(self):
"""Run the training process for a given model and data loader.
"""
self.model = model_cls(config=self.config, shape=data_loader.shape())
# Create trainer, a pytorch-ignite `Engine` object
self.trainer = self._create_trainer(self.model)
self.checkpointer = self.checkpoint_handler()
self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.checkpointer)
# Run the training process
self.trainer.run(dist_data_loader, max_epochs=self.config["model"]["epochs"])
def checkpoint_handler(self):
to_save = {
'model': self.model,
'optimizer': self.model.optimizer,
'trainer': self.trainer
}
logger.info("Creating checkpoint.")
return Checkpoint(
to_save,
DiskSaver(Path("./checkpoints"), require_empty=False),
n_saved=1,
global_step_transform=global_step_from_engine(self.trainer),
)
def signal_handler(self, sig, frame):
if sig == signal.SIGINT:
logger.info("SIGINT received, creating checkpoint and exiting.")
self.checkpointer(self.trainer)
self.trainer.terminate()
# this may be omitted?
# sys.exit(0) Hope this helps |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
A computing cluster where I'll be training some models with ignite uses a "condo model" such that I can request access to any resources not currently being used, but if the owner of those resources needs them, my jobs will be stopped within about 10-15 seconds.
This feels like the right place to use Checkpointing, but I'm not well enough versed in ignite to know how to shut down the trainer as part of my signal handler. Any advice would be very much appreciated!
My code looks something like this:
What I was expecting to happen here was to call the run method of the Train class, and then ctrl-c at some point in the training run, and a checkpoint file would be produced.
Unfortunately that doesn't seem to work. Checkpoints are produced at the end of each epoch as expected. But when I ctrl-c, I see multiple instances of the log message "Creating checkpoint" printed out (one for each dataloader worker+1), and a usually (but not always) a stack trace that looks like the following.
I'm assuming that there must be an elegant way to shut down the engine, and I thought that
trainer.terminate()
was the way to do it, but this doesn't seem correct.Am I just completely misusing this functionality? At the end of the day, I could just use the EPOCH_COMPLETED event to trigger the creation of a checkpoint, but it would be nice to be able to resume in the middle of an epoch if we're evicted from the hardware we're training on.
Beta Was this translation helpful? Give feedback.
All reactions