Skip to content

Commit

Permalink
feat(pt/tf): add bias changing param/interface (deepmodeling#3933)
Browse files Browse the repository at this point in the history
Add bias changing param/interface

For pt/tf, add `training/change_bias_after_training` to change out bias
once after training.

For pt, add a separate command `change-bias` to change trained
model(pt/pth, multi/single) out bias for specific data:

```
dp change-bias model.pt -s data -n 10 -m change
```

UTs for this feature are still in consideration.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added a new subcommand `change-bias` to adjust model output bias in
the PyTorch backend.
  - Introduced test cases for changing model biases via new test suite.
  
- **Documentation**
- Added documentation for the new `change-bias` command, including usage
and options.
- Updated `index.rst` to include a new entry for `change-bias` under the
`Model` section.

- **Bug Fixes**
- Adjusted data handling in `make_stat_input` to limit processing to a
specified number of batches.

- **Refactor**
- Restructured training configuration to include the parameter
`change_bias_after_training`.
  - Modularized data requirement handling and bias adjustment functions.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored and mtaillefumier committed Sep 18, 2024
1 parent 313240e commit a119394
Show file tree
Hide file tree
Showing 12 changed files with 534 additions and 84 deletions.
67 changes: 67 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,72 @@ def main_parser() -> argparse.ArgumentParser:
help="treat all types as a single type. Used with se_atten descriptor.",
)

