Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #231

Merged
merged 2 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.2.0
hooks:
- id: black

Expand Down
14 changes: 8 additions & 6 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def dataset_from_dicts(
for key, val in labels["fixed"].items():
labels["fixed"][key] = tf.constant(val)

ds = tf.data.Dataset.from_tensor_slices((
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
))
ds = tf.data.Dataset.from_tensor_slices(
(
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
)
)
return ds


Expand Down
10 changes: 6 additions & 4 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})
epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)

epoch_metrics.update({**epoch_loss})

Expand Down
2 changes: 1 addition & 1 deletion apax/utils/jax_md_reduced/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def neighbor_fn(position_and_error, max_occupancy=None):
if not is_sparse(format):
capacity_limit = N - 1 if mask_self else N
elif format is NeighborListFormat.Sparse:
capacity_limit = N * (N - 1) if mask_self else N ** 2
capacity_limit = N * (N - 1) if mask_self else N**2
else:
capacity_limit = N * (N - 1) // 2
if max_occupancy > capacity_limit:
Expand Down
12 changes: 7 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,13 @@ def modify_xyz_file(file_path, target_string, replacement_string):

@pytest.fixture()
def get_sample_input():
positions = np.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
positions = np.array(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
atomic_numbers = np.array([1, 1, 8])
box = np.diag(np.zeros(3))
offsets = np.full([3, 3], 0)
Expand Down
12 changes: 7 additions & 5 deletions tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,13 @@ def test_ase_calc(get_tmp_path):
model_config.dump_config(model_config.data.model_version_path)

cell_size = 10.0
positions = np.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
positions = np.array(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
atomic_numbers = np.array([1, 1, 8])
box = np.diag([cell_size] * 3)
offsets = jnp.full([3, 3], 0)
Expand Down
94 changes: 55 additions & 39 deletions tests/unit_tests/model/test_apax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@


def test_apax_variable_size():
R = np.array([
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
])
R = np.array(
[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
]
)

Z = np.array([1, 2, 2])

idx = np.array([
[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2],
])
idx = np.array(
[
[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2],
]
)
offsets = jnp.full([6, 3], 0)
box = np.array([0, 0, 0])

Expand Down Expand Up @@ -65,20 +69,24 @@ def test_apax_variable_size():
def test_atomistic_model():
key = jax.random.PRNGKey(0)

dR = np.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[-1.0, 0.0, 0.0],
[-1.0, 1.0, 0.0],
[0.0, -1.0, 0.0],
[1.0, -1.0, 0.0],
])
dR = np.array(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[-1.0, 0.0, 0.0],
[-1.0, 1.0, 0.0],
[0.0, -1.0, 0.0],
[1.0, -1.0, 0.0],
]
)
Z = np.array([1, 2, 2])

idx = np.array([
[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2],
])
idx = np.array(
[
[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2],
]
)

model = AtomisticModel(mask_atoms=False)

Expand All @@ -91,18 +99,22 @@ def test_atomistic_model():
def test_energy_model():
key = jax.random.PRNGKey(0)

R = np.array([
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
])
R = np.array(
[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
]
)

Z = np.array([1, 2, 2])

idx = np.array([
[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2],
])
idx = np.array(
[
[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2],
]
)
offsets = jnp.full([6, 3], 0)
box = np.array([0.0, 0.0, 0.0])

Expand All @@ -117,18 +129,22 @@ def test_energy_model():
def test_energy_force_model():
key = jax.random.PRNGKey(0)

R = np.array([
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
])
R = np.array(
[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
]
)

Z = np.array([1, 2, 2])

idx = np.array([
[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2],
])
idx = np.array(
[
[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2],
]
)
offsets = jnp.full([6, 3], 0)

box = np.array([0.0, 0.0, 0.0])
Expand Down
88 changes: 55 additions & 33 deletions tests/unit_tests/train/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ def test_weighted_squared_error():
assert loss.shape == ()
assert abs(loss - ref) < 1e-6

pred = jnp.array([
[0.6, 0.4, 0.2, -0.5],
[0.1, -0.1, 0.8, 0.6],
])
pred = jnp.array(
[
[0.6, 0.4, 0.2, -0.5],
[0.1, -0.1, 0.8, 0.6],
]
)
loss = weighted_squared_error(energy_label, pred, divisor=1.0)
loss = jnp.sum(loss)
ref = 0.25
Expand All @@ -35,23 +37,31 @@ def test_weighted_squared_error():


def test_force_angle_loss():
F_pred = jnp.array([[
[0.5, 0.0, 0.0],
[0.5, 0.0, 0.0],
[0.5, 0.5, 0.0],
[0.0, 0.5, 0.0],
[0.0, 0.5, 0.0],
[0.0, 0.0, 0.0], # padding
]])

F_0 = jnp.array([[
[0.5, 0.0, 0.0],
[0.9, 0.0, 0.0],
[0.5, 0.0, 0.0],
[0.5, 0.0, 0.0],
[0.9, 0.0, 0.0],
[0.0, 0.0, 0.0], # padding
]])
F_pred = jnp.array(
[
[
[0.5, 0.0, 0.0],
[0.5, 0.0, 0.0],
[0.5, 0.5, 0.0],
[0.0, 0.5, 0.0],
[0.0, 0.5, 0.0],
[0.0, 0.0, 0.0], # padding
]
]
)

F_0 = jnp.array(
[
[
[0.5, 0.0, 0.0],
[0.9, 0.0, 0.0],
[0.5, 0.0, 0.0],
[0.5, 0.0, 0.0],
[0.9, 0.0, 0.0],
[0.0, 0.0, 0.0], # padding
]
]
)

F_angle_loss = force_angle_loss(F_pred, F_0)
F_angle_loss = jnp.arccos(-F_angle_loss + 1) * 360 / (2 * np.pi)
Expand All @@ -74,17 +84,25 @@ def test_force_loss():
"n_atoms": jnp.array([2]),
}
label = {
name: jnp.array([[
[0.4, 0.2, 0.5],
[0.3, 0.8, 0.1],
]]),
name: jnp.array(
[
[
[0.4, 0.2, 0.5],
[0.3, 0.8, 0.1],
]
]
),
}

pred = {
name: jnp.array([[
[0.4, 0.2, 0.5],
[0.3, 0.8, 0.1],
]]),
name: jnp.array(
[
[
[0.4, 0.2, 0.5],
[0.3, 0.8, 0.1],
]
]
),
}
loss_func = Loss(name=name, loss_type=loss_type, weight=weight)
loss = loss_func(inputs=inputs, label=label, prediction=pred)
Expand All @@ -93,10 +111,14 @@ def test_force_loss():
assert abs(loss - ref_loss) < 1e-6

pred = {
name: jnp.array([[
[0.4, 0.2, 0.5],
[0.3, 0.8, 0.6],
]]),
name: jnp.array(
[
[
[0.4, 0.2, 0.5],
[0.3, 0.8, 0.6],
]
]
),
}
loss_func = Loss(name=name, loss_type=loss_type, weight=weight)
loss = loss_func(inputs=inputs, label=label, prediction=pred)
Expand Down
Loading
Loading