Skip to content

Commit

Permalink
add generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Oct 19, 2023
1 parent 1f71c36 commit bf121ba
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/paddle/io/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ class SubsetRandomSampler(Sampler):
Args:
indices (sequence): a sequence of indices
generator(Generator, optional): specify a generator to sample the :code:`data_source`. Default None, disabled.
Examples:
Expand All @@ -368,15 +369,15 @@ class SubsetRandomSampler(Sampler):
5
1
see `paddle.io.Sampler`
"""

def __init__(self, indices):
def __init__(self, indices, generator=None):
if len(indices) == 0:
raise ValueError(
"The length of `indices` in SubsetRandomSampler should be greater than 0."
)
self.indices = indices
assert generator is None

def __iter__(self):
for i in randperm(len(self.indices)):
Expand Down

0 comments on commit bf121ba

Please sign in to comment.