Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update auto3dseg algorithm templates #184

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ training:
momentum: 0.9
weight_decay: 4.0e-05
lr_scheduler:
_target_: torch.optim.lr_scheduler.StepLR
_target_: torch.optim.lr_scheduler.PolynomialLR
optimizer: "$@training#optimizer"
step_size: "$max(@training#num_epochs // 5, 1)"
gamma: 0.5
power: 0.5,
total_iters: '$@training#num_epochs // @training#num_epochs_per_validation + 1'

# fine-tuning
finetune:
Expand Down
2 changes: 1 addition & 1 deletion auto3dseg/algorithm_templates/dints/scripts/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
"train_loss_arch", loss.item(), epoch_len * epoch + step
)

lr_scheduler.step()
lr_scheduler.step()

if torch.cuda.device_count() > 1:
dist.barrier()
Expand Down
254 changes: 129 additions & 125 deletions auto3dseg/algorithm_templates/dints/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,18 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
validate_cache_rate = float(parser.get_parsed_content("validate_cache_rate"))

train_ds = monai.data.CacheDataset(
data=train_files,
data=train_files * num_epochs_per_validation,
transform=train_transforms,
cache_rate=train_cache_rate,
hash_as_key=True,
num_workers=parser.get_parsed_content("training#num_cache_workers"),
progress=False,
)
val_ds = monai.data.CacheDataset(
data=val_files,
transform=val_transforms,
cache_rate=validate_cache_rate,
hash_as_key=True,
num_workers=parser.get_parsed_content("training#num_cache_workers"),
progress=False,
)
Expand Down Expand Up @@ -253,11 +255,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
f.write("epoch\tmetric\tloss\tlr\ttime\titer\n")

start_time = time.time()
for epoch in range(num_epochs):

num_rounds = int(np.ceil(float(num_epochs) // float(num_epochs_per_validation)))
for _round in range(num_rounds):
epoch = (_round + 1) * num_epochs_per_validation
lr = lr_scheduler.get_last_lr()[0]
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
print("-" * 10)
print(f"epoch {epoch + 1}/{num_epochs}")
print(f"epoch {_round * num_epochs_per_validation + 1}/{num_epochs}")
print(f"learning rate is set to {lr}")

model.train()
Expand Down Expand Up @@ -303,9 +308,9 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
f"[{str(datetime.now())[:19]}] "
+ f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
)
writer.add_scalar("Loss/train", loss.item(), epoch_len * epoch + step)
writer.add_scalar("Loss/train", loss.item(), epoch_len * num_rounds + step)

lr_scheduler.step()
lr_scheduler.step()

if torch.cuda.device_count() > 1:
dist.barrier()
Expand All @@ -315,144 +320,143 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
loss_torch_epoch = loss_torch[0] / loss_torch[1]
print(
f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, "
f"epoch {epoch} average loss: {loss_torch_epoch:.4f}, "
f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
)

if (epoch + 1) % val_interval == 0 or (epoch + 1) == num_epochs:
torch.cuda.empty_cache()
model.eval()
with torch.no_grad():
metric = torch.zeros(metric_dim * 2, dtype=torch.float, device=device)
metric_sum = 0.0
metric_count = 0
metric_mat = []
val_images = None
val_labels = None
val_outputs = None

_index = 0
for val_data in val_loader:
val_images = (
val_data["image"].to(device)
if sw_input_on_cpu is False
else val_data["image"]
)
val_labels = (
val_data["label"].to(device)
if sw_input_on_cpu is False
else val_data["label"]
torch.cuda.empty_cache()
model.eval()
with torch.no_grad():
metric = torch.zeros(metric_dim * 2, dtype=torch.float, device=device)
metric_sum = 0.0
metric_count = 0
metric_mat = []
val_images = None
val_labels = None
val_outputs = None

_index = 0
for val_data in val_loader:
val_images = (
val_data["image"].to(device)
if sw_input_on_cpu is False
else val_data["image"]
)
val_labels = (
val_data["label"].to(device)
if sw_input_on_cpu is False
else val_data["label"]
)

with torch.cuda.amp.autocast(enabled=amp):
val_outputs = sliding_window_inference(
val_images,
patch_size_valid,
num_sw_batch_size,
model,
mode="gaussian",
overlap=overlap_ratio,
sw_device=device,
)

with torch.cuda.amp.autocast(enabled=amp):
val_outputs = sliding_window_inference(
val_images,
patch_size_valid,
num_sw_batch_size,
model,
mode="gaussian",
overlap=overlap_ratio,
sw_device=device,
)
val_outputs = post_pred(val_outputs[0, ...])
val_outputs = val_outputs[None, ...]

val_outputs = post_pred(val_outputs[0, ...])
val_outputs = val_outputs[None, ...]
if softmax:
val_labels = post_label(val_labels[0, ...])
val_labels = val_labels[None, ...]

if softmax:
val_labels = post_label(val_labels[0, ...])
val_labels = val_labels[None, ...]
value = compute_dice(
y_pred=val_outputs, y=val_labels, include_background=not softmax
)

value = compute_dice(
y_pred=val_outputs, y=val_labels, include_background=not softmax
)
print(_index + 1, "/", len(val_loader), value)

print(_index + 1, "/", len(val_loader), value)
metric_count += len(value)
metric_sum += value.sum().item()
metric_vals = value.cpu().numpy()
if len(metric_mat) == 0:
metric_mat = metric_vals
else:
metric_mat = np.concatenate((metric_mat, metric_vals), axis=0)

metric_count += len(value)
metric_sum += value.sum().item()
metric_vals = value.cpu().numpy()
if len(metric_mat) == 0:
metric_mat = metric_vals
else:
metric_mat = np.concatenate((metric_mat, metric_vals), axis=0)

for _c in range(metric_dim):
val0 = torch.nan_to_num(value[0, _c], nan=0.0)
val1 = 1.0 - torch.isnan(value[0, 0]).float()
metric[2 * _c] += val0 * val1
metric[2 * _c + 1] += val1

_index += 1

if torch.cuda.device_count() > 1:
dist.barrier()
dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)

metric = metric.tolist()
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
for _c in range(metric_dim):
print(
f"evaluation metric - class {_c + 1:d}:",
metric[2 * _c] / metric[2 * _c + 1],
)
avg_metric = 0
for _c in range(metric_dim):
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
avg_metric = avg_metric / float(metric_dim)
print("avg_metric", avg_metric)

writer.add_scalar("Accuracy/validation", avg_metric, epoch)

if avg_metric > best_metric:
best_metric = avg_metric
best_metric_epoch = epoch + 1
if torch.cuda.device_count() > 1:
torch.save(
model.module.state_dict(),
os.path.join(ckpt_path, "best_metric_model.pt"),
)
else:
torch.save(
model.state_dict(),
os.path.join(ckpt_path, "best_metric_model.pt"),
)
print("saved new best metric model")

dict_file = {}
dict_file["best_avg_dice_score"] = float(best_metric)
dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch)
dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
with open(
os.path.join(ckpt_path, "progress.yaml"), "a"
) as out_file:
yaml.dump([dict_file], stream=out_file)
for _c in range(metric_dim):
val0 = torch.nan_to_num(value[0, _c], nan=0.0)
val1 = 1.0 - torch.isnan(value[0, 0]).float()
metric[2 * _c] += val0 * val1
metric[2 * _c + 1] += val1

