Skip to content

Commit

Permalink
Add setters for higher dimensional arrays.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550619522
  • Loading branch information
jwhpryor authored and copybara-github committed Jul 24, 2023
1 parent a083543 commit c47fdc3
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 28 deletions.
80 changes: 54 additions & 26 deletions implementation/array_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,60 @@ class ArrayRef : public ScopedArrayImpl<JniT> {
std::atomic<std::size_t> length_ = kNoIdx;
};

// |SpanType| is object.
// Shared behaviour for object like arrays.
template <typename JniT>
class ArrayRef<
JniT, std::enable_if_t<std::is_same_v<typename JniT::SpanType, jobject>>>
: public ScopedArrayImpl<JniT> {
class ArrayRefBase : public ScopedArrayImpl<JniT> {
public:
using Base = ScopedArrayImpl<JniT>;
using Base::Base;
using SpanType = jobject;
using SpanType = typename JniT::SpanType;

// Construct array with given size and null values.
explicit ArrayRef(std::size_t size)
explicit ArrayRefBase(std::size_t size)
: Base(JniArrayHelper<jobject, JniT::kRank>::NewArray(
size, ClassRef_t<JniT>::GetAndMaybeLoadClassRef(nullptr),
static_cast<jobject>(nullptr))) {}

// Construct from jobject lvalue (object is used as template).
explicit ArrayRef(std::size_t size, jobject obj)
explicit ArrayRefBase(std::size_t size, jobject obj)
: Base(JniArrayHelper<jobject, JniT::kRank>::NewArray(
size,
ClassRef_t<JniT>::GetAndMaybeLoadClassRef(
static_cast<jobject>(obj)),
static_cast<jobject>(obj))) {}

// Object arrays cannot be efficiently pinned like primitive types can.
ArrayView<SpanType, JniT::kRank> Pin() {
return {Base::object_ref_, false, Length()};
}

std::size_t Length() {
return JniArrayHelper<jobject, JniT::kRank>::GetLength(Base::object_ref_);
}

// Note: Globals are not permitted in a local array because it makes reasoning
// about them confusing.
//
// TODO(b/406948932): Permit lvalues of locals and globals as technically
// they're both viable (the scope will be extended as expected).
void Set(
std::size_t idx,
LocalObject<JniT::class_v, JniT::class_loader_v, JniT::jvm_v>&& val) {
JniArrayHelper<jobject, JniT::kRank>::SetArrayElement(Base::object_ref_,
idx, val.Release());
}
};

// |SpanType| is object and rank is 1.
template <typename JniT>
class ArrayRef<
JniT, std::enable_if_t<(std::is_same_v<typename JniT::SpanType, jobject> &&
JniT::kRank == 1)>> : public ArrayRefBase<JniT> {
public:
using Base = ArrayRefBase<JniT>;
using Base::Base;
using SpanType = typename JniT::SpanType;

// Construct from LocalObject lvalue (object is used as template).
//
// e.g.
Expand All @@ -105,31 +135,29 @@ class ArrayRef<
const ObjectContainer<class_v, class_loader_v, jvm_v>& obj)
: ArrayRef(size, static_cast<jobject>(obj)) {}

std::size_t Length() {
return JniArrayHelper<jobject, JniT::kRank>::GetLength(Base::object_ref_);
}

// Object arrays cannot be efficiently pinned like primitive types can.
ArrayView<jobject, JniT::kRank> Pin() {
return {Base::object_ref_, false, Length()};
}

LocalObject<JniT::class_v, JniT::class_loader_v, JniT::jvm_v> Get(
std::size_t idx) {
return {JniArrayHelper<jobject, JniT::kRank>::GetArrayElement(
Base::object_ref_, idx)};
}
};

