Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: fix spin nlist in spin_model #3718

Merged
merged 8 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,20 @@ def extend_nlist(extended_atype, nlist):
nlist_shift = nlist + nall
nlist[~nlist_mask] = -1
nlist_shift[~nlist_mask] = -1
self_spin = torch.arange(0, nloc, dtype=nlist.dtype, device=nlist.device) + nall
self_spin = self_spin.view(1, -1, 1).expand(nframes, -1, -1)
# self spin + real neighbor + virtual neighbor
self_real = (
torch.arange(0, nloc, dtype=nlist.dtype, device=nlist.device)
.view(1, -1, 1)
.expand(nframes, -1, -1)
)
self_spin = self_real + nall
# real atom's neighbors: self spin + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
real_nlist = torch.cat([self_spin, nlist, nlist_shift], dim=-1)
# spin atom's neighbors: real + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
extended_nlist = torch.cat([self_spin, nlist, nlist_shift], dim=-1)
spin_nlist = torch.cat([self_real, nlist, nlist_shift], dim=-1)
# nf x (nloc + nloc) x (1 + nnei + nnei)
extended_nlist = torch.cat(
[extended_nlist, -1 * torch.ones_like(extended_nlist)], dim=-2
)
extended_nlist = torch.cat([real_nlist, spin_nlist], dim=-2)
# update the index for switch
first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall)
second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc))
Expand Down
63 changes: 58 additions & 5 deletions doc/model/train-energy-spin.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
# Fit spin energy {{ tensorflow_icon }}
# Fit spin energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
:::

In this section, we will take `$deepmd_source_dir/examples/NiO/se_e2_a/input.json` as an example of the input file.
To train a model that takes additional spin information as input, you only need to modify the following sections to define the spin-specific settings,
keeping other sections the same as the normal energy model's input script.

:::{warning}
Note that when adding spin into the model, there will be some implicit modifications automatically done by the program:

- In the TensorFlow backend, the `se_e2_a` descriptor will treat those atom types with spin as new (virtual) types,
and duplicate their corresponding selected numbers of neighbors ({ref}`sel <model/descriptor[se_e2_a]/sel>`) from their real atom types.
- In the PyTorch backend, if spin settings are added, all the types (with or without spin) will have their virtual types.
The `se_e2_a` descriptor will thus double the {ref}`sel <model/descriptor[se_e2_a]/sel>` list,
while in other descriptors with mixed types (such as `dpa1` or `dpa2`), the sel number will not be changed for clarity.
If you are using descriptors with mixed types, to achieve better performance,
you should manually extend your sel number (maybe double) depending on the balance between performance and efficiency.
:::

## Spin

The construction of the fitting net is give by section {ref}`spin <model/spin>`
The spin settings are given by the {ref}`spin <model/spin>` section, which sets the magnetism for each type of atoms as described in the following sections.
iProzd marked this conversation as resolved.
Show resolved Hide resolved

:::{note}
Note that the construction of spin settings is different between TensorFlow and PyTorch/DP.
:::

### Spin settings in TensorFlow

The implementation in TensorFlow only supports `se_e2_a` descriptor. See examples in `$deepmd_source_dir/examples/spin/se_e2_a/input_tf.json`, the {ref}`spin <model/spin>` section is defined as the following:

```json
"spin" : {
Expand All @@ -18,10 +39,38 @@ The construction of the fitting net is give by section {ref}`spin <model/spin>`
},
```

- {ref}`use_spin <model/spin[ener_spin]/use_spin>` determines whether to turn on the magnetism of the atoms.The index of this option matches option `type_map <model/type_map>`.
- {ref}`use_spin <model/spin[ener_spin]/use_spin>` is a list of boolean values indicating whether to use atomic spin for each atom type.
True for spin and False for not. The index of this option matches option `type_map <model/type_map>`.
- {ref}`virtual_len <model/spin[ener_spin]/virtual_len>` specifies the distance between virtual atom and the belonging real atom.
iProzd marked this conversation as resolved.
Show resolved Hide resolved
- {ref}`spin_norm <model/spin[ener_spin]/spin_norm>` gives the magnitude of the magnetic moment for each magnatic atom.

### Spin settings in PyTorch/DP

In PyTorch/DP, the spin implementation is more flexible and so far supports the following descriptors:
iProzd marked this conversation as resolved.
Show resolved Hide resolved