_index += 1

if torch.cuda.device_count() > 1:
dist.barrier()
dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)

metric = metric.tolist()
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
for _c in range(metric_dim):
print(
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
epoch + 1, avg_metric, best_metric, best_metric_epoch
)
f"evaluation metric - class {_c + 1:d}:",
metric[2 * _c] / metric[2 * _c + 1],
)
avg_metric = 0
for _c in range(metric_dim):
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
avg_metric = avg_metric / float(metric_dim)
print("avg_metric", avg_metric)

writer.add_scalar("Accuracy/validation", avg_metric, epoch)

if avg_metric > best_metric:
best_metric = avg_metric
best_metric_epoch = epoch
if torch.cuda.device_count() > 1:
torch.save(
model.module.state_dict(),
os.path.join(ckpt_path, "best_metric_model.pt"),
)
else:
torch.save(
model.state_dict(),
os.path.join(ckpt_path, "best_metric_model.pt"),
)
print("saved new best metric model")

current_time = time.time()
elapsed_time = (current_time - start_time) / 60.0
dict_file = {}
dict_file["best_avg_dice_score"] = float(best_metric)
dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch)
dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
with open(
os.path.join(ckpt_path, "accuracy_history.csv"), "a"
) as f:
f.write(
"{:d}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.1f}\t{:d}\n".format(
epoch + 1,
avg_metric,
loss_torch_epoch,
lr,
elapsed_time,
idx_iter,
)
os.path.join(ckpt_path, "progress.yaml"), "a"
) as out_file:
yaml.dump([dict_file], stream=out_file)

print(
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
epoch, avg_metric, best_metric, best_metric_epoch
)
)

current_time = time.time()
elapsed_time = (current_time - start_time) / 60.0
with open(
os.path.join(ckpt_path, "accuracy_history.csv"), "a"
) as f:
f.write(
"{:d}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.1f}\t{:d}\n".format(
epoch,
avg_metric,
loss_torch_epoch,
lr,
elapsed_time,
idx_iter,
)
)

if torch.cuda.device_count() > 1:
dist.barrier()
if torch.cuda.device_count() > 1:
dist.barrier()

torch.cuda.empty_cache()
torch.cuda.empty_cache()

if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
print(
Expand Down
2 changes: 1 addition & 1 deletion auto3dseg/algorithm_templates/segresnet2d/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
)
writer.add_scalar("Loss/train", loss.item(), epoch_len * epoch + step)

lr_scheduler.step()
lr_scheduler.step()

if torch.cuda.device_count() > 1:
dist.barrier()
Expand Down
2 changes: 1 addition & 1 deletion auto3dseg/algorithm_templates/swinunetr/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
)
writer.add_scalar("Loss/train", loss.item(), epoch_len * epoch + step)

lr_scheduler.step()
lr_scheduler.step()

if torch.cuda.device_count() > 1:
dist.barrier()
Expand Down