Skip to content

Commit

Permalink
Ensure 0-index in constant buffer is carried through
Browse files Browse the repository at this point in the history
Differential Revision: D62209852

Pull Request resolved: pytorch#5145
  • Loading branch information
lucylq authored Sep 6, 2024
1 parent 617f9d8 commit 8afdc48
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 5 deletions.
7 changes: 6 additions & 1 deletion exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,12 @@ def serialize_pte_binary(
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
program.constant_buffer, tensor_alignment=constant_tensor_alignment
)
if len(constant_segment_data) > 0:

# If there are no constants, len(constant_segment_data) = 0. However, there may
# be non-constants, in which case len(constant_segment_offsets) = 1, containing
# the placeholder value 0. Ensure the placeholder value is put into
# program.constant_segment.offsets.
if len(constant_segment_offsets) > 0:
# Update program.constant_segment with constant subsegment offset information.
program.constant_segment = SubsegmentOffsets(
segment_index=len(segments), offsets=constant_segment_offsets
Expand Down
27 changes: 27 additions & 0 deletions exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,33 @@ def test_round_trip_with_segments(self) -> None:
program2 = deserialize_pte_binary(pte_data)
self.assert_programs_equal(program, program2)

def test_no_constants(self) -> None:
program = get_test_program()
# Insert placeholder for non-const tensors.
add_constant_data(program, [b""])

pte_data = bytes(
serialize_pte_binary(
program,
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
)
)
# The input Program should not be modified.
self.assertEqual(program.segments, [])

# Peek inside the actual flatbuffer data to see the segments.
flatbuffer_program = _json_to_program(_program_flatbuffer_to_json(pte_data))

# Constant buffer should be empty.
self.assertEqual(len(flatbuffer_program.constant_buffer), 0)

# Constant segment should contain the placeholder.
self.assertEqual(flatbuffer_program.constant_segment.segment_index, 0)
self.assertEqual(len(flatbuffer_program.constant_segment.offsets), 1)
self.assertEqual(flatbuffer_program.constant_segment.offsets[0], 0)

def test_unused_inline_delegate_blobs_with_segments(self) -> None:
# Create a program with some delegate data blobs.
program = get_test_program()
Expand Down
30 changes: 26 additions & 4 deletions runtime/executor/test/program_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,31 @@ TEST_F(ProgramTest, DEPRECATEDLoad) {
EXPECT_EQ(program_res.error(), Error::Ok);
}

TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
Result<Program> program =
Program::load(add_loader_.get(), kDefaultVerification);
ASSERT_EQ(program.error(), Error::Ok);

// Load constant segment data should fail.
const auto segment_info = DataLoader::SegmentInfo(
DataLoader::SegmentInfo::Type::Constant,
/*segment_index=*/0);
Result<FreeableBuffer> segment =
ProgramTestFriend::LoadSegment(&program.get(), segment_info);
EXPECT_NE(segment.error(), Error::Ok);

const executorch_flatbuffer::Program* flatbuffer_program =
ProgramTestFriend::GetInternalProgram(&program.get());

// The constant buffer should be empty.
EXPECT_EQ(flatbuffer_program->constant_buffer()->size(), 0);

// Expect 1 constant segment, placeholder for non-const tensors.
EXPECT_EQ(flatbuffer_program->segments()->size(), 1);
}

TEST_F(ProgramTest, LoadConstantSegment) {
// Load the serialized ModuleLinear data, with constants in the segment and no
// constants in the flatbuffer.
// Load the serialized ModuleLinear data, with constants in the segment.
const char* linear_path = std::getenv("ET_MODULE_LINEAR_PATH");
Result<FileDataLoader> linear_loader = FileDataLoader::from(linear_path);
ASSERT_EQ(linear_loader.error(), Error::Ok);
Expand Down Expand Up @@ -504,8 +526,8 @@ TEST_F(ProgramTest, LoadFromMutableSegment) {
const executorch_flatbuffer::Program* flatbuffer_program =
ProgramTestFriend::GetInternalProgram(&program.get());

// Expect 1 segment. 1 mutable segment and no constant segment.
EXPECT_EQ(flatbuffer_program->segments()->size(), 1);
// Expect 2 segments. 1 mutable segment and 1 constant segment.
EXPECT_EQ(flatbuffer_program->segments()->size(), 2);

// Expect a mutable data segment.
EXPECT_EQ(flatbuffer_program->mutable_data_segments()->size(), 1);
Expand Down

0 comments on commit 8afdc48

Please sign in to comment.