// Note: Globals are not permitted in a local array because it makes reasoning
// about them confusing.
//
// TODO(b/406948932): Permit lvalues of locals and globals as technically
// they're both viable (the scope will be extended as expected).
void Set(
std::size_t idx,
LocalObject<JniT::class_v, JniT::class_loader_v, JniT::jvm_v>&& val) {
return JniArrayHelper<jobject, JniT::kRank>::SetArrayElement(
Base::object_ref_, idx, val.Release());
// |SpanType| is object or rank is > 1.
template <typename JniT>
class ArrayRef<JniT, std::enable_if_t<(JniT::kRank > 1)>>
: public ArrayRefBase<JniT> {
public:
using Base = ArrayRefBase<JniT>;
using Base::Base;

template <typename SpanType, std::size_t kRank_, const auto& class_v_,
const auto& class_loader_v_, const auto& jvm_v_>
void Set(std::size_t idx, const LocalArray<SpanType, kRank_, class_v_,
class_loader_v_, jvm_v_>& val) {
using ElementT =
typename JniArrayHelper<SpanType, JniT::kRank - 1>::AsArrayType;
JniArrayHelper<ElementT, JniT::kRank>::SetArrayElement(
Base::object_ref_, idx, static_cast<ElementT>(val));
}
};

Expand Down
4 changes: 4 additions & 0 deletions implementation/jni_helper/jni_array_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ struct JniArrayHelperBase {
// Rank 2+ arrays all behave like object arrays.
template <typename SpannedType, std::size_t kRank>
struct JniArrayHelper : public JniArrayHelperBase {
using AsArrayType = jobjectArray;

static inline jobjectArray NewArray(std::size_t size,
jclass class_id = nullptr,
jobject initial_element = nullptr) {
Expand Down Expand Up @@ -247,6 +249,8 @@ struct JniArrayHelper<jdouble, 1> : public JniArrayHelperBase {
// is unlike any other new array construction.
template <std::size_t kRank>
struct JniArrayHelper<jobject, kRank> : public JniArrayHelperBase {
using AsArrayType = jobjectArray;

static inline jobjectArray NewArray(std::size_t size, jclass class_id,
jobject initial_element) {
return jni::JniEnv::GetEnv()->NewObjectArray(size, class_id,
Expand Down
40 changes: 38 additions & 2 deletions implementation/local_array_multidimensional_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ using ::testing::_;
using ::testing::Return;
using ::testing::StrEq;

static constexpr Class kClass{"kClass"};

////////////////////////////////////////////////////////////////////////////////
// Multi-Dimensional Construction.
////////////////////////////////////////////////////////////////////////////////
Expand All @@ -52,6 +54,42 @@ TEST_F(JniTest, Array_BuildsFromSizeForMultiDimensionalArray_primitive_lvalue) {
LocalArray<jint, 2>{std::size_t{10}, arr};
}

////////////////////////////////////////////////////////////////////////////////
// Setters.
////////////////////////////////////////////////////////////////////////////////
TEST_F(JniTest, Array_SetsIntValues) {
EXPECT_CALL(
*env_, SetObjectArrayElement(Fake<jobjectArray>(), 0, Fake<jintArray>()));
EXPECT_CALL(
*env_, SetObjectArrayElement(Fake<jobjectArray>(), 1, Fake<jintArray>()));
EXPECT_CALL(
*env_, SetObjectArrayElement(Fake<jobjectArray>(), 2, Fake<jintArray>()));

LocalArray<jint, 1> array_arg{Fake<jintArray>()};
LocalArray<jint, 2> arr{std::size_t{10}, Fake<jobjectArray>()};
arr.Set(0, array_arg);
arr.Set(1, array_arg);
arr.Set(2, std::move(array_arg));
}

TEST_F(JniTest, Array_SetsObjectValues) {
EXPECT_CALL(*env_, SetObjectArrayElement(Fake<jobjectArray>(1), 0,
Fake<jobjectArray>(2)));
EXPECT_CALL(*env_, SetObjectArrayElement(Fake<jobjectArray>(1), 1,
Fake<jobjectArray>(2)));
EXPECT_CALL(*env_, SetObjectArrayElement(Fake<jobjectArray>(1), 2,
Fake<jobjectArray>(2)));

LocalArray<jobject, 1, kClass> array_arg{Fake<jobjectArray>(2)};
LocalArray<jint, 2> arr{Fake<jobjectArray>(1)};
arr.Set(0, array_arg);
arr.Set(1, array_arg);
arr.Set(2, std::move(array_arg));
}

////////////////////////////////////////////////////////////////////////////////
// Iteration.
////////////////////////////////////////////////////////////////////////////////
TEST_F(JniTest, Array_IteratesOver1DRange) {
std::array expected{10, 20, 30, 40, 50, 60, 70, 80, 90, 100};

Expand Down Expand Up @@ -84,7 +122,6 @@ TEST_F(JniTest, Array_WorksWithSTLComparison) {

TEST_F(JniTest, Array_WorksWithSTLComparisonOfObjects) {
std::array expected{Fake<jobject>(1), Fake<jobject>(2), Fake<jobject>(3)};
static constexpr Class kClass{"kClass"};

EXPECT_CALL(*env_, GetArrayLength).WillOnce(Return(3));
EXPECT_CALL(*env_, GetObjectArrayElement)
Expand All @@ -99,7 +136,6 @@ TEST_F(JniTest, Array_WorksWithSTLComparisonOfObjects) {
}

TEST_F(JniTest, Array_WorksWithSTLComparisonOfRichlyDecoratedObjects) {
static constexpr Class kClass{"kClass"};
std::array expected{LocalObject<kClass>{Fake<jobject>(1)},
LocalObject<kClass>{Fake<jobject>(2)},
LocalObject<kClass>{Fake<jobject>(3)}};
Expand Down

0 comments on commit c47fdc3

Please sign in to comment.