Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

fixed multi-GPU termination in train.py #379

Merged
merged 1 commit into from
Jun 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions experiments/interpretation/dutchf3_patch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,17 @@ def run(*options, cfg=None, local_rank=0, debug=False, input=None, distributed=F

n_classes = train_set.n_classes
val_set = TrainPatchLoader(config, split="val", is_transform=True, augmentations=val_aug, debug=debug,)

logger.info(val_set)

if debug:
data_flow_dict = dict()

data_flow_dict['train_patch_loader_length'] = len(train_set)
data_flow_dict['validation_patch_loader_length'] = len(val_set)
data_flow_dict['train_input_shape'] = train_set.seismic.shape
data_flow_dict['train_label_shape'] = train_set.labels.shape
data_flow_dict['n_classes'] = n_classes
data_flow_dict["train_patch_loader_length"] = len(train_set)
data_flow_dict["validation_patch_loader_length"] = len(val_set)
data_flow_dict["train_input_shape"] = train_set.seismic.shape
data_flow_dict["train_label_shape"] = train_set.labels.shape
data_flow_dict["n_classes"] = n_classes

logger.info("Running in debug mode..")
train_range = min(config.TRAIN.BATCH_SIZE_PER_GPU * config.NUM_DEBUG_BATCHES, len(train_set))
Expand All @@ -171,9 +171,9 @@ def run(*options, cfg=None, local_rank=0, debug=False, input=None, distributed=F
valid_range = min(config.VALIDATION.BATCH_SIZE_PER_GPU, len(val_set))
val_set = data.Subset(val_set, range(valid_range))

data_flow_dict['train_length_subset'] = len(train_set)
data_flow_dict['validation_length_subset'] = len(val_set)
data_flow_dict["train_length_subset"] = len(train_set)
data_flow_dict["validation_length_subset"] = len(val_set)

train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, num_replicas=world_size, rank=local_rank)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_set, num_replicas=world_size, rank=local_rank)

Expand All @@ -185,13 +185,13 @@ def run(*options, cfg=None, local_rank=0, debug=False, input=None, distributed=F
)

if debug:
data_flow_dict['train_loader_length'] = len(train_loader)
data_flow_dict['validation_loader_length'] = len(val_loader)
data_flow_dict["train_loader_length"] = len(train_loader)
data_flow_dict["validation_loader_length"] = len(val_loader)
config_file_name = "default_config" if not cfg else cfg.split("/")[-1].split(".")[0]
fname = f"data_flow_train_{config_file_name}_{config.TRAIN.MODEL_DIR}.json"
with open(fname, 'w') as f:
with open(fname, "w") as f:
json.dump(data_flow_dict, f, indent=2)

# Model:
model = getattr(models, config.MODEL.NAME).get_seg_model(config)
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -322,7 +322,8 @@ def log_validation_results(engine):

logger.info("Starting training")
trainer.run(train_loader, max_epochs=config.TRAIN.END_EPOCH, epoch_length=len(train_loader), seed=config.SEED)
summary_writer.close()
if local_rank == 0:
summary_writer.close()


if __name__ == "__main__":
Expand Down