Skip to content

Commit

Permalink
String Tensor SplitToSequence fix (#19942)
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp authored and rachguo committed Mar 21, 2024
1 parent 0941cc7 commit 674c359
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
int num_remaining_splits = 0;
InlinedVector<int64_t> split_sizes;
const bool is_string_type = input.IsDataTypeString();
const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size();
const size_t element_size = input.DataType()->Size();

// figure out split_scalar or split_sizes
if (p_split_input) {
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,19 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) {
test.Run();
}

TEST(SequenceOpsTest, SplitToSequence_StringSplit) {
OpTester test("SplitToSequence", 11);
test.AddInput<std::string>("input", {3}, std::vector<std::string>({"Test string", "Another string", "A third and much longer string"}));

Check warning on line 447 in onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc:447: Lines should be <= 120 characters long [whitespace/line_length] [2]
int64_t axis = 0;
test.AddAttribute("axis", axis);
SeqTensors<std::string> output;
output.AddTensor({1}, {"Test string"});
output.AddTensor({1}, {"Another string"});
output.AddTensor({1}, {"A third and much longer string"});
test.AddSeqOutput("S2", output);
test.Run();
}

TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) {
OpTester test("SplitToSequence", 11);
test.AddInput<float>("input", {5, 2}, GetConsecutiveVector<float>(1.f, 10));
Expand Down

0 comments on commit 674c359

Please sign in to comment.