Skip to content

Commit

Permalink
Merge pull request #27 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Fix a minor bug, and release v0.3
  • Loading branch information
WenjieDu authored Nov 23, 2023
2 parents 6682bcc + df23ec9 commit 9460d8c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pygrinder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.2"
__version__ = "0.3"

from .missing_completely_at_random import mcar, mcar_little_test
from .missing_at_random import mar_logistic
Expand Down
8 changes: 4 additions & 4 deletions pygrinder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def masked_fill(
Mask filled X.
"""
if isinstance(X, list):
X = np.asarray(X)
mask = np.asarray(mask)

assert X.shape == mask.shape, (
"Shapes of X and mask must match, "
f"but X.shape={X.shape}, mask.shape={mask.shape}"
Expand All @@ -73,10 +77,6 @@ def masked_fill(
"Data types of X and mask must match, " f"but got {type(X)} and {type(mask)}"
)

if isinstance(X, list):
X = np.asarray(X)
mask = np.asarray(mask)

if isinstance(X, np.ndarray):
filled_X = X.copy()
mask = mask.copy()
Expand Down
15 changes: 15 additions & 0 deletions tests/test_pygrinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ def test_0_mcar(self):
assert np.sum(X_with_missing[(1 - missing_mask).astype(bool)]) == NaN * np.sum(
1 - missing_mask
)
# as list
list_X_with_missing = masked_fill(
X_with_missing.tolist(),
(1 - missing_mask).tolist(),
np.nan,
).tolist()
_ = cal_missing_rate(list_X_with_missing)
# as torch tensor
tensor_X_with_missing = masked_fill(
torch.from_numpy(X_with_missing),
torch.from_numpy(1 - missing_mask),
torch.nan,
)
_ = cal_missing_rate(tensor_X_with_missing)
# as numpy array
X_with_missing = masked_fill(X_with_missing, 1 - missing_mask, np.nan)
actual_missing_rate = cal_missing_rate(X_with_missing)
assert (
Expand Down

0 comments on commit 9460d8c

Please sign in to comment.