Skip to content

Commit

Permalink
Merge pull request #117 from N720720/strong_type_numpy_fix
Browse files Browse the repository at this point in the history
problems with strong typing and numpy
  • Loading branch information
N720720 authored Feb 15, 2021
2 parents 05206df + 3c1d46e commit ec50be3
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 25 deletions.
3 changes: 2 additions & 1 deletion lindemann/index/per_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def lindemann_per_frames(frames: np.ndarray) -> np.ndarray:
return lindex_array


def calculate(indices: np.ndarray) -> np.ndarray:
# def calculate(indices: np.ndarray) -> np.ndarray:
def calculate(indices):
"""
Small helper function, since numba has not implemented the np.nanmean with axis parameter
I cant implemnet this in the jit function for now.
Expand Down
3 changes: 2 additions & 1 deletion lindemann/index/per_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def lindemann_per_frames_for_each_atom(frames: np.ndarray) -> np.ndarray:
return lindex_array


def calculate(indices: np.ndarray) -> np.ndarray:
# def calculate(indices: np.ndarray) -> np.ndarray:
def calculate(indices):
"""
Small helper function, since numba has not implemented the np.nanmean with axis parameter
I cant implemnet this in the jit function for now.
Expand Down
6 changes: 4 additions & 2 deletions lindemann/index/per_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@


@nb.njit(fastmath=True) # , cache=True)
def lindemann_per_atom(frames: np.ndarray) -> np.ndarray:
# def lindemann_per_atom(frames: np.ndarray) -> np.ndarray:
def lindemann_per_atom(frames):

"""Calculates the lindemann index for """
natoms = len(frames[0])
Expand Down Expand Up @@ -51,7 +52,8 @@ def lindemann_per_atom(frames: np.ndarray) -> np.ndarray:
return lindemann_indices


def calculate(frames: np.ndarray) -> np.float64:
# def calculate(frames: np.ndarray) -> float:
def calculate(frames):
"""
Small helper function, since numba has not implemented the np.nanmean with axis parameter
I cant implemnet this in the jit function for now.
Expand Down
2 changes: 1 addition & 1 deletion lindemann/trajectory/plt_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def lindemann_vs_frames(indices: np.ndarray) -> str:
plt.xlabel("Frames")
plt.ylabel("Lindemann index")
plt.plot(np.arange(0, len(indices)), indices, "+")
plt.tight_layout()
# plt.tight_layout()
# plt.show()
plt.savefig("lindemann_per_frame.pdf")
return "lindemann_per_frame.pdf"
23 changes: 10 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions tests/test_example/test_hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
[
(
"tests/test_example/459_01.lammpstrj",
np.round(np.float(0.025923892565654555), 12),
np.round(float(0.025923892565654555), 12),
),
(
"tests/test_example/459_02.lammpstrj",
np.round(np.float(0.026426709832984754), 12),
np.round(float(0.026426709832984754), 12),
),
],
)
# def test_setup(trajectory: str, lindemannindex: np.float) -> bool:
# def test_setup(trajectory: str, lindemannindex: float) -> bool:


def test_tra(trajectory, lindemannindex):
Expand All @@ -35,11 +35,11 @@ def test_tra(trajectory, lindemannindex):
[
(
"tests/test_example/459_01.lammpstrj",
np.round(np.float(0.025923892565654555), 12),
np.round(float(0.025923892565654555), 12),
),
(
"tests/test_example/459_02.lammpstrj",
np.round(np.float(0.026426709832984754), 12),
np.round(float(0.026426709832984754), 12),
),
],
)
Expand All @@ -55,11 +55,11 @@ def test_frames(trajectory, lindemannindex):
[
(
"tests/test_example/459_01.lammpstrj",
np.round(np.float(0.025923892565654555), 12),
np.round(float(0.025923892565654555), 12),
),
(
"tests/test_example/459_02.lammpstrj",
np.round(np.float(0.026426709832984754), 12),
np.round(float(0.026426709832984754), 12),
),
],
)
Expand Down

0 comments on commit ec50be3

Please sign in to comment.