- `se_e2_a`
- `dpa1`(`se_atten`)
- `dpa2`

See `se_e2_a` examples in `$deepmd_source_dir/examples/spin/se_e2_a/input_torch.json`, the {ref}`spin <model/spin>` section is defined as the following with a much more clear interface:

```json
"spin": {
"use_spin": [true, false],
"virtual_scale": [0.3140]
},
```

- {ref}`use_spin <model/spin[ener_spin]/use_spin>` is a list of boolean values indicating whether to use atomic spin for each atom type, or a list of type indexes that use atomic spin.
The index of this option matches option `type_map <model/type_map>`.
- {ref}`virtual_len <model/spin[ener_spin]/virtual_scale>` defines the scaling factor to determine the virtual distance
between a virtual atom representing spin and its corresponding real atom
for each atom type with spin. This factor is defined as the virtual distance
divided by the magnitude of atomic spin for each atom type with spin.
The virtual coordinate is defined as the real coordinate plus spin \* virtual_scale.
List of float values with shape of `ntypes` or `ntypes_spin` or one single float value for all types,
only used when {ref}`use_spin <model/spin[ener_spin]/use_spin>` is True for each atom type.

## Spin Loss

The spin loss function $L$ for training energy is given by
Expand Down Expand Up @@ -59,3 +108,7 @@ The {ref}`loss <loss>` section in the `input.json` is
The options {ref}`start_pref_e <loss[ener_spin]/start_pref_e>`, {ref}`limit_pref_e <loss[ener_spin]/limit_pref_e>`, {ref}`start_pref_fr <loss[ener_spin]/start_pref_fr>`, {ref}`limit_pref_fm <loss[ener_spin]/limit_pref_fm>`, {ref}`start_pref_v <loss[ener_spin]/start_pref_v>` and {ref}`limit_pref_v <loss[ener_spin]/limit_pref_v>` determine the start and limit prefactors of energy, atomic force, magnatic force and virial, respectively.

If one does not want to train with virial, then he/she may set the virial prefactors {ref}`start_pref_v <loss[ener_spin]/start_pref_v>` and {ref}`limit_pref_v <loss[ener_spin]/limit_pref_v>` to 0.

## Data preparation

(Need a documentation for data format for TensorFlow and PyTorch/DP.)
30 changes: 22 additions & 8 deletions source/tests/pt/model/test_forward_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def test(
) = extend_input_and_build_neighbor_list(
coord.unsqueeze(0),
atype.unsqueeze(0),
self.model.get_rcut(),
self.model.get_rcut() + 1.0
if test_spin
else self.model.get_rcut(), # buffer region for spin nlist
self.model.get_sel(),
mixed_types=self.model.mixed_types(),
box=cell.unsqueeze(0),
Expand Down Expand Up @@ -128,15 +130,13 @@ class TestEnergyModelSeA(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_se_e2_a)
self.type_split = False
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelDPA1(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_dpa1)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


Expand All @@ -151,24 +151,38 @@ def setUp(self):
"repinit_nsel"
]
model_params = copy.deepcopy(model_dpa2)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelZBL(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_zbl)
self.type_split = False
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinSeA(unittest.TestCase, ForwardLowerTest):
def setUp(self):
# still need to figure out why only 1e-5 rtol and atol
self.prec = 1e-5
self.prec = 1e-10
model_params = copy.deepcopy(model_spin)
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinDPA1(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_spin)
model_params["descriptor"] = copy.deepcopy(model_dpa1)["descriptor"]
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinDPA2(unittest.TestCase, ForwardLowerTest):
def setUp(self):
self.prec = 1e-10
model_params = copy.deepcopy(model_spin)
self.type_split = False
model_params["descriptor"] = copy.deepcopy(model_dpa2)["descriptor"]
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)

Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@
"type": "dpa2",
"repinit_rcut": 6.0,
"repinit_rcut_smth": 2.0,
"repinit_nsel": 30,
"repinit_nsel": 100,
"repformer_rcut": 4.0,
"repformer_rcut_smth": 0.5,
"repformer_nsel": 20,
"repformer_nsel": 40,
"repinit_neuron": [2, 4, 8],
"repinit_axis_neuron": 4,
"repinit_activation": "tanh",
Expand Down