Skip to content

Commit

Permalink
Merge pull request #571 from ACEsuit/main
Browse files Browse the repository at this point in the history
Backward the log changes
  • Loading branch information
ilyes319 committed Aug 27, 2024
2 parents b542c35 + 2a5ebec commit e78ae91
Show file tree
Hide file tree
Showing 11 changed files with 400 additions and 167 deletions.
3 changes: 2 additions & 1 deletion mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def run(args: argparse.Namespace):

# Data preparation
collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=args.train_file,
valid_path=args.valid_file,
valid_fraction=args.valid_fraction,
Expand Down Expand Up @@ -211,7 +212,7 @@ def run(args: argparse.Namespace):
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
logging.info(f"Atomic Energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
Expand Down
163 changes: 113 additions & 50 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def run(args: argparse.Namespace) -> None:
"""
This script runs the training/fine tuning for mace
"""
args, input_log_messages = tools.check_args(args)
tag = tools.get_tag(name=args.name, seed=args.seed)
if args.distributed:
try:
Expand All @@ -74,6 +75,9 @@ def run(args: argparse.Namespace) -> None:
# Setup
tools.set_seeds(args.seed)
tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank)
logging.info("===========VERIFYING SETTINGS===========")
for message, loglevel in input_log_messages:
logging.log(level=loglevel, msg=message)

if args.distributed:
torch.cuda.set_device(local_rank)
Expand All @@ -84,7 +88,7 @@ def run(args: argparse.Namespace) -> None:
logging.info(f"MACE version: {mace.__version__}")
except AttributeError:
logging.info("Cannot find MACE version, please install MACE via pip")
logging.info(f"Configuration: {args}")
logging.debug(f"Configuration: {args}")

tools.set_default_dtype(args.default_dtype)
device = tools.init_device(args.device)
Expand Down Expand Up @@ -132,6 +136,8 @@ def run(args: argparse.Namespace) -> None:
args.compute_avg_num_neighbors = False
args.E0s = statistics["atomic_energies"]

logging.info("")
logging.info("===========LOADING INPUT DATA===========")
# Data preparation
if args.train_file.endswith(".xyz"):
if args.valid_file is not None:
Expand All @@ -140,6 +146,7 @@ def run(args: argparse.Namespace) -> None:
), "valid_file if given must be same format as train_file"
config_type_weights = get_config_type_weights(args.config_type_weights)
collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=args.train_file,
valid_path=args.valid_file,
valid_fraction=args.valid_fraction,
Expand All @@ -154,11 +161,16 @@ def run(args: argparse.Namespace) -> None:
charges_key=args.charges_key,
keep_isolated_atoms=args.keep_isolated_atoms,
)
if len(collections.train) < args.batch_size:
logging.error(
f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})"
)
if len(collections.valid) < args.valid_batch_size:
logging.warning(
f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})"
)
args.valid_batch_size = len(collections.valid)

logging.info(
f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, "
f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]"
)
else:
atomic_energies_dict = None

Expand All @@ -181,12 +193,11 @@ def run(args: argparse.Namespace) -> None:
assert isinstance(zs_list, list)
z_table = tools.get_atomic_number_table_from_zs(zs_list)
# yapf: enable
logging.info(z_table)
logging.info(f"Atomic Numbers used: {z_table.zs}")

if atomic_energies_dict is None or len(atomic_energies_dict) == 0:
if args.E0s.lower() == "foundation":
assert args.foundation_model is not None
logging.info("Using atomic energies from foundation model")
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
Expand All @@ -196,6 +207,9 @@ def run(args: argparse.Namespace) -> None:
].item()
for z in z_table.zs
}
logging.info(
f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}"
)
else:
if args.train_file.endswith(".xyz"):
atomic_energies_dict = get_atomic_energies(
Expand Down Expand Up @@ -226,7 +240,9 @@ def run(args: argparse.Namespace) -> None:
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
logging.info(
f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}"
)

if args.train_file.endswith(".xyz"):
train_set = [
Expand Down Expand Up @@ -286,7 +302,8 @@ def run(args: argparse.Namespace) -> None:
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
)

logging.info("")
logging.info("===========MODEL DETAILS===========")
if args.loss == "weighted":
loss_fn = modules.WeightedEnergyForcesLoss(
energy_weight=args.energy_weight, forces_weight=args.forces_weight
Expand Down Expand Up @@ -336,7 +353,6 @@ def run(args: argparse.Namespace) -> None:
else:
# Unweighted Energy and Forces loss by default
loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0)
logging.info(loss_fn)

if args.compute_avg_num_neighbors:
avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader)
Expand All @@ -350,7 +366,12 @@ def run(args: argparse.Namespace) -> None:
args.avg_num_neighbors = (num_neighbors / num_graphs).item()
else:
args.avg_num_neighbors = avg_num_neighbors
logging.info(f"Average number of neighbors: {args.avg_num_neighbors}")
if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100:
logging.warning(
f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}"
)
else:
logging.info(f"Average number of neighbors: {args.avg_num_neighbors:.1f}")

# Selecting outputs
compute_virials = False
Expand All @@ -369,7 +390,10 @@ def run(args: argparse.Namespace) -> None:
"stress": args.compute_stress,
"dipoles": compute_dipole,
}
logging.info(f"Selected the following outputs: {output_args}")

logging.info(
f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}"
)

