Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add size and replace arguments to deepmd.utils.random.choice #3195

Merged
merged 4 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@
def __getitem__(self, index=None):
"""Get a batch of frames from the selected system."""
if index is None:
index = dp_random.choice(np.arange(self.nsystems), self.probs)
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)
b_data = self._data_systems[index].get_batch(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
Expand All @@ -892,7 +892,7 @@
def get_training_batch(self, index=None):
"""Get a batch of frames from the selected system."""
if index is None:
index = dp_random.choice(np.arange(self.nsystems), self.probs)
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)

Check warning on line 895 in deepmd/pt/utils/dataset.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataset.py#L895

Added line #L895 was not covered by tests
b_data = self._data_systems[index].get_batch_for_train(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
Expand Down
27 changes: 21 additions & 6 deletions deepmd/utils/random.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
Tuple,
Union,
)

import numpy as np

_RANDOM_GENERATOR = np.random.RandomState()


def choice(a: np.ndarray, p: Optional[np.ndarray] = None):
def choice(
a: Union[np.ndarray, int],
size: Optional[Union[int, Tuple[int, ...]]] = None,
replace: bool = True,
p: Optional[np.ndarray] = None,
):
"""Generates a random sample from a given 1-D array.

Parameters
----------
a : np.ndarray
A random sample is generated from its elements.
p : np.ndarray
The probabilities associated with each entry in a.
a : 1-D array-like or int
If an ndarray, a random sample is generated from its elements. If an int,
the random sample is generated as if it were np.arange(a)
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples
are drawn. Default is None, in which case a single value is returned.
replace : boolean, optional
Whether the sample is with or without replacement. Default is True, meaning
that a value of a can be selected multiple times.
p : 1-D array-like, optional
The probabilities associated with each entry in a. If not given, the sample
assumes a uniform distribution over all entries in a.

Returns
-------
np.ndarray
arrays with results and their shapes
"""
return _RANDOM_GENERATOR.choice(a, p=p)
return _RANDOM_GENERATOR.choice(a, size=size, replace=replace, p=p)


def random(size=None):
Expand Down
Loading