diff --git a/auto3dseg/algorithm_templates/dints/configs/hyper_parameters.yaml b/auto3dseg/algorithm_templates/dints/configs/hyper_parameters.yaml index 1cbe0f0d..a96cd34b 100644 --- a/auto3dseg/algorithm_templates/dints/configs/hyper_parameters.yaml +++ b/auto3dseg/algorithm_templates/dints/configs/hyper_parameters.yaml @@ -49,10 +49,10 @@ training: momentum: 0.9 weight_decay: 4.0e-05 lr_scheduler: - _target_: torch.optim.lr_scheduler.StepLR + _target_: torch.optim.lr_scheduler.PolynomialLR optimizer: "$@training#optimizer" - step_size: "$max(@training#num_epochs // 5, 1)" - gamma: 0.5 + power: 0.5, + total_iters: '$@training#num_epochs // @training#num_epochs_per_validation + 1' # fine-tuning finetune: diff --git a/auto3dseg/algorithm_templates/dints/scripts/search.py b/auto3dseg/algorithm_templates/dints/scripts/search.py index 7519c94f..3e2d2cac 100644 --- a/auto3dseg/algorithm_templates/dints/scripts/search.py +++ b/auto3dseg/algorithm_templates/dints/scripts/search.py @@ -445,7 +445,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): "train_loss_arch", loss.item(), epoch_len * epoch + step ) - lr_scheduler.step() + lr_scheduler.step() if torch.cuda.device_count() > 1: dist.barrier() diff --git a/auto3dseg/algorithm_templates/dints/scripts/train.py b/auto3dseg/algorithm_templates/dints/scripts/train.py index 9713e508..c28cef5a 100644 --- a/auto3dseg/algorithm_templates/dints/scripts/train.py +++ b/auto3dseg/algorithm_templates/dints/scripts/train.py @@ -143,9 +143,10 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): validate_cache_rate = float(parser.get_parsed_content("validate_cache_rate")) train_ds = monai.data.CacheDataset( - data=train_files, + data=train_files * num_epochs_per_validation, transform=train_transforms, cache_rate=train_cache_rate, + hash_as_key=True, num_workers=parser.get_parsed_content("training#num_cache_workers"), progress=False, ) @@ -153,6 +154,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): data=val_files, transform=val_transforms, cache_rate=validate_cache_rate, + hash_as_key=True, num_workers=parser.get_parsed_content("training#num_cache_workers"), progress=False, ) @@ -253,11 +255,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") start_time = time.time() - for epoch in range(num_epochs): + + num_rounds = int(np.ceil(float(num_epochs) // float(num_epochs_per_validation))) + for _round in range(num_rounds): + epoch = (_round + 1) * num_epochs_per_validation lr = lr_scheduler.get_last_lr()[0] if torch.cuda.device_count() == 1 or dist.get_rank() == 0: print("-" * 10) - print(f"epoch {epoch + 1}/{num_epochs}") + print(f"epoch {_round * num_epochs_per_validation + 1}/{num_epochs}") print(f"learning rate is set to {lr}") model.train() @@ -303,9 +308,9 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): f"[{str(datetime.now())[:19]}] " + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" ) - writer.add_scalar("Loss/train", loss.item(), epoch_len * epoch + step) + writer.add_scalar("Loss/train", loss.item(), epoch_len * num_rounds + step) - lr_scheduler.step() + lr_scheduler.step() if torch.cuda.device_count() > 1: dist.barrier() @@ -315,144 +320,143 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): if torch.cuda.device_count() == 1 or dist.get_rank() == 0: loss_torch_epoch = loss_torch[0] / loss_torch[1] print( - f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, " + f"epoch {epoch} average loss: {loss_torch_epoch:.4f}, " f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) - if (epoch + 1) % val_interval == 0 or (epoch + 1) == num_epochs: - torch.cuda.empty_cache() - model.eval() - with torch.no_grad(): - metric = torch.zeros(metric_dim * 2, dtype=torch.float, device=device) - metric_sum = 0.0 - metric_count = 0 - metric_mat = [] - val_images = None - val_labels = None - val_outputs = None - - _index = 0 - for val_data in val_loader: - val_images = ( - val_data["image"].to(device) - if sw_input_on_cpu is False - else val_data["image"] - ) - val_labels = ( - val_data["label"].to(device) - if sw_input_on_cpu is False - else val_data["label"] + torch.cuda.empty_cache() + model.eval() + with torch.no_grad(): + metric = torch.zeros(metric_dim * 2, dtype=torch.float, device=device) + metric_sum = 0.0 + metric_count = 0 + metric_mat = [] + val_images = None + val_labels = None + val_outputs = None + + _index = 0 + for val_data in val_loader: + val_images = ( + val_data["image"].to(device) + if sw_input_on_cpu is False + else val_data["image"] + ) + val_labels = ( + val_data["label"].to(device) + if sw_input_on_cpu is False + else val_data["label"] + ) + + with torch.cuda.amp.autocast(enabled=amp): + val_outputs = sliding_window_inference( + val_images, + patch_size_valid, + num_sw_batch_size, + model, + mode="gaussian", + overlap=overlap_ratio, + sw_device=device, ) - with torch.cuda.amp.autocast(enabled=amp): - val_outputs = sliding_window_inference( - val_images, - patch_size_valid, - num_sw_batch_size, - model, - mode="gaussian", - overlap=overlap_ratio, - sw_device=device, - ) + val_outputs = post_pred(val_outputs[0, ...]) + val_outputs = val_outputs[None, ...] - val_outputs = post_pred(val_outputs[0, ...]) - val_outputs = val_outputs[None, ...] + if softmax: + val_labels = post_label(val_labels[0, ...]) + val_labels = val_labels[None, ...] - if softmax: - val_labels = post_label(val_labels[0, ...]) - val_labels = val_labels[None, ...] + value = compute_dice( + y_pred=val_outputs, y=val_labels, include_background=not softmax + ) - value = compute_dice( - y_pred=val_outputs, y=val_labels, include_background=not softmax - ) + print(_index + 1, "/", len(val_loader), value) - print(_index + 1, "/", len(val_loader), value) + metric_count += len(value) + metric_sum += value.sum().item() + metric_vals = value.cpu().numpy() + if len(metric_mat) == 0: + metric_mat = metric_vals + else: + metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) - metric_count += len(value) - metric_sum += value.sum().item() - metric_vals = value.cpu().numpy() - if len(metric_mat) == 0: - metric_mat = metric_vals - else: - metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) - - for _c in range(metric_dim): - val0 = torch.nan_to_num(value[0, _c], nan=0.0) - val1 = 1.0 - torch.isnan(value[0, 0]).float() - metric[2 * _c] += val0 * val1 - metric[2 * _c + 1] += val1 - - _index += 1 - - if torch.cuda.device_count() > 1: - dist.barrier() - dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) - - metric = metric.tolist() - if torch.cuda.device_count() == 1 or dist.get_rank() == 0: - for _c in range(metric_dim): - print( - f"evaluation metric - class {_c + 1:d}:", - metric[2 * _c] / metric[2 * _c + 1], - ) - avg_metric = 0 - for _c in range(metric_dim): - avg_metric += metric[2 * _c] / metric[2 * _c + 1] - avg_metric = avg_metric / float(metric_dim) - print("avg_metric", avg_metric) - - writer.add_scalar("Accuracy/validation", avg_metric, epoch) - - if avg_metric > best_metric: - best_metric = avg_metric - best_metric_epoch = epoch + 1 - if torch.cuda.device_count() > 1: - torch.save( - model.module.state_dict(), - os.path.join(ckpt_path, "best_metric_model.pt"), - ) - else: - torch.save( - model.state_dict(), - os.path.join(ckpt_path, "best_metric_model.pt"), - ) - print("saved new best metric model") - - dict_file = {} - dict_file["best_avg_dice_score"] = float(best_metric) - dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch) - dict_file["best_avg_dice_score_iteration"] = int(idx_iter) - with open( - os.path.join(ckpt_path, "progress.yaml"), "a" - ) as out_file: - yaml.dump([dict_file], stream=out_file) + for _c in range(metric_dim): + val0 = torch.nan_to_num(value[0, _c], nan=0.0) + val1 = 1.0 - torch.isnan(value[0, 0]).float() + metric[2 * _c] += val0 * val1 + metric[2 * _c + 1] += val1 + + _index += 1 + + if torch.cuda.device_count() > 1: + dist.barrier() + dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) + metric = metric.tolist() + if torch.cuda.device_count() == 1 or dist.get_rank() == 0: + for _c in range(metric_dim): print( - "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( - epoch + 1, avg_metric, best_metric, best_metric_epoch - ) + f"evaluation metric - class {_c + 1:d}:", + metric[2 * _c] / metric[2 * _c + 1], ) + avg_metric = 0 + for _c in range(metric_dim): + avg_metric += metric[2 * _c] / metric[2 * _c + 1] + avg_metric = avg_metric / float(metric_dim) + print("avg_metric", avg_metric) + + writer.add_scalar("Accuracy/validation", avg_metric, epoch) + + if avg_metric > best_metric: + best_metric = avg_metric + best_metric_epoch = epoch + if torch.cuda.device_count() > 1: + torch.save( + model.module.state_dict(), + os.path.join(ckpt_path, "best_metric_model.pt"), + ) + else: + torch.save( + model.state_dict(), + os.path.join(ckpt_path, "best_metric_model.pt"), + ) + print("saved new best metric model") - current_time = time.time() - elapsed_time = (current_time - start_time) / 60.0 + dict_file = {} + dict_file["best_avg_dice_score"] = float(best_metric) + dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch) + dict_file["best_avg_dice_score_iteration"] = int(idx_iter) with open( - os.path.join(ckpt_path, "accuracy_history.csv"), "a" - ) as f: - f.write( - "{:d}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.1f}\t{:d}\n".format( - epoch + 1, - avg_metric, - loss_torch_epoch, - lr, - elapsed_time, - idx_iter, - ) + os.path.join(ckpt_path, "progress.yaml"), "a" + ) as out_file: + yaml.dump([dict_file], stream=out_file) + + print( + "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( + epoch, avg_metric, best_metric, best_metric_epoch + ) + ) + + current_time = time.time() + elapsed_time = (current_time - start_time) / 60.0 + with open( + os.path.join(ckpt_path, "accuracy_history.csv"), "a" + ) as f: + f.write( + "{:d}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.1f}\t{:d}\n".format( + epoch, + avg_metric, + loss_torch_epoch, + lr, + elapsed_time, + idx_iter, ) + ) - if torch.cuda.device_count() > 1: - dist.barrier() + if torch.cuda.device_count() > 1: + dist.barrier() - torch.cuda.empty_cache() + torch.cuda.empty_cache() if torch.cuda.device_count() == 1 or dist.get_rank() == 0: print( diff --git a/auto3dseg/algorithm_templates/segresnet2d/scripts/train.py b/auto3dseg/algorithm_templates/segresnet2d/scripts/train.py index 124cb901..63e333a2 100644 --- a/auto3dseg/algorithm_templates/segresnet2d/scripts/train.py +++ b/auto3dseg/algorithm_templates/segresnet2d/scripts/train.py @@ -305,7 +305,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): ) writer.add_scalar("Loss/train", loss.item(), epoch_len * epoch + step) - lr_scheduler.step() + lr_scheduler.step() if torch.cuda.device_count() > 1: dist.barrier() diff --git a/auto3dseg/algorithm_templates/swinunetr/scripts/train.py b/auto3dseg/algorithm_templates/swinunetr/scripts/train.py index bcd52f6b..1f748322 100644 --- a/auto3dseg/algorithm_templates/swinunetr/scripts/train.py +++ b/auto3dseg/algorithm_templates/swinunetr/scripts/train.py @@ -324,7 +324,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): ) writer.add_scalar("Loss/train", loss.item(), epoch_len * epoch + step) - lr_scheduler.step() + lr_scheduler.step() if torch.cuda.device_count() > 1: dist.barrier()