if args.scaling == "no_scaling":
args.std = 1.0
Expand All @@ -380,11 +404,14 @@ def run(args: argparse.Namespace) -> None:
)
# Build model
if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]:
logging.info("Building model")
logging.info("Loading FOUNDATION model")
model_config_foundation = extract_config_mace_model(model_foundation)
model_config_foundation["atomic_numbers"] = z_table.zs
model_config_foundation["num_elements"] = len(z_table)
args.max_L = model_config_foundation["hidden_irreps"].lmax
args.num_channels = list(
{irrep.mul for irrep in o3.Irreps(model_config_foundation["hidden_irreps"])}
)[0]
model_config_foundation["atomic_inter_shift"] = (
model_foundation.scale_shift.shift.item()
)
Expand All @@ -394,23 +421,35 @@ def run(args: argparse.Namespace) -> None:
model_config_foundation["atomic_energies"] = atomic_energies
args.model = "FoundationMACE"
model_config = model_config_foundation # pylint
logging.info(
f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({model_config_foundation['hidden_irreps']})"
)
logging.info(
f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}"
)
logging.info(
f"Radial cutoff: {model_config_foundation['r_max']} Å (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} Å)"
)
logging.info(
f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}"
)
else:
logging.info("Building model")
if args.num_channels is not None and args.max_L is not None:
assert args.num_channels > 0, "num_channels must be positive integer"
assert args.max_L >= 0, "max_L must be non-negative integer"
args.hidden_irreps = o3.Irreps(
(args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
.sort()
.irreps.simplify()
)

assert (
len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"

logging.info(f"Hidden irreps: {args.hidden_irreps}")

logging.info(
f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})"
)
logging.info(
f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}"
)
logging.info(
f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions"
)
logging.info(
f"Radial cutoff: {args.r_max} Å (total receptive field for each atom: {args.r_max * args.num_interactions} Å)"
)
logging.info(
f"Distance transform for radial basis functions: {args.distance_transform}"
)
model_config = dict(
r_max=args.r_max,
num_bessel=args.num_radial_basis,
Expand Down Expand Up @@ -522,6 +561,20 @@ def run(args: argparse.Namespace) -> None:
)
model.to(device)

logging.debug(model)
logging.info(f"Total number of parameters: {tools.count_parameters(model)}")
logging.info("")
logging.info("===========OPTIMIZER INFORMATION===========")
logging.info(f"Using {args.optimizer.upper()} as parameter optimizer")
logging.info(f"Batch size: {args.batch_size}")
if args.ema:
logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}")
logging.info(
f"Number of gradient updates: {int(args.max_num_epochs*len(collections.train)/args.batch_size)}"
)
logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}")
logging.info(loss_fn)

# Optimizer
decay_interactions = {}
no_decay_interactions = {}
Expand Down Expand Up @@ -592,13 +645,9 @@ def run(args: argparse.Namespace) -> None:
swas.append(True)
if args.start_swa is None:
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
else:
if args.start_swa > args.max_num_epochs:
logging.info(
f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}"
)
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
logging.info(f"Setting start Stage Two to {args.start_swa}")
logging.info(
f"Stage Two will start after {args.start_swa} epochs with loss function:"
)
if args.loss == "forces_only":
raise ValueError("Can not select Stage Two with forces only loss.")
if args.loss == "virials":
Expand All @@ -619,17 +668,12 @@ def run(args: argparse.Namespace) -> None:
forces_weight=args.swa_forces_weight,
dipole_weight=args.swa_dipole_weight,
)
logging.info(
f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}"
)
else:
loss_fn_energy = modules.WeightedEnergyForcesLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
)
logging.info(
f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}"
)
logging.info(loss_fn_energy)
swa = tools.SWAContainer(
model=AveragedModel(model),
scheduler=SWALR(
Expand Down Expand Up @@ -673,10 +717,6 @@ def run(args: argparse.Namespace) -> None:
for group in optimizer.param_groups:
group["lr"] = args.lr

logging.info(model)
logging.info(f"Number of parameters: {tools.count_parameters(model)}")
logging.info(f"Optimizer: {optimizer}")

if args.wandb:
logging.info("Using Weights and Biases for logging")
import wandb
Expand Down Expand Up @@ -726,7 +766,8 @@ def run(args: argparse.Namespace) -> None:
train_sampler=train_sampler,
rank=rank,
)

logging.info("")
logging.info("===========RESULTS===========")
logging.info("Computing metrics for training, validation, and test sets")

all_data_loaders = {
Expand Down Expand Up @@ -781,6 +822,13 @@ def run(args: argparse.Namespace) -> None:
)
all_data_loaders[test_name] = test_loader

train_valid_data_loader = {
k: v for k, v in all_data_loaders.items() if k in ["train", "valid"]
}
test_data_loader = {
k: v for k, v in all_data_loaders.items() if k not in ["train", "valid"]
}

for swa_eval in swas:
epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
Expand All @@ -791,21 +839,36 @@ def run(args: argparse.Namespace) -> None:
if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
model_to_evaluate = model if not args.distributed else distributed_model
logging.info(f"Loaded model from epoch {epoch}")
if swa_eval:
logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation")
else:
logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation")

for param in model.parameters():
param.requires_grad = False
table = create_error_table(

table_train = create_error_table(
table_type=args.error_table,
all_data_loaders=train_valid_data_loader,
model=model_to_evaluate,
loss_fn=loss_fn,
output_args=output_args,
log_wandb=args.wandb,
device=device,
distributed=args.distributed,
)
table_test = create_error_table(
table_type=args.error_table,
all_data_loaders=all_data_loaders,
all_data_loaders=test_data_loader,
model=model_to_evaluate,
loss_fn=loss_fn,
output_args=output_args,
log_wandb=args.wandb,
device=device,
distributed=args.distributed,
)
logging.info("\n" + str(table))
logging.info("Error-table on TRAIN and VALID:\n" + str(table_train))
logging.info("Error-table on TEST:\n" + str(table_test))

if rank == 0:
# Save entire model
Expand Down
Loading

0 comments on commit e78ae91

Please sign in to comment.