Skip to content

Commit

Permalink
ARROW-15219: [Python] Export the random compute function
Browse files Browse the repository at this point in the history
Also change the underlying generator as successive seeds (e.g. 0, 1, 2, 3) could produce the same output.

Closes apache#12054 from pitrou/ARROW-15219-py-random

Authored-by: Antoine Pitrou <antoine@python.org>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
pitrou authored and pull[bot] committed Jul 21, 2022
1 parent 71e7227 commit 04ed7e6
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 13 deletions.
22 changes: 15 additions & 7 deletions cpp/src/arrow/compute/kernels/scalar_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,36 @@ namespace internal {

namespace {

// We use the PCG64 single-stream ("oneseq") generator because:
// - we don't need multiple streams
// - we want deterministic output for a given seed (ruling out the unique-stream
// PCG generators)
// - the PCG64 no-stream ("fast") generator produces identical outputs for seeds
// which differ only by their 2 low bits (for example, 0, 1, 2, 3 all produce
// the same output).

// Generates a random floating point number in range [0, 1).
double generate_uniform(random::pcg64_fast* rng) {
double generate_uniform(random::pcg64_oneseq* rng) {
// This equation is copied from numpy. It calculates `rng() / 2^64` and
// the return value is strictly less than 1.
static_assert(random::pcg64_fast::min() == 0ULL, "");
static_assert(random::pcg64_fast::max() == ~0ULL, "");
static_assert(random::pcg64_oneseq::min() == 0ULL, "");
static_assert(random::pcg64_oneseq::max() == ~0ULL, "");
return ((*rng)() >> 11) * (1.0 / 9007199254740992.0);
}

using RandomState = OptionsWrapper<RandomOptions>;

random::pcg64_fast MakeSeedGenerator() {
random::pcg64_oneseq MakeSeedGenerator() {
arrow_vendored::pcg_extras::seed_seq_from<std::random_device> seed_source;
random::pcg64_fast seed_gen(seed_source);
random::pcg64_oneseq seed_gen(seed_source);
return seed_gen;
}

Status ExecRandom(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
static random::pcg64_fast seed_gen = MakeSeedGenerator();
static random::pcg64_oneseq seed_gen = MakeSeedGenerator();
static std::mutex seed_gen_mutex;

random::pcg64_fast gen;
random::pcg64_oneseq gen;
const RandomOptions& options = RandomState::Get(ctx);
if (options.length < 0) {
return Status::Invalid("Negative number of elements");
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/util/pcg_random.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ using pcg32 = ::arrow_vendored::pcg32;
using pcg64 = ::arrow_vendored::pcg64;
using pcg32_fast = ::arrow_vendored::pcg32_fast;
using pcg64_fast = ::arrow_vendored::pcg64_fast;
using pcg32_oneseq = ::arrow_vendored::pcg32_oneseq;
using pcg64_oneseq = ::arrow_vendored::pcg64_oneseq;

} // namespace random
} // namespace arrow
26 changes: 26 additions & 0 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,32 @@ class Utf8NormalizeOptions(_Utf8NormalizeOptions):
self._set_options(form)


cdef class _RandomOptions(FunctionOptions):
def _set_options(self, length, initializer):
if initializer == 'system':
self.wrapped.reset(new CRandomOptions(
CRandomOptions.FromSystemRandom(length)))
return

if not isinstance(initializer, int):
try:
initializer = hash(initializer)
except TypeError:
raise TypeError(
f"initializer should be 'system', an integer, "
f"or a hashable object; got {initializer!r}")

if initializer < 0:
initializer += 2**64
self.wrapped.reset(new CRandomOptions(
CRandomOptions.FromSeed(length, initializer)))


class RandomOptions(_RandomOptions):
def __init__(self, length, *, initializer='system'):
self._set_options(length, initializer)


def _group_by(args, keys, aggregations):
cdef:
vector[CDatum] c_args
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
PadOptions,
PartitionNthOptions,
QuantileOptions,
RandomOptions,
ReplaceSliceOptions,
ReplaceSubstringOptions,
RoundOptions,
Expand Down
22 changes: 16 additions & 6 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,22 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
CUtf8NormalizeOptions(CUtf8NormalizeForm form)
CUtf8NormalizeForm form

cdef cppclass CSetLookupOptions \
"arrow::compute::SetLookupOptions"(CFunctionOptions):
CSetLookupOptions(CDatum value_set, c_bool skip_nulls)
CDatum value_set
c_bool skip_nulls

cdef cppclass CRandomOptions \
"arrow::compute::RandomOptions"(CFunctionOptions):
CRandomOptions(CRandomOptions)

@staticmethod
CRandomOptions FromSystemRandom(int64_t length)

@staticmethod
CRandomOptions FromSeed(int64_t length, uint64_t seed)

cdef enum DatumType" arrow::Datum::type":
DatumType_NONE" arrow::Datum::NONE"
DatumType_SCALAR" arrow::Datum::SCALAR"
Expand All @@ -2236,12 +2252,6 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
const shared_ptr[CTable]& table() const
const shared_ptr[CScalar]& scalar() const

cdef cppclass CSetLookupOptions \
"arrow::compute::SetLookupOptions"(CFunctionOptions):
CSetLookupOptions(CDatum value_set, c_bool skip_nulls)
CDatum value_set
c_bool skip_nulls


cdef extern from * namespace "arrow::compute":
# inlined from compute/function_internal.h to avoid exposing
Expand Down
31 changes: 31 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from datetime import datetime
from functools import lru_cache, partial
import inspect
import os
import pickle
import pytest
import random
Expand Down Expand Up @@ -146,6 +147,7 @@ def test_option_class_equality():
pc.PadOptions(5),
pc.PartitionNthOptions(1, null_placement="at_start"),
pc.QuantileOptions(),
pc.RandomOptions(10),
pc.ReplaceSliceOptions(0, 1, "a"),
pc.ReplaceSubstringOptions("a", "b"),
pc.RoundOptions(2, "towards_infinity"),
Expand Down Expand Up @@ -174,6 +176,7 @@ def test_option_class_equality():
options.append(pc.AssumeTimezoneOptions("Europe/Ljubljana"))

classes = {type(option) for option in options}

for cls in exported_option_classes:
# Timezone database is not available on Windows yet
if cls not in classes and sys.platform != 'win32' and \
Expand All @@ -182,6 +185,7 @@ def test_option_class_equality():
options.append(cls())
except TypeError:
pytest.fail(f"Options class is not tested: {cls}")

for option in options:
assert option == option
assert repr(option).startswith(option.__class__.__name__)
Expand Down Expand Up @@ -2370,3 +2374,30 @@ def test_utf8_normalize():
ValueError,
match='"NFZ" is not a valid Unicode normalization form'):
pc.utf8_normalize(arr, form="NFZ")


def test_random():
# (note negative integer initializers are accepted)
for initializer in ['system', 42, -42, b"abcdef"]:
assert pc.random(0, initializer=initializer) == \
pa.array([], type=pa.float64())

# System random initialization => outputs all distinct
arrays = [tuple(pc.random(100).to_pylist()) for i in range(10)]
assert len(set(arrays)) == len(arrays)

arrays = [tuple(pc.random(100, initializer=i % 7).to_pylist())
for i in range(0, 100)]
assert len(set(arrays)) == 7

# Arbitrary hashable objects can be given as initializer
initializers = [object(), (4, 5, 6), "foo"]
initializers.extend(os.urandom(10) for i in range(10))
arrays = [tuple(pc.random(100, initializer=i).to_pylist())
for i in initializers]
assert len(set(arrays)) == len(arrays)

with pytest.raises(TypeError,
match=r"initializer should be 'system', an integer, "
r"or a hashable object; got \[\]"):
pc.random(100, initializer=[])

0 comments on commit 04ed7e6

Please sign in to comment.