# change_bias
parser_change_bias = subparsers.add_parser(
"change-bias",
parents=[parser_log],
help="(Supported backend: PyTorch) Change model out bias according to the input data.",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp change-bias model.pt -s data -n 10 -m change
"""
),
)
parser_change_bias.add_argument(
"INPUT", help="The input checkpoint file or frozen model file"
)
parser_change_bias_source = parser_change_bias.add_mutually_exclusive_group()
parser_change_bias_source.add_argument(
"-s",
"--system",
default=".",
type=str,
help="The system dir. Recursively detect systems in this directory",
)
parser_change_bias_source.add_argument(
"-b",
"--bias-value",
default=None,
type=float,
nargs="+",
help="The user defined value for each type in the type_map of the model, split with spaces.\n"
"For example, '-93.57 -187.1' for energy bias of two elements. "
"Only supports energy bias changing.",
)
parser_change_bias.add_argument(
"-n",
"--numb-batch",
default=0,
type=int,
help="The number of frames for bias changing in one data system. 0 means all data.",
)
parser_change_bias.add_argument(
"-m",
"--mode",
type=str,
default="change",
choices=["change", "set"],
help="The mode for changing energy bias: \n"
"change (default) : perform predictions using input model on target dataset, "
"and do least square on the errors to obtain the target shift as bias.\n"
"set : directly use the statistic bias in the target dataset.",
)
parser_change_bias.add_argument(
"-o",
"--output",
default=None,
type=str,
help="The model after changing bias.",
)
parser_change_bias.add_argument(
"--model-branch",
type=str,
default=None,
help="Model branch chosen for changing bias if multi-task model.",
)

# --version
parser.add_argument(
"--version", action="version", version=f"DeePMD-kit v{__version__}"
Expand Down Expand Up @@ -831,6 +897,7 @@ def main():
"convert-from",
"train-nvnmd",
"show",
"change-bias",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
Expand Down
137 changes: 137 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
import json
import logging
import os
Expand All @@ -23,6 +24,9 @@
from deepmd import (
__version__,
)
from deepmd.common import (
expand_sys_str,
)
from deepmd.env import (
GLOBAL_CONFIG,
)
Expand All @@ -44,6 +48,9 @@
from deepmd.pt.train import (
training,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
from deepmd.pt.utils import (
env,
)
Expand All @@ -59,6 +66,12 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand Down Expand Up @@ -376,6 +389,128 @@ def show(FLAGS):
log.info(f"The fitting_net parameter is {fitting_net}")


def change_bias(FLAGS):
if FLAGS.INPUT.endswith(".pt"):
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
model_params = model_state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.endswith(".pth"):
old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE)
model_params_string = old_model.get_model_def_script()
model_params = json.loads(model_params_string)
old_state_dict = old_model.state_dict()
model_state_dict = old_state_dict
else:
raise RuntimeError(
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
multi_task = "model_dict" in model_params
model_branch = FLAGS.model_branch
bias_adjust_mode = (
"change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic"
)
if multi_task:
assert (
model_branch is not None
), "For multitask model, the model branch must be set!"
assert model_branch in model_params["model_dict"], (
f"For multitask model, the model branch must be in the 'model_dict'! "
f"Available options are : {list(model_params['model_dict'].keys())}."
)
log.info(f"Changing out bias for model {model_branch}.")
model = training.get_model_for_wrapper(model_params)
type_map = (
model_params["type_map"]
if not multi_task
else model_params["model_dict"][model_branch]["type_map"]
)
model_to_change = model if not multi_task else model[model_branch]
if FLAGS.INPUT.endswith(".pt"):
wrapper = ModelWrapper(model)
wrapper.load_state_dict(old_state_dict["model"])
else:
# for .pth
model.load_state_dict(old_state_dict)

if FLAGS.bias_value is not None:
# use user-defined bias
assert model_to_change.model_type in [
"ener"
], "User-defined bias is only available for energy model!"
assert (
len(FLAGS.bias_value) == len(type_map)
), f"The number of elements in the bias should be the same as that in the type_map: {type_map}."
old_bias = model_to_change.get_out_bias()
bias_to_set = torch.tensor(
FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device
).view(old_bias.shape)
model_to_change.set_out_bias(bias_to_set)
log.info(
f"Change output bias of {type_map!s} "
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
f"to {to_numpy_array(bias_to_set).reshape(-1)!s}."
)
updated_model = model_to_change
else:
# calculate bias on given systems
data_systems = process_systems(expand_sys_str(FLAGS.system))
data_single = DpLoaderSet(
data_systems,
1,
type_map,
)
mock_loss = training.get_loss(
{"inference": True}, 1.0, len(type_map), model_to_change
)
data_requirement = mock_loss.label_requirement
data_requirement += training.get_additional_data_requirement(model_to_change)
data_single.add_data_requirement(data_requirement)
nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf")
sampled_data = make_stat_input(
data_single.systems,
data_single.dataloaders,
nbatches,
)
updated_model = training.model_change_out_bias(
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
)

if not multi_task:
model = updated_model
else:
model[model_branch] = updated_model

if FLAGS.INPUT.endswith(".pt"):
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"]
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]
torch.save(old_state_dict, output_path)
else:
# for .pth
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pth", "_updated.pth")
)
model = torch.jit.script(model)
torch.jit.save(
model,
output_path,
{},
)
log.info(f"Saved model to {output_path}")


@record
def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if not isinstance(args, argparse.Namespace):
Expand All @@ -400,6 +535,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
freeze(FLAGS)
elif FLAGS.command == "show":
show(FLAGS)
elif FLAGS.command == "change-bias":
change_bias(FLAGS)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference
self.has_gf = (start_pref_gf != 0.0 and limit_pref_gf != 0.0) or inference
self.has_gf = start_pref_gf != 0.0 and limit_pref_gf != 0.0

self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def init_out_stat(self):
self.register_buffer("out_bias", out_bias_data)
self.register_buffer("out_std", out_std_data)

def set_out_bias(self, out_bias: torch.Tensor) -> None:
self.out_bias = out_bias

def __setitem__(self, key, value):
if key in ["out_bias"]:
self.out_bias = value
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def forward_common(
def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()

def set_out_bias(self, out_bias: torch.Tensor) -> None:
self.atomic_model.set_out_bias(out_bias)

def change_out_bias(
self,
merged,
Expand Down
Loading

0 comments on commit a119394

Please sign in to comment.