Skip to content

Commit

Permalink
Merge pull request #562 from elcorto/feature-improve-predictor
Browse files Browse the repository at this point in the history
Improve predictor module
  • Loading branch information
RandomDefaultUser authored Oct 18, 2024
2 parents e062deb + e8ba079 commit 49ed05f
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions mala/network/predictor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tester class for testing a network."""
"""Predictor class."""

from time import perf_counter

Expand Down Expand Up @@ -61,13 +61,6 @@ def predict_from_qeout(self, path_to_file, gather_ldos=False):
predicted_ldos : numpy.array
Precicted LDOS for these atomic positions.
"""
self.data.grid_dimension = self.parameters.inference_data_grid
self.data.grid_size = (
self.data.grid_dimension[0]
* self.data.grid_dimension[1]
* self.data.grid_dimension[2]
)

self.data.target_calculator.read_additional_calculation_data(
path_to_file, "espresso-out"
)
Expand Down Expand Up @@ -240,18 +233,17 @@ def _forward_snap_descriptors(
)

for i in range(0, self.number_of_batches_per_snapshot):
inputs = snap_descriptors[
i
* self.parameters.mini_batch_size : (i + 1)
* self.parameters.mini_batch_size
]
inputs = inputs.to(self.parameters._configuration["device"])
predicted_outputs[
i
* self.parameters.mini_batch_size : (i + 1)
* self.parameters.mini_batch_size
] = self.data.output_data_scaler.inverse_transform(
self.network(inputs).to("cpu"), as_numpy=True
sl = slice(
i * self.parameters.mini_batch_size,
(i + 1) * self.parameters.mini_batch_size,
)
inputs = snap_descriptors[sl].to(
self.parameters._configuration["device"]
)
predicted_outputs[sl] = (
self.data.output_data_scaler.inverse_transform(
self.network(inputs).to("cpu"), as_numpy=True
)
)

# Restricting the actual quantities to physical meaningful values,
Expand Down

0 comments on commit 49ed05f

Please sign in